Added comments and made a start on infer.py

This commit is contained in:
2026-03-28 22:31:10 +00:00
parent 753723694b
commit 0af8bff4a8
7 changed files with 301 additions and 22 deletions

View File

@@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 33,
"id": "79ac71dd",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "aa9117f0",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>source</th>\n",
" <th>review_id</th>\n",
" <th>user_name</th>\n",
" <th>review_title</th>\n",
" <th>review_description</th>\n",
" <th>rating</th>\n",
" <th>thumbs_up</th>\n",
" <th>review_date</th>\n",
" <th>developer_response</th>\n",
" <th>developer_response_date</th>\n",
" <th>appVersion</th>\n",
" <th>laguage_code</th>\n",
" <th>country_code</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Google Play</td>\n",
" <td>18d6584c-d0e9-4833-a744-f607058aee97</td>\n",
" <td>Milky Way</td>\n",
" <td>NaN</td>\n",
" <td>Suddenly, the driver can't have my location an...</td>\n",
" <td>1</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:48:51</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Google Play</td>\n",
" <td>50a08f18-cece-4ddf-b617-028844c8aa28</td>\n",
" <td>Bradlee Severa</td>\n",
" <td>NaN</td>\n",
" <td>Very cordial.. And helped with a quick turnaro...</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:38:35</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.485.10000</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Google Play</td>\n",
" <td>b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7</td>\n",
" <td>Amit Aggarwal</td>\n",
" <td>NaN</td>\n",
" <td>Very good experience</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:38:17</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.486.10002</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Google Play</td>\n",
" <td>502702a9-25ed-4373-a96c-7fa1f06caacd</td>\n",
" <td>Bryant Inman</td>\n",
" <td>NaN</td>\n",
" <td>All I use</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:37:45</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.467.10008</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Google Play</td>\n",
" <td>f47a3fb6-23db-49bd-9e63-f33c8d724d07</td>\n",
" <td>Addie Whittaker</td>\n",
" <td>NaN</td>\n",
" <td>I have enjoyed traveling by Uber my drivers ha...</td>\n",
" <td>5</td>\n",
" <td>0.0</td>\n",
" <td>2023-08-10 17:36:56</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>4.486.10002</td>\n",
" <td>en</td>\n",
" <td>in</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" source review_id user_name \\\n",
"0 Google Play 18d6584c-d0e9-4833-a744-f607058aee97 Milky Way \n",
"1 Google Play 50a08f18-cece-4ddf-b617-028844c8aa28 Bradlee Severa \n",
"2 Google Play b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7 Amit Aggarwal \n",
"3 Google Play 502702a9-25ed-4373-a96c-7fa1f06caacd Bryant Inman \n",
"4 Google Play f47a3fb6-23db-49bd-9e63-f33c8d724d07 Addie Whittaker \n",
"\n",
" review_title review_description rating \\\n",
"0 NaN Suddenly, the driver can't have my location an... 1 \n",
"1 NaN Very cordial.. And helped with a quick turnaro... 5 \n",
"2 NaN Very good experience 5 \n",
"3 NaN All I use 5 \n",
"4 NaN I have enjoyed traveling by Uber my drivers ha... 5 \n",
"\n",
" thumbs_up review_date developer_response developer_response_date \\\n",
"0 0.0 2023-08-10 17:48:51 NaN NaN \n",
"1 0.0 2023-08-10 17:38:35 NaN NaN \n",
"2 0.0 2023-08-10 17:38:17 NaN NaN \n",
"3 0.0 2023-08-10 17:37:45 NaN NaN \n",
"4 0.0 2023-08-10 17:36:56 NaN NaN \n",
"\n",
" appVersion laguage_code country_code \n",
"0 NaN en in \n",
"1 4.485.10000 en in \n",
"2 4.486.10002 en in \n",
"3 4.467.10008 en in \n",
"4 4.486.10002 en in "
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(\"../data/raw/uber_reviews.csv\", low_memory=False)\n",
"df.head(5)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "36683790",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.DataFrame'>\n",
"RangeIndex: 1069616 entries, 0 to 1069615\n",
"Data columns (total 13 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 source 1069616 non-null str \n",
" 1 review_id 1069616 non-null str \n",
" 2 user_name 1069615 non-null str \n",
" 3 review_title 2180 non-null str \n",
" 4 review_description 1069447 non-null str \n",
" 5 rating 1069616 non-null int64 \n",
" 6 thumbs_up 1067436 non-null float64\n",
" 7 review_date 1069616 non-null str \n",
" 8 developer_response 198264 non-null str \n",
" 9 developer_response_date 197278 non-null str \n",
" 10 appVersion 828068 non-null str \n",
" 11 laguage_code 1069616 non-null str \n",
" 12 country_code 1069616 non-null str \n",
"dtypes: float64(1), int64(1), str(11)\n",
"memory usage: 106.1 MB\n"
]
}
],
"source": [
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "0b1c7f73",
"metadata": {},
"outputs": [],
"source": [
"df = df[['review_description']].reset_index(drop=True)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "90e6d653",
"metadata": {},
"outputs": [],
"source": [
"df.head(5)\n",
"df.to_csv(\"../data/raw/review_description.csv\", index=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "multitag",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -5,19 +5,7 @@
"execution_count": 1, "execution_count": 1,
"id": "2b7cfa1a", "id": "2b7cfa1a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'sklearn'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpd\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel_selection\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m train_test_split\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n",
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'sklearn'"
]
}
],
"source": [ "source": [
"import pandas as pd\n", "import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import train_test_split\n",
@@ -1209,7 +1197,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.13.11" "version": "3.11.15"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -2,14 +2,21 @@ import pandas as pd
import numpy as np import numpy as np
import torch import torch
import argparse import argparse
from transformers import AutoTokenizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
# mappings # mappings
binary_map = {1:'Yes', 0:'No'} binary_map = {1:'Yes', 0:'No'}
aspect_map = {0:'App', 1:'Driver', 2:'General', 3:'Payment', 4:'Pricing', 5:'Service'} aspect_map = {0:'App', 1:'Driver', 2:'General', 3:'Payment', 4:'Pricing', 5:'Service'}
sentiment_map = {0:'Positive', 1:'Neutral', 2:'Negative'} sentiment_map = {0:'Positive', 1:'Neutral', 2:'Negative'}
label_names = {
'bug_report': ['No', 'Yes'],
'feature_request': ['No', 'Yes'],
'aspect': ['App', 'Driver', 'General', 'Payment', 'Pricing', 'Service'],
'aspect_sentiment': ['Positive', 'Neutral', 'Negative']
}
SEED = 4321 SEED = 4321
torch.manual_seed(SEED) torch.manual_seed(SEED)
@@ -17,9 +24,31 @@ np.random.seed(SEED)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.") parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
parser.add_argument("--model_path", type=str, help="Enter the models path / the desired .pt file") parser.add_argument("--model_path", type=str, required=True, help=".pt file path")
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" ) parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"])
parser.add_argument("--interactive", help="Loops reading input until exit") parser.add_argument("--interactive", help="Loops reading input until exit()")
parser.add_argument("--text", help="Use command line text for input") parser.add_argument("--text", help="Use command line text for input")
parser.add_argument("--dataset", type=str, required=True, help="Enter a file for inference")
return parser.parse_args() return parser.parse_args()
def main():
args = parse_args()
print(f'='*50)
print(f' '*15 + "Starting inference")
if torch.cuda.is_available():
print(f' '*15 + "GPU:", torch.cuda.get_device_name(0))
torch.cuda.manual_seed_all(SEED)
torch.cuda.manual_seed(SEED)
else:
print(f' '*15 + "No GPUs available")
print(f'='*50 + "\n")
print(f"Running inference on: {args.model_path.upper()} using {args.dataset}")
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
infer_data = f"data/processed/{args.dataset}_infer.csv"
if __name__ == main():
main()

View File

@@ -52,7 +52,7 @@ class Model(nn.Module):
# Applied across shared cls token, before all task heads # Applied across shared cls token, before all task heads
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
# get logits for each head
self.bug_head = nn.Linear(hidden_size, 2) self.bug_head = nn.Linear(hidden_size, 2)
self.feature_head = nn.Linear(hidden_size, 2) self.feature_head = nn.Linear(hidden_size, 2)
self.aspect_head = nn.Linear(hidden_size, 6) self.aspect_head = nn.Linear(hidden_size, 6)

View File

@@ -160,8 +160,8 @@ def preprocess_uber_reviews(input_path, output_path):
return df_clean return df_clean
if __name__ == "__main__": if __name__ == "__main__":
input_file = "multitag/data/uber_reviews.csv" input_file = "data/raw/uber_reviews.csv"
output_file = "multitag/data/uber_reviews_cleaned.csv" output_file = "data/raw/uber_reviews_cleaned.csv"
df_clean = preprocess_uber_reviews(input_file, output_file) df_clean = preprocess_uber_reviews(input_file, output_file)
print("\nPreprocessing complete!") print("\nPreprocessing complete!")

View File

@@ -190,7 +190,6 @@ class Sampler:
mini_sample = self.data.sample(200) # reading some samples manually mini_sample = self.data.sample(200) # reading some samples manually
return mini_sample return mini_sample
def save_sample(self, sample_df,output_path): def save_sample(self, sample_df,output_path):
"""Save sample and display statistics""" """Save sample and display statistics"""

View File

@@ -126,6 +126,7 @@ def main():
print("Aspect sentiment class weights:", aspect_sentiment_weights.cpu().numpy()) print("Aspect sentiment class weights:", aspect_sentiment_weights.cpu().numpy())
# equal weighted task losses. unequal was considered but equal weights performed well without adding complexity # equal weighted task losses. unequal was considered but equal weights performed well without adding complexity
# CrossEntropyLoss = LogSoftmax + NLLLoss (negative log likelihood)
criterions = { criterions = {
'bug_report': nn.CrossEntropyLoss(weight=bug_weights), 'bug_report': nn.CrossEntropyLoss(weight=bug_weights),
'feature_request': nn.CrossEntropyLoss(weight=feature_weights), 'feature_request': nn.CrossEntropyLoss(weight=feature_weights),
@@ -134,6 +135,7 @@ def main():
} }
# -------------------- Optimizer and scheduler ------------------- # -------------------- Optimizer and scheduler -------------------
# adaptive momentum and weight decay keeps track of previous weight adaptions and ensures they dont get too large (weight also shrinks towards 0 each pass)
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
model.parameters(), model.parameters(),
lr=args.lr, lr=args.lr,