Added comments and made a start on infer.py
This commit is contained in:
261
notebooks/getting_csv_for_inference.ipynb
Normal file
261
notebooks/getting_csv_for_inference.ipynb
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 33,
|
||||||
|
"id": "79ac71dd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 34,
|
||||||
|
"id": "aa9117f0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>source</th>\n",
|
||||||
|
" <th>review_id</th>\n",
|
||||||
|
" <th>user_name</th>\n",
|
||||||
|
" <th>review_title</th>\n",
|
||||||
|
" <th>review_description</th>\n",
|
||||||
|
" <th>rating</th>\n",
|
||||||
|
" <th>thumbs_up</th>\n",
|
||||||
|
" <th>review_date</th>\n",
|
||||||
|
" <th>developer_response</th>\n",
|
||||||
|
" <th>developer_response_date</th>\n",
|
||||||
|
" <th>appVersion</th>\n",
|
||||||
|
" <th>laguage_code</th>\n",
|
||||||
|
" <th>country_code</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>Google Play</td>\n",
|
||||||
|
" <td>18d6584c-d0e9-4833-a744-f607058aee97</td>\n",
|
||||||
|
" <td>Milky Way</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Suddenly, the driver can't have my location an...</td>\n",
|
||||||
|
" <td>1</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>2023-08-10 17:48:51</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>en</td>\n",
|
||||||
|
" <td>in</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>Google Play</td>\n",
|
||||||
|
" <td>50a08f18-cece-4ddf-b617-028844c8aa28</td>\n",
|
||||||
|
" <td>Bradlee Severa</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Very cordial.. And helped with a quick turnaro...</td>\n",
|
||||||
|
" <td>5</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>2023-08-10 17:38:35</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>4.485.10000</td>\n",
|
||||||
|
" <td>en</td>\n",
|
||||||
|
" <td>in</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>Google Play</td>\n",
|
||||||
|
" <td>b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7</td>\n",
|
||||||
|
" <td>Amit Aggarwal</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Very good experience</td>\n",
|
||||||
|
" <td>5</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>2023-08-10 17:38:17</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>4.486.10002</td>\n",
|
||||||
|
" <td>en</td>\n",
|
||||||
|
" <td>in</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>Google Play</td>\n",
|
||||||
|
" <td>502702a9-25ed-4373-a96c-7fa1f06caacd</td>\n",
|
||||||
|
" <td>Bryant Inman</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>All I use</td>\n",
|
||||||
|
" <td>5</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>2023-08-10 17:37:45</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>4.467.10008</td>\n",
|
||||||
|
" <td>en</td>\n",
|
||||||
|
" <td>in</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>Google Play</td>\n",
|
||||||
|
" <td>f47a3fb6-23db-49bd-9e63-f33c8d724d07</td>\n",
|
||||||
|
" <td>Addie Whittaker</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>I have enjoyed traveling by Uber my drivers ha...</td>\n",
|
||||||
|
" <td>5</td>\n",
|
||||||
|
" <td>0.0</td>\n",
|
||||||
|
" <td>2023-08-10 17:36:56</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>4.486.10002</td>\n",
|
||||||
|
" <td>en</td>\n",
|
||||||
|
" <td>in</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" source review_id user_name \\\n",
|
||||||
|
"0 Google Play 18d6584c-d0e9-4833-a744-f607058aee97 Milky Way \n",
|
||||||
|
"1 Google Play 50a08f18-cece-4ddf-b617-028844c8aa28 Bradlee Severa \n",
|
||||||
|
"2 Google Play b0d8e75a-80a7-4dcd-abaf-72b046dbeeb7 Amit Aggarwal \n",
|
||||||
|
"3 Google Play 502702a9-25ed-4373-a96c-7fa1f06caacd Bryant Inman \n",
|
||||||
|
"4 Google Play f47a3fb6-23db-49bd-9e63-f33c8d724d07 Addie Whittaker \n",
|
||||||
|
"\n",
|
||||||
|
" review_title review_description rating \\\n",
|
||||||
|
"0 NaN Suddenly, the driver can't have my location an... 1 \n",
|
||||||
|
"1 NaN Very cordial.. And helped with a quick turnaro... 5 \n",
|
||||||
|
"2 NaN Very good experience 5 \n",
|
||||||
|
"3 NaN All I use 5 \n",
|
||||||
|
"4 NaN I have enjoyed traveling by Uber my drivers ha... 5 \n",
|
||||||
|
"\n",
|
||||||
|
" thumbs_up review_date developer_response developer_response_date \\\n",
|
||||||
|
"0 0.0 2023-08-10 17:48:51 NaN NaN \n",
|
||||||
|
"1 0.0 2023-08-10 17:38:35 NaN NaN \n",
|
||||||
|
"2 0.0 2023-08-10 17:38:17 NaN NaN \n",
|
||||||
|
"3 0.0 2023-08-10 17:37:45 NaN NaN \n",
|
||||||
|
"4 0.0 2023-08-10 17:36:56 NaN NaN \n",
|
||||||
|
"\n",
|
||||||
|
" appVersion laguage_code country_code \n",
|
||||||
|
"0 NaN en in \n",
|
||||||
|
"1 4.485.10000 en in \n",
|
||||||
|
"2 4.486.10002 en in \n",
|
||||||
|
"3 4.467.10008 en in \n",
|
||||||
|
"4 4.486.10002 en in "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 34,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df = pd.read_csv(\"../data/raw/uber_reviews.csv\", low_memory=False)\n",
|
||||||
|
"df.head(5)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 35,
|
||||||
|
"id": "36683790",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"<class 'pandas.DataFrame'>\n",
|
||||||
|
"RangeIndex: 1069616 entries, 0 to 1069615\n",
|
||||||
|
"Data columns (total 13 columns):\n",
|
||||||
|
" # Column Non-Null Count Dtype \n",
|
||||||
|
"--- ------ -------------- ----- \n",
|
||||||
|
" 0 source 1069616 non-null str \n",
|
||||||
|
" 1 review_id 1069616 non-null str \n",
|
||||||
|
" 2 user_name 1069615 non-null str \n",
|
||||||
|
" 3 review_title 2180 non-null str \n",
|
||||||
|
" 4 review_description 1069447 non-null str \n",
|
||||||
|
" 5 rating 1069616 non-null int64 \n",
|
||||||
|
" 6 thumbs_up 1067436 non-null float64\n",
|
||||||
|
" 7 review_date 1069616 non-null str \n",
|
||||||
|
" 8 developer_response 198264 non-null str \n",
|
||||||
|
" 9 developer_response_date 197278 non-null str \n",
|
||||||
|
" 10 appVersion 828068 non-null str \n",
|
||||||
|
" 11 laguage_code 1069616 non-null str \n",
|
||||||
|
" 12 country_code 1069616 non-null str \n",
|
||||||
|
"dtypes: float64(1), int64(1), str(11)\n",
|
||||||
|
"memory usage: 106.1 MB\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df.info()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 36,
|
||||||
|
"id": "0b1c7f73",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df = df[['review_description']].reset_index(drop=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 37,
|
||||||
|
"id": "90e6d653",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df.head(5)\n",
|
||||||
|
"df.to_csv(\"../data/raw/review_description.csv\", index=False)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "multitag",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.15"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -5,19 +5,7 @@
|
|||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"id": "2b7cfa1a",
|
"id": "2b7cfa1a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"ename": "ModuleNotFoundError",
|
|
||||||
"evalue": "No module named 'sklearn'",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
|
||||||
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
|
|
||||||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpandas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpd\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel_selection\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m train_test_split\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnumpy\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mnp\u001b[39;00m\n",
|
|
||||||
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'sklearn'"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
@@ -1209,7 +1197,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.13.11"
|
"version": "3.11.15"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
37
src/infer.py
37
src/infer.py
@@ -2,14 +2,21 @@ import pandas as pd
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
from transformers import AutoTokenizer
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers import AutoModelForSequenceClassification
|
||||||
|
|
||||||
# mappings
|
# mappings
|
||||||
binary_map = {1:'Yes', 0:'No'}
|
binary_map = {1:'Yes', 0:'No'}
|
||||||
aspect_map = {0:'App', 1:'Driver', 2:'General', 3:'Payment', 4:'Pricing', 5:'Service'}
|
aspect_map = {0:'App', 1:'Driver', 2:'General', 3:'Payment', 4:'Pricing', 5:'Service'}
|
||||||
sentiment_map = {0:'Positive', 1:'Neutral', 2:'Negative'}
|
sentiment_map = {0:'Positive', 1:'Neutral', 2:'Negative'}
|
||||||
|
|
||||||
|
label_names = {
|
||||||
|
'bug_report': ['No', 'Yes'],
|
||||||
|
'feature_request': ['No', 'Yes'],
|
||||||
|
'aspect': ['App', 'Driver', 'General', 'Payment', 'Pricing', 'Service'],
|
||||||
|
'aspect_sentiment': ['Positive', 'Neutral', 'Negative']
|
||||||
|
}
|
||||||
|
|
||||||
SEED = 4321
|
SEED = 4321
|
||||||
torch.manual_seed(SEED)
|
torch.manual_seed(SEED)
|
||||||
@@ -17,9 +24,31 @@ np.random.seed(SEED)
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
||||||
parser.add_argument("--model_path", type=str, help="Enter the models path / the desired .pt file")
|
parser.add_argument("--model_path", type=str, required=True, help=".pt file path")
|
||||||
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" )
|
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"])
|
||||||
parser.add_argument("--interactive", help="Loops reading input until exit")
|
parser.add_argument("--interactive", help="Loops reading input until exit()")
|
||||||
parser.add_argument("--text", help="Use command line text for input")
|
parser.add_argument("--text", help="Use command line text for input")
|
||||||
|
parser.add_argument("--dataset", type=str, required=True, help="Enter a file for inference")
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
print(f'='*50)
|
||||||
|
print(f' '*15 + "Starting inference")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f' '*15 + "GPU:", torch.cuda.get_device_name(0))
|
||||||
|
torch.cuda.manual_seed_all(SEED)
|
||||||
|
torch.cuda.manual_seed(SEED)
|
||||||
|
else:
|
||||||
|
print(f' '*15 + "No GPUs available")
|
||||||
|
print(f'='*50 + "\n")
|
||||||
|
print(f"Running inference on: {args.model_path.upper()} using {args.dataset}")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||||
|
infer_data = f"data/processed/{args.dataset}_infer.csv"
|
||||||
|
|
||||||
|
if __name__ == main():
|
||||||
|
main()
|
||||||
@@ -52,7 +52,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
# Applied across shared cls token, before all task heads
|
# Applied across shared cls token, before all task heads
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
|
# get logits for each head
|
||||||
self.bug_head = nn.Linear(hidden_size, 2)
|
self.bug_head = nn.Linear(hidden_size, 2)
|
||||||
self.feature_head = nn.Linear(hidden_size, 2)
|
self.feature_head = nn.Linear(hidden_size, 2)
|
||||||
self.aspect_head = nn.Linear(hidden_size, 6)
|
self.aspect_head = nn.Linear(hidden_size, 6)
|
||||||
|
|||||||
@@ -160,8 +160,8 @@ def preprocess_uber_reviews(input_path, output_path):
|
|||||||
return df_clean
|
return df_clean
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
input_file = "multitag/data/uber_reviews.csv"
|
input_file = "data/raw/uber_reviews.csv"
|
||||||
output_file = "multitag/data/uber_reviews_cleaned.csv"
|
output_file = "data/raw/uber_reviews_cleaned.csv"
|
||||||
|
|
||||||
df_clean = preprocess_uber_reviews(input_file, output_file)
|
df_clean = preprocess_uber_reviews(input_file, output_file)
|
||||||
print("\nPreprocessing complete!")
|
print("\nPreprocessing complete!")
|
||||||
|
|||||||
@@ -191,7 +191,6 @@ class Sampler:
|
|||||||
return mini_sample
|
return mini_sample
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_sample(self, sample_df,output_path):
|
def save_sample(self, sample_df,output_path):
|
||||||
"""Save sample and display statistics"""
|
"""Save sample and display statistics"""
|
||||||
sample_df.to_csv(output_path, index=False)
|
sample_df.to_csv(output_path, index=False)
|
||||||
|
|||||||
@@ -126,6 +126,7 @@ def main():
|
|||||||
print("Aspect sentiment class weights:", aspect_sentiment_weights.cpu().numpy())
|
print("Aspect sentiment class weights:", aspect_sentiment_weights.cpu().numpy())
|
||||||
|
|
||||||
# equal weighted task losses. unequal was considered but equal weights performed well without adding complexity
|
# equal weighted task losses. unequal was considered but equal weights performed well without adding complexity
|
||||||
|
# CrossEntropyLoss = LogSoftmax + NLLLoss (negative log likelihood)
|
||||||
criterions = {
|
criterions = {
|
||||||
'bug_report': nn.CrossEntropyLoss(weight=bug_weights),
|
'bug_report': nn.CrossEntropyLoss(weight=bug_weights),
|
||||||
'feature_request': nn.CrossEntropyLoss(weight=feature_weights),
|
'feature_request': nn.CrossEntropyLoss(weight=feature_weights),
|
||||||
@@ -134,6 +135,7 @@ def main():
|
|||||||
}
|
}
|
||||||
|
|
||||||
# -------------------- Optimizer and scheduler -------------------
|
# -------------------- Optimizer and scheduler -------------------
|
||||||
|
# adaptive momentum and weight decay keeps track of previous weight adaptions and ensures they dont get too large (weight also shrinks towards 0 each pass)
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
|
|||||||
Reference in New Issue
Block a user