Added comments and made a start on infer.py
This commit is contained in:
261
notebooks/getting_csv_for_inference.ipynb
Normal file
261
notebooks/getting_csv_for_inference.ipynb
Normal file
@@ -0,0 +1,261 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "79ac71dd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"id": "aa9117f0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>source</th>\n",
|
||||
" <th>review_id</th>\n",
|
||||
" <th>user_name</th>\n",
|
||||
" <th>review_title</th>\n",
|
||||
" <th>review_description</th>\n",
|
||||
" <th>rating</th>\n",
|
||||
" <th>thumbs_up</th>\n",
|
||||
" <th>review_date</th>\n",
|
||||
" <th>developer_response</th>\n",
|
||||
" <th>developer_response_date</th>\n",
|
||||
" <th>appVersion</th>\n",
|
||||
" <th>laguage_code</th>\n",
|
||||
" <th>country_code</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>Google Play</td>\n",
|
||||
" <td>18d6584c-d0e9-4833-a744-f607058aee97</td>\n",
|
||||
" <td>Milky Way</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>Suddenly, the driver can't have my location an...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2023-08-10 17:48:51</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>en</td>\n",
|
||||
" <td>in</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>Google Play</td>\n",
|
||||
" <td>50a08f18-cece-4ddf-b617-028844c8aa28</td>\n",
|
||||
" <td>Bradlee Severa</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>Very cordial.. And helped with a quick turnaro...</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2023-08-10 17:38:35</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.485.10000</td>\n",
|
||||
" <td>en</td>\n",
|
||||
" <td>in</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>Google Play</td>\n",
|
||||
" <td>b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7</td>\n",
|
||||
" <td>Amit Aggarwal</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>Very good experience</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2023-08-10 17:38:17</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.486.10002</td>\n",
|
||||
" <td>en</td>\n",
|
||||
" <td>in</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>Google Play</td>\n",
|
||||
" <td>502702a9-25ed-4373-a96c-7fa1f06caacd</td>\n",
|
||||
" <td>Bryant Inman</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>All I use</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2023-08-10 17:37:45</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.467.10008</td>\n",
|
||||
" <td>en</td>\n",
|
||||
" <td>in</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>Google Play</td>\n",
|
||||
" <td>f47a3fb6-23db-49bd-9e63-f33c8d724d07</td>\n",
|
||||
" <td>Addie Whittaker</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>I have enjoyed traveling by Uber my drivers ha...</td>\n",
|
||||
" <td>5</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2023-08-10 17:36:56</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>NaN</td>\n",
|
||||
" <td>4.486.10002</td>\n",
|
||||
" <td>en</td>\n",
|
||||
" <td>in</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" source review_id user_name \\\n",
|
||||
"0 Google Play 18d6584c-d0e9-4833-a744-f607058aee97 Milky Way \n",
|
||||
"1 Google Play 50a08f18-cece-4ddf-b617-028844c8aa28 Bradlee Severa \n",
|
||||
"2 Google Play b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7 Amit Aggarwal \n",
|
||||
"3 Google Play 502702a9-25ed-4373-a96c-7fa1f06caacd Bryant Inman \n",
|
||||
"4 Google Play f47a3fb6-23db-49bd-9e63-f33c8d724d07 Addie Whittaker \n",
|
||||
"\n",
|
||||
" review_title review_description rating \\\n",
|
||||
"0 NaN Suddenly, the driver can't have my location an... 1 \n",
|
||||
"1 NaN Very cordial.. And helped with a quick turnaro... 5 \n",
|
||||
"2 NaN Very good experience 5 \n",
|
||||
"3 NaN All I use 5 \n",
|
||||
"4 NaN I have enjoyed traveling by Uber my drivers ha... 5 \n",
|
||||
"\n",
|
||||
" thumbs_up review_date developer_response developer_response_date \\\n",
|
||||
"0 0.0 2023-08-10 17:48:51 NaN NaN \n",
|
||||
"1 0.0 2023-08-10 17:38:35 NaN NaN \n",
|
||||
"2 0.0 2023-08-10 17:38:17 NaN NaN \n",
|
||||
"3 0.0 2023-08-10 17:37:45 NaN NaN \n",
|
||||
"4 0.0 2023-08-10 17:36:56 NaN NaN \n",
|
||||
"\n",
|
||||
" appVersion laguage_code country_code \n",
|
||||
"0 NaN en in \n",
|
||||
"1 4.485.10000 en in \n",
|
||||
"2 4.486.10002 en in \n",
|
||||
"3 4.467.10008 en in \n",
|
||||
"4 4.486.10002 en in "
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df = pd.read_csv(\"../data/raw/uber_reviews.csv\", low_memory=False)\n",
|
||||
"df.head(5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "36683790",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'pandas.DataFrame'>\n",
|
||||
"RangeIndex: 1069616 entries, 0 to 1069615\n",
|
||||
"Data columns (total 13 columns):\n",
|
||||
" # Column Non-Null Count Dtype \n",
|
||||
"--- ------ -------------- ----- \n",
|
||||
" 0 source 1069616 non-null str \n",
|
||||
" 1 review_id 1069616 non-null str \n",
|
||||
" 2 user_name 1069615 non-null str \n",
|
||||
" 3 review_title 2180 non-null str \n",
|
||||
" 4 review_description 1069447 non-null str \n",
|
||||
" 5 rating 1069616 non-null int64 \n",
|
||||
" 6 thumbs_up 1067436 non-null float64\n",
|
||||
" 7 review_date 1069616 non-null str \n",
|
||||
" 8 developer_response 198264 non-null str \n",
|
||||
" 9 developer_response_date 197278 non-null str \n",
|
||||
" 10 appVersion 828068 non-null str \n",
|
||||
" 11 laguage_code 1069616 non-null str \n",
|
||||
" 12 country_code 1069616 non-null str \n",
|
||||
"dtypes: float64(1), int64(1), str(11)\n",
|
||||
"memory usage: 106.1 MB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"df.info()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"id": "0b1c7f73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df = df[['review_description']].reset_index(drop=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "90e6d653",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df.head(5)\n",
|
||||
"df.to_csv(\"../data/raw/review_description.csv\", index=False)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "multitag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -5,19 +5,7 @@
|
||||
"execution_count": 1,
|
||||
"id": "2b7cfa1a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ModuleNotFoundError",
|
||||
"evalue": "No module named 'sklearn'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
|
||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpd\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel_selection\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m train_test_split\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n",
|
||||
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'sklearn'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
@@ -1209,7 +1197,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.11"
|
||||
"version": "3.11.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
37
src/infer.py
37
src/infer.py
@@ -2,14 +2,21 @@ import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
import argparse
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
# mappings
|
||||
binary_map = {1:'Yes', 0:'No'}
|
||||
aspect_map = {0:'App', 1:'Driver', 2:'General', 3:'Payment', 4:'Pricing', 5:'Service'}
|
||||
sentiment_map = {0:'Positive', 1:'Neutral', 2:'Negative'}
|
||||
|
||||
label_names = {
|
||||
'bug_report': ['No', 'Yes'],
|
||||
'feature_request': ['No', 'Yes'],
|
||||
'aspect': ['App', 'Driver', 'General', 'Payment', 'Pricing', 'Service'],
|
||||
'aspect_sentiment': ['Positive', 'Neutral', 'Negative']
|
||||
}
|
||||
|
||||
SEED = 4321
|
||||
torch.manual_seed(SEED)
|
||||
@@ -17,9 +24,31 @@ np.random.seed(SEED)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
||||
parser.add_argument("--model_path", type=str, help="Enter the models path / the desired .pt file")
|
||||
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" )
|
||||
parser.add_argument("--interactive", help="Loops reading input until exit")
|
||||
parser.add_argument("--model_path", type=str, required=True, help=".pt file path")
|
||||
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"])
|
||||
parser.add_argument("--interactive", help="Loops reading input until exit()")
|
||||
parser.add_argument("--text", help="Use command line text for input")
|
||||
parser.add_argument("--dataset", type=str, required=True, help="Enter a file for inference")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
print(f'='*50)
|
||||
print(f' '*15 + "Starting inference")
|
||||
if torch.cuda.is_available():
|
||||
print(f' '*15 + "GPU:", torch.cuda.get_device_name(0))
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
torch.cuda.manual_seed(SEED)
|
||||
else:
|
||||
print(f' '*15 + "No GPUs available")
|
||||
print(f'='*50 + "\n")
|
||||
print(f"Running inference on: {args.model_path.upper()} using {args.dataset}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
infer_data = f"data/processed/{args.dataset}_infer.csv"
|
||||
|
||||
if __name__ == main():
|
||||
main()
|
||||
@@ -52,7 +52,7 @@ class Model(nn.Module):
|
||||
|
||||
# Applied across shared cls token, before all task heads
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
# get logits for each head
|
||||
self.bug_head = nn.Linear(hidden_size, 2)
|
||||
self.feature_head = nn.Linear(hidden_size, 2)
|
||||
self.aspect_head = nn.Linear(hidden_size, 6)
|
||||
|
||||
@@ -160,8 +160,8 @@ def preprocess_uber_reviews(input_path, output_path):
|
||||
return df_clean
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file = "multitag/data/uber_reviews.csv"
|
||||
output_file = "multitag/data/uber_reviews_cleaned.csv"
|
||||
input_file = "data/raw/uber_reviews.csv"
|
||||
output_file = "data/raw/uber_reviews_cleaned.csv"
|
||||
|
||||
df_clean = preprocess_uber_reviews(input_file, output_file)
|
||||
print("\nPreprocessing complete!")
|
||||
|
||||
@@ -191,7 +191,6 @@ class Sampler:
|
||||
return mini_sample
|
||||
|
||||
|
||||
|
||||
def save_sample(self, sample_df,output_path):
|
||||
"""Save sample and display statistics"""
|
||||
sample_df.to_csv(output_path, index=False)
|
||||
|
||||
@@ -126,6 +126,7 @@ def main():
|
||||
print("Aspect sentiment class weights:", aspect_sentiment_weights.cpu().numpy())
|
||||
|
||||
# equal weighted task losses. unequal was considered but equal weights performed well without adding complexity
|
||||
# CrossEntropyLoss = LogSoftmax + NLLLoss (negative log likelihood)
|
||||
criterions = {
|
||||
'bug_report': nn.CrossEntropyLoss(weight=bug_weights),
|
||||
'feature_request': nn.CrossEntropyLoss(weight=feature_weights),
|
||||
@@ -134,6 +135,7 @@ def main():
|
||||
}
|
||||
|
||||
# -------------------- Optimizer and scheduler -------------------
|
||||
# adaptive momentum and weight decay keeps track of previous weight adaptions and ensures they dont get too large (weight also shrinks towards 0 each pass)
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
|
||||
Reference in New Issue
Block a user