Added comments and made a start on infer.py

This commit is contained in:
2026-03-28 22:31:10 +00:00
parent 753723694b
commit 0af8bff4a8
7 changed files with 301 additions and 22 deletions

View File

@@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 33,
"id": "79ac71dd",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "aa9117f0",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>source</th>\n",
" <th>review_id</th>\n",
" <th>user_name</th>\n",
" <th>review_title</th>\n",
" <th>review_description</th>\n",
" <th>rating</th>\n",
" <th>thumbs_up</th>\n",
" <th>review_date</th>\n",
" <th>developer_response</th>\n",
" <th>developer_response_date</th>\n",
" <th>appVersion</th>\n",
" <th>laguage_code</th>\n",
" <th>country_code</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Google Play</td>\n",
" <td>18d6584c-d0e9-4833-a744-f607058aee97</td>\n",
" <td>Milky Way</td>\n",
" <td>NaN</td>\n",
" <td>Suddenly, the driver can't have my location an...</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:48:51</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Google Play</td>\n",
" <td>50a08f18-cece-4ddf-b617-028844c8aa28</td>\n",
" <td>Bradlee Severa</td>\n",
" <td>NaN</td>\n",
" <td>Very cordial.. And helped with a quick turnaro...</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:38:35</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.485.10000</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Google Play</td>\n",
" <td>b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7</td>\n",
" <td>Amit Aggarwal</td>\n",
" <td>NaN</td>\n",
" <td>Very good experience</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:38:17</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.486.10002</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Google Play</td>\n",
" <td>502702a9-25ed-4373-a96c-7fa1f06caacd</td>\n",
" <td>Bryant Inman</td>\n",
" <td>NaN</td>\n",
" <td>All I use</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:37:45</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.467.10008</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Google Play</td>\n",
" <td>f47a3fb6-23db-49bd-9e63-f33c8d724d07</td>\n",
" <td>Addie Whittaker</td>\n",
" <td>NaN</td>\n",
" <td>I have enjoyed traveling by Uber my drivers ha...</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:36:56</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.486.10002</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" source review_id user_name \\\n",
"0 Google Play 18d6584c-d0e9-4833-a744-f607058aee97 Milky Way \n",
"1 Google Play 50a08f18-cece-4ddf-b617-028844c8aa28 Bradlee Severa \n",
"2 Google Play b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7 Amit Aggarwal \n",
"3 Google Play 502702a9-25ed-4373-a96c-7fa1f06caacd Bryant Inman \n",
"4 Google Play f47a3fb6-23db-49bd-9e63-f33c8d724d07 Addie Whittaker \n",
"\n",
" review_title review_description rating \\\n",
"0 NaN Suddenly, the driver can't have my location an... 1 \n",
"1 NaN Very cordial.. And helped with a quick turnaro... 5 \n",
"2 NaN Very good experience 5 \n",
"3 NaN All I use 5 \n",
"4 NaN I have enjoyed traveling by Uber my drivers ha... 5 \n",
"\n",
" thumbs_up review_date developer_response developer_response_date \\\n",
"0 0.0 2023-08-10 17:48:51 NaN NaN \n",
"1 0.0 2023-08-10 17:38:35 NaN NaN \n",
"2 0.0 2023-08-10 17:38:17 NaN NaN \n",
"3 0.0 2023-08-10 17:37:45 NaN NaN \n",
"4 0.0 2023-08-10 17:36:56 NaN NaN \n",
"\n",
" appVersion laguage_code country_code \n",
"0 NaN en in \n",
"1 4.485.10000 en in \n",
"2 4.486.10002 en in \n",
"3 4.467.10008 en in \n",
"4 4.486.10002 en in "
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"../data/raw/uber_reviews.csv\", low_memory=False)\n",
"df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "36683790",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.DataFrame'>\n",
"RangeIndex: 1069616 entries, 0 to 1069615\n",
"Data columns (total 13 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 source 1069616 non-null str \n",
" 1 review_id 1069616 non-null str \n",
" 2 user_name 1069615 non-null str \n",
" 3 review_title 2180 non-null str \n",
" 4 review_description 1069447 non-null str \n",
" 5 rating 1069616 non-null int64 \n",
" 6 thumbs_up 1067436 non-null float64\n",
" 7 review_date 1069616 non-null str \n",
" 8 developer_response 198264 non-null str \n",
" 9 developer_response_date 197278 non-null str \n",
" 10 appVersion 828068 non-null str \n",
" 11 laguage_code 1069616 non-null str \n",
" 12 country_code 1069616 non-null str \n",
"dtypes: float64(1), int64(1), str(11)\n",
"memory usage: 106.1 MB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "0b1c7f73",
"metadata": {},
"outputs": [],
"source": [
"df = df[['review_description']].reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "90e6d653",
"metadata": {},
"outputs": [],
"source": [
"df.head(5)\n",
"df.to_csv(\"../data/raw/review_description.csv\", index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "multitag",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -5,19 +5,7 @@
"execution_count": 1,
"id": "2b7cfa1a",
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'sklearn'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpd\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel_selection\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m train_test_split\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n",
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'sklearn'"
]
}
],
"outputs": [],
"source": [
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
@@ -1209,7 +1197,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.11"
"version": "3.11.15"
}
},
"nbformat": 4,

View File

@@ -2,14 +2,21 @@ import pandas as pd
import numpy as np
import torch
import argparse
from transformers import AutoTokenizer
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
# mappings
binary_map = {1:'Yes', 0:'No'}
aspect_map = {0:'App', 1:'Driver', 2:'General', 3:'Payment', 4:'Pricing', 5:'Service'}
sentiment_map = {0:'Positive', 1:'Neutral', 2:'Negative'}
label_names = {
'bug_report': ['No', 'Yes'],
'feature_request': ['No', 'Yes'],
'aspect': ['App', 'Driver', 'General', 'Payment', 'Pricing', 'Service'],
'aspect_sentiment': ['Positive', 'Neutral', 'Negative']
}
SEED = 4321
torch.manual_seed(SEED)
@@ -17,9 +24,31 @@ np.random.seed(SEED)
def parse_args():
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
parser.add_argument("--model_path", type=str, help="Enter the models path / the desired .pt file")
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" )
parser.add_argument("--interactive", help="Loops reading input until exit")
parser.add_argument("--model_path", type=str, required=True, help=".pt file path")
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"])
parser.add_argument("--interactive", help="Loops reading input until exit()")
parser.add_argument("--text", help="Use command line text for input")
parser.add_argument("--dataset", type=str, required=True, help="Enter a file for inference")
return parser.parse_args()
def main():
args = parse_args()
print(f'='*50)
print(f' '*15 + "Starting inference")
if torch.cuda.is_available():
print(f' '*15 + "GPU:", torch.cuda.get_device_name(0))
torch.cuda.manual_seed_all(SEED)
torch.cuda.manual_seed(SEED)
else:
print(f' '*15 + "No GPUs available")
print(f'='*50 + "\n")
print(f"Running inference on: {args.model_path.upper()} using {args.dataset}")
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
infer_data = f"data/processed/{args.dataset}_infer.csv"
if __name__ == main():
main()

View File

@@ -52,7 +52,7 @@ class Model(nn.Module):
# Applied across shared cls token, before all task heads
self.dropout = nn.Dropout(dropout_rate)
# get logits for each head
self.bug_head = nn.Linear(hidden_size, 2)
self.feature_head = nn.Linear(hidden_size, 2)
self.aspect_head = nn.Linear(hidden_size, 6)

View File

@@ -160,8 +160,8 @@ def preprocess_uber_reviews(input_path, output_path):
return df_clean
if __name__ == "__main__":
input_file = "multitag/data/uber_reviews.csv"
output_file = "multitag/data/uber_reviews_cleaned.csv"
input_file = "data/raw/uber_reviews.csv"
output_file = "data/raw/uber_reviews_cleaned.csv"
df_clean = preprocess_uber_reviews(input_file, output_file)
print("\nPreprocessing complete!")

View File

@@ -190,7 +190,6 @@ class Sampler:
mini_sample = self.data.sample(200) # reading some samples manually
return mini_sample
def save_sample(self, sample_df,output_path):
"""Save sample and display statistics"""

View File

@@ -126,6 +126,7 @@ def main():
print("Aspect sentiment class weights:", aspect_sentiment_weights.cpu().numpy())
# equal weighted task losses. unequal was considered but equal weights performed well without adding complexity
# CrossEntropyLoss = LogSoftmax + NLLLoss (negative log likelihood)
criterions = {
'bug_report': nn.CrossEntropyLoss(weight=bug_weights),
'feature_request': nn.CrossEntropyLoss(weight=feature_weights),
@@ -134,6 +135,7 @@ def main():
}
# -------------------- Optimizer and scheduler -------------------
# adaptive momentum and weight decay keeps track of previous weight adaptions and ensures they dont get too large (weight also shrinks towards 0 each pass)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,