diff --git a/.gitignore b/.gitignore index 132f261..61417b6 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ __pycache__/ *.png *.jpg *.json +*.bat \ No newline at end of file diff --git a/notebooks/reclass_analysis.ipynb b/notebooks/reclass_analysis.ipynb new file mode 100644 index 0000000..ec4e87c --- /dev/null +++ b/notebooks/reclass_analysis.ipynb @@ -0,0 +1,585 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "id": "03dd6d43", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'c:\\\\Users\\\\Charlie\\\\6013\\\\notebooks'" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import os\n", + "\n", + "os.getcwd()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e0c7757e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " mode dataset task model_path \\\n", + "bug_report mtl original all outputs/best_model_mtl_original.pt \n", + "feature_request mtl original all outputs/best_model_mtl_original.pt \n", + "aspect mtl original all outputs/best_model_mtl_original.pt \n", + "aspect_sentiment mtl original all outputs/best_model_mtl_original.pt \n", + "\n", + " results \n", + "bug_report {'macro_f1': 0.7833333333333331, 'macro_precis... \n", + "feature_request {'macro_f1': 0.7632819746470161, 'macro_precis... \n", + "aspect {'macro_f1': 0.7170467094024511, 'macro_precis... \n", + "aspect_sentiment {'macro_f1': 0.7574652640875611, 'macro_precis... \n", + " mode dataset task model_path \\\n", + "bug_report mtl boosted all outputs/best_model_mtl_boosted.pt \n", + "feature_request mtl boosted all outputs/best_model_mtl_boosted.pt \n", + "aspect mtl boosted all outputs/best_model_mtl_boosted.pt \n", + "aspect_sentiment mtl boosted all outputs/best_model_mtl_boosted.pt \n", + "\n", + " results \n", + "bug_report {'macro_f1': 0.9051856266200821, 'macro_precis... \n", + "feature_request {'macro_f1': 0.8164215686274511, 'macro_precis... \n", + "aspect {'macro_f1': 0.8025782333853201, 'macro_precis... \n", + "aspect_sentiment {'macro_f1': 0.6003394063095551, 'macro_precis... \n", + " mode dataset task model_path \\\n", + "aspect stl original aspect outputs/best_model_stl_aspect_original.pt \n", + "\n", + " results \n", + "aspect {'macro_f1': 0.6941267385341161, 'macro_precis... \n", + " mode dataset task \\\n", + "aspect_sentiment stl original aspect_sentiment \n", + "\n", + " model_path \\\n", + "aspect_sentiment outputs/best_model_stl_aspect_sentiment_origin... \n", + "\n", + " results \n", + "aspect_sentiment {'macro_f1': 0.7856154710827541, 'macro_precis... \n", + " mode dataset task \\\n", + "bug_report stl original bug_report \n", + "\n", + " model_path \\\n", + "bug_report outputs/best_model_stl_bug_report_original.pt \n", + "\n", + " results \n", + "bug_report {'macro_f1': 0.7845034000574651, 'macro_precis... \n", + " mode dataset task \\\n", + "feature_request stl original feature_request \n", + "\n", + " model_path \\\n", + "feature_request outputs/best_model_stl_feature_request_origina... \n", + "\n", + " results \n", + "feature_request {'macro_f1': 0.741924491616304, 'macro_precisi... \n" + ] + } + ], + "source": [ + "summaries = {\n", + " \"mtl_original\": pd.read_json(\"../outputs/eval_summary_mtl_mtl_original.json\"),\n", + " \"mtl_boosted\": pd.read_json(\"../outputs/eval_summary_mtl_mtl_boosted.json\"),\n", + " \"stl_aspect\": pd.read_json(\"../outputs/eval_summary_stl_aspect_original.json\"),\n", + " \"stl_aspect_sentiment\": pd.read_json(\"../outputs/eval_summary_stl_aspect_sentiment_original.json\"),\n", + " \"stl_bug_report\": pd.read_json(\"../outputs/eval_summary_stl_bug_report_original.json\"),\n", + " \"stl_feature_request\": pd.read_json(\"../outputs/eval_summary_stl_feature_request_original.json\"),\n", + "}\n", + "for task, df in summaries.items():\n", + " print(df.head())\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e9ddb0d1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Index(['mode', 'dataset', 'task', 'model_path', 'results'], dtype='object')\n", + "mode mtl\n", + "dataset original\n", + "task all\n", + "model_path outputs/best_model_mtl_original.pt\n", + "results {'macro_f1': 0.7833333333333331, 'macro_precis...\n", + "Name: bug_report, dtype: object\n" + ] + } + ], + "source": [ + "print(type(summaries[\"mtl_original\"]))\n", + "print(summaries[\"mtl_original\"].columns)\n", + "print(summaries[\"mtl_original\"].iloc[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7fbdcc74", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " model task macro_f1 macro_precision macro_recall \\\n", + "0 mtl_original bug_report 0.783333 0.768456 0.802785 \n", + "1 mtl_original feature_request 0.763282 0.752334 0.776279 \n", + "2 mtl_original aspect 0.717047 0.718659 0.717574 \n", + "3 mtl_original aspect_sentiment 0.757465 0.771312 0.747191 \n", + "4 mtl_boosted bug_report 0.905186 0.901647 0.909222 \n", + "\n", + " accuracy conf_overall conf_correct conf_incorrect \n", + "0 0.861333 0.956886 0.971339 0.867113 \n", + "1 0.869333 0.960423 0.971002 0.890037 \n", + "2 0.736000 0.898030 0.927060 0.817099 \n", + "3 0.912000 0.965079 0.976294 0.848845 \n", + "4 0.913218 0.975117 0.983552 0.886357 \n" + ] + } + ], + "source": [ + "# Creating section 1, general summary table of macro_f1, macro_precision, macro_recall, accuracy inside per_class with confidence values\n", + "\"\"\"\n", + "{\n", + " \"mode\": \"mtl\",\n", + " \"dataset\": \"original\",\n", + " \"task\": \"all\",\n", + " \"model_path\": \"outputs/best_model_mtl_original.pt\",\n", + " \"results\": {\n", + " \"bug_report\": {\n", + " \"macro_f1\": 0.7833333333333333,\n", + " \"macro_precision\": 0.7684555303602922,\n", + " \"macro_recall\": 0.8027848820687695,\n", + " \"confidence\": {\n", + " \"overall\": 0.9568860530853271,\n", + " \"correct\": 0.97133868932724,\n", + " \"incorrect\": 0.8671128153800964\n", + " },\n", + " \"per_class\": {\n", + " \"No\": {\n", + " \"precision\": 0.9319727891156463,\n", + " \"recall\": 0.8954248366013072,\n", + " \"f1-score\": 0.9133333333333333,\n", + " \"support\": 612.0\n", + " },\n", + " \"Yes\": {\n", + " \"precision\": 0.6049382716049383,\n", + " \"recall\": 0.7101449275362319,\n", + " \"f1-score\": 0.6533333333333333,\n", + " \"support\": 138.0\n", + " },\n", + " \"accuracy\": 0.8613333333333333,\n", + " \"macro avg\": {\n", + " \"precision\": 0.7684555303602922,\n", + " \"recall\": 0.8027848820687695,\n", + " \"f1-score\": 0.7833333333333333,\n", + " \"support\": 750.0\n", + " },\n", + " \"weighted avg\": {\n", + " \"precision\": 0.8717984378936761,\n", + " \"recall\": 0.8613333333333333,\n", + " \"f1-score\": 0.8654933333333333,\n", + " \"support\": 750.0\n", + " }\n", + " }\n", + " },\n", + "\"\"\"\n", + "\n", + "\n", + "\n", + "rows = []\n", + "# \"mtl_original\", pd.read_json(\"../outputs/eval_summary_mtl_mtl_original.json\")\n", + "for model_name, df in summaries.items():\n", + " \n", + " for task_name, row in df.iterrows():\n", + " \n", + " task_result = row[\"results\"]\n", + "\n", + " rows.append({\n", + " \"model\": model_name,\n", + " \"task\": task_name,\n", + "\n", + " \"macro_f1\": task_result[\"macro_f1\"],\n", + " \"macro_precision\": task_result[\"macro_precision\"],\n", + " \"macro_recall\": task_result[\"macro_recall\"],\n", + " \"accuracy\": task_result[\"per_class\"][\"accuracy\"],\n", + "\n", + " \"conf_overall\": task_result[\"confidence\"][\"overall\"],\n", + " \"conf_correct\": task_result[\"confidence\"][\"correct\"],\n", + " \"conf_incorrect\": task_result[\"confidence\"][\"incorrect\"],\n", + "\n", + " })\n", + "\n", + "analysis_summary_df = pd.DataFrame(rows)\n", + "print(analysis_summary_df.head())\n", + "analysis_summary_df.to_csv(\"analysis.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b66b19f4", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4147fda8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model mtl_boosted mtl_original stl_aspect stl_aspect_sentiment \\\n", + "task \n", + "aspect 0.802578 0.717047 0.694127 NaN \n", + "aspect_sentiment 0.600339 0.757465 NaN 0.785615 \n", + "bug_report 0.905186 0.783333 NaN NaN \n", + "feature_request 0.816422 0.763282 NaN NaN \n", + "\n", + "model stl_bug_report stl_feature_request \n", + "task \n", + "aspect NaN NaN \n", + "aspect_sentiment NaN NaN \n", + "bug_report 0.784503 NaN \n", + "feature_request NaN 0.741924 \n", + "\n", + "final:\n", + "\n", + " model stl mtl_original mtl_boosted mtl_orig_excluding_stl \\\n", + "task \n", + "aspect 0.694127 0.717047 0.802578 0.02292 \n", + "aspect_sentiment 0.785615 0.757465 0.600339 -0.02815 \n", + "bug_report 0.784503 0.783333 0.905186 -0.00117 \n", + "feature_request 0.741924 0.763282 0.816422 0.021357 \n", + "\n", + "model mtl_boost_excluding_orig \n", + "task \n", + "aspect 0.085532 \n", + "aspect_sentiment -0.157126 \n", + "bug_report 0.121852 \n", + "feature_request 0.053140 \n" + ] + } + ], + "source": [ + "pivot = analysis_summary_df.pivot(\n", + " index=\"task\",\n", + " columns=\"model\",\n", + " values=\"macro_f1\"\n", + ")\n", + "\n", + "print(pivot) # wrong, pull actual stl score\n", + "\n", + "pivot[\"stl\"] = None\n", + "\n", + "pivot.loc[\"bug_report\", \"stl\"] = pivot.loc[\"bug_report\", \"stl_bug_report\"]\n", + "pivot.loc[\"feature_request\", \"stl\"] = pivot.loc[\"feature_request\", \"stl_feature_request\"]\n", + "pivot.loc[\"aspect\", \"stl\"] = pivot.loc[\"aspect\", \"stl_aspect\"]\n", + "pivot.loc[\"aspect_sentiment\", \"stl\"] = pivot.loc[\"aspect_sentiment\", \"stl_aspect_sentiment\"]\n", + "\n", + "pivot[\"mtl_orig_excluding_stl\"] = pivot[\"mtl_original\"] - pivot[\"stl\"]\n", + "pivot[\"mtl_boost_excluding_orig\"] = pivot[\"mtl_boosted\"] - pivot[\"mtl_original\"]\n", + "final = pivot[\n", + " [\n", + " \"stl\",\n", + " \"mtl_original\",\n", + " \"mtl_boosted\",\n", + " \"mtl_orig_excluding_stl\",\n", + " \"mtl_boost_excluding_orig\"\n", + " ]\n", + "]\n", + "\n", + "print(\"\\nfinal:\\n\\n\", final)\n", + "\n", + "final.to_csv(\"analysis_macro_f1_pivot.csv\")\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "1182dc66", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Per class df model task class precision recall \\\n", + "0 mtl_original bug_report No 0.931973 0.895425 \n", + "1 mtl_original bug_report Yes 0.604938 0.710145 \n", + "2 mtl_original feature_request No 0.932149 0.911532 \n", + "3 mtl_original feature_request Yes 0.572519 0.641026 \n", + "4 mtl_original aspect App 0.787500 0.840000 \n", + "5 mtl_original aspect Driver 0.786260 0.780303 \n", + "6 mtl_original aspect General 0.803738 0.688000 \n", + "7 mtl_original aspect Payment 0.617647 0.636364 \n", + "8 mtl_original aspect Pricing 0.698630 0.750000 \n", + "9 mtl_original aspect Service 0.618182 0.610778 \n", + "10 mtl_original aspect_sentiment Positive 0.938776 0.953368 \n", + "11 mtl_original aspect_sentiment Neutral 0.451613 0.358974 \n", + "12 mtl_original aspect_sentiment Negative 0.923547 0.929231 \n", + "13 mtl_boosted bug_report No 0.943515 0.922290 \n", + "14 mtl_boosted bug_report Yes 0.859779 0.896154 \n", + "15 mtl_boosted feature_request No 0.963343 0.969027 \n", + "16 mtl_boosted feature_request Yes 0.686567 0.647887 \n", + "17 mtl_boosted aspect App 0.932990 0.841860 \n", + "18 mtl_boosted aspect Driver 0.839572 0.892045 \n", + "19 mtl_boosted aspect General 0.829268 0.772727 \n", + "20 mtl_boosted aspect Payment 0.711111 0.842105 \n", + "21 mtl_boosted aspect Pricing 0.826087 0.802817 \n", + "22 mtl_boosted aspect Service 0.669291 0.691057 \n", + "23 mtl_boosted aspect_sentiment Positive 0.800000 0.882353 \n", + "24 mtl_boosted aspect_sentiment Neutral 0.000000 0.000000 \n", + "25 mtl_boosted aspect_sentiment Negative 0.968280 0.955519 \n", + "26 stl_aspect aspect App 0.797357 0.804444 \n", + "27 stl_aspect aspect Driver 0.717241 0.787879 \n", + "28 stl_aspect aspect General 0.761905 0.640000 \n", + "29 stl_aspect aspect Payment 0.625000 0.606061 \n", + "30 stl_aspect aspect Pricing 0.561404 0.941176 \n", + "31 stl_aspect aspect Service 0.692913 0.526946 \n", + "32 stl_aspect_sentiment aspect_sentiment Positive 0.960000 0.932642 \n", + "33 stl_aspect_sentiment aspect_sentiment Neutral 0.500000 0.461538 \n", + "34 stl_aspect_sentiment aspect_sentiment Negative 0.911504 0.950769 \n", + "35 stl_bug_report bug_report No 0.942105 0.877451 \n", + "36 stl_bug_report bug_report Yes 0.583333 0.760870 \n", + "37 stl_feature_request feature_request No 0.920635 0.916272 \n", + "38 stl_feature_request feature_request Yes 0.558333 0.572650 \n", + "\n", + " f1 support \n", + "0 0.913333 612.0 \n", + "1 0.653333 138.0 \n", + "2 0.921725 633.0 \n", + "3 0.604839 117.0 \n", + "4 0.812903 225.0 \n", + "5 0.783270 132.0 \n", + "6 0.741379 125.0 \n", + "7 0.626866 33.0 \n", + "8 0.723404 68.0 \n", + "9 0.614458 167.0 \n", + "10 0.946015 386.0 \n", + "11 0.400000 39.0 \n", + "12 0.926380 325.0 \n", + "13 0.932782 489.0 \n", + "14 0.877589 260.0 \n", + "15 0.966176 678.0 \n", + "16 0.666667 71.0 \n", + "17 0.885086 215.0 \n", + "18 0.865014 176.0 \n", + "19 0.800000 88.0 \n", + "20 0.771084 76.0 \n", + "21 0.814286 71.0 \n", + "22 0.680000 123.0 \n", + "23 0.839161 136.0 \n", + "24 0.000000 6.0 \n", + "25 0.961857 607.0 \n", + "26 0.800885 225.0 \n", + "27 0.750903 132.0 \n", + "28 0.695652 125.0 \n", + "29 0.615385 33.0 \n", + "30 0.703297 68.0 \n", + "31 0.598639 167.0 \n", + "32 0.946124 386.0 \n", + "33 0.480000 39.0 \n", + "34 0.930723 325.0 \n", + "35 0.908629 612.0 \n", + "36 0.660377 138.0 \n", + "37 0.918448 633.0 \n", + "38 0.565401 117.0 \n" + ] + } + ], + "source": [ + "per_class_rows = []\n", + "skip = {\"accuracy\", \"macro avg\", \"weighted avg\"}\n", + "\n", + "for model_name, df in summaries.items():\n", + " for task_name, row in df.iterrows():\n", + " task_result = row[\"results\"]\n", + " \n", + " for class_name, class_metrics in task_result[\"per_class\"].items():\n", + " if class_name in skip:\n", + " continue\n", + " per_class_rows.append({\n", + " \"model\": model_name,\n", + " \"task\": task_name,\n", + " \"class\": class_name,\n", + " \"precision\": class_metrics[\"precision\"],\n", + " \"recall\": class_metrics[\"recall\"],\n", + " \"f1\": class_metrics[\"f1-score\"],\n", + " \"support\": class_metrics[\"support\"],\n", + " })\n", + "\n", + "per_class_analysis_df = pd.DataFrame(per_class_rows)\n", + "print(\"Per class df\", per_class_analysis_df)\n", + "per_class_analysis_df.to_csv(\"per_class_analysis.csv\", index=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f796bc9e", + "metadata": {}, + "outputs": [], + "source": [ + "label_names = {\n", + " 'bug_report': ['No', 'Yes'],\n", + " 'feature_request': ['No', 'Yes'],\n", + " 'aspect': ['App', 'Driver', 'General', 'Payment', 'Pricing', 'Service'],\n", + " 'aspect_sentiment': ['Positive', 'Neutral', 'Negative']\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "4cc572c1", + "metadata": {}, + "outputs": [], + "source": [ + "tasks = [\"bug_report\", \"feature_request\", \"aspect\", \"aspect_sentiment\"]\n", + "mcnemar_rows = []\n", + "\n", + "mtl_path = \"../outputs/test_predictions_mtl_mtl_original.csv\"\n", + "mtl_df = pd.read_csv(mtl_path)\n", + "\n", + "from statsmodels.stats.contingency_tables import mcnemar" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "0ef9d5ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "McNemar Test Results (MTL vs STL Original):\n", + " task mtl_only_correct stl_only_correct p_value significant\n", + "0 bug_report 32 28 0.698883 False\n", + "1 feature_request 42 37 0.652964 False\n", + "2 aspect 70 55 0.210327 False\n", + "3 aspect_sentiment 15 18 0.728332 False\n" + ] + } + ], + "source": [ + "for task in tasks:\n", + " stl_path = f\"../outputs/test_predictions_stl_{task}_original.csv\"\n", + " \n", + " if not os.path.exists(stl_path):\n", + " print(f\"Warning: Missing STL file for {task}\")\n", + " continue\n", + " \n", + " stl_df = pd.read_csv(stl_path)\n", + "\n", + " # 3. Align ground truth\n", + " # The 'task' column in the CSV is numeric (0, 1, 2...). \n", + " # The '{task}_pred' column is a string ('Yes', 'No', 'App'...).\n", + " # We map the numeric ground truth to labels so we can compare directly.\n", + " y_true = [label_names[task][int(val)] for val in mtl_df[task]]\n", + " \n", + " # 4. Get predictions\n", + " y_mtl = mtl_df[f\"{task}_pred\"].values\n", + " y_stl = stl_df[f\"{task}_pred\"].values\n", + "\n", + " # 5. Determine correctness for each model\n", + " mtl_correct = (y_mtl == y_true)\n", + " stl_correct = (y_stl == y_true)\n", + "\n", + " # 6. Build the 2x2 Contingency Table\n", + " # Table structure for McNemar:\n", + " # MTL Correct | MTL Wrong\n", + " # STL Correct | [0,0] | [0,1]\n", + " # STL Wrong | [1,0] | [1,1]\n", + " \n", + " both_correct = np.sum(mtl_correct & stl_correct)\n", + " stl_only = np.sum(~mtl_correct & stl_correct) # STL right, MTL wrong\n", + " mtl_only = np.sum(mtl_correct & ~stl_correct) # MTL right, STL wrong\n", + " both_wrong = np.sum(~mtl_correct & ~stl_correct)\n", + "\n", + " contingency_table = [\n", + " [both_correct, stl_only],\n", + " [mtl_only, both_wrong]\n", + " ]\n", + "\n", + " # 7. Run McNemar's Test\n", + " # We use exact=True because some of our discordant cells (mtl_only/stl_only) might be small\n", + " result = mcnemar(contingency_table, exact=True)\n", + " \n", + " mcnemar_rows.append({\n", + " \"task\": task,\n", + " \"both_correct\": both_correct,\n", + " \"mtl_only_correct\": mtl_only,\n", + " \"stl_only_correct\": stl_only,\n", + " \"both_wrong\": both_wrong,\n", + " \"p_value\": result.pvalue,\n", + " \"significant\": result.pvalue < 0.05\n", + " })\n", + "\n", + "# 8. Save and Display Results\n", + "mcnemar_df = pd.DataFrame(mcnemar_rows)\n", + "print(\"\\nMcNemar Test Results (MTL vs STL Original):\")\n", + "print(mcnemar_df[[\"task\", \"mtl_only_correct\", \"stl_only_correct\", \"p_value\", \"significant\"]])\n", + "\n", + "mcnemar_df.to_csv(\"analysis_mcnemar_results.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6afd6cf8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "multitag", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/dataset.py b/src/dataset.py index 3eaa170..e5b8ccc 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -51,7 +51,7 @@ class ReviewDataset(Dataset): if __name__ == "__main__": dataset = ReviewDataset("data/processed/original_train.csv", AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")) - # print(dataset.__getitem__(1)) + print(dataset.__getitem__(1)) diff --git a/src/infer.py b/src/infer.py index 27a9f8a..f324e53 100644 --- a/src/infer.py +++ b/src/infer.py @@ -1,4 +1,25 @@ +import pandas as pd +import numpy as np +import torch +import argparse +from transformers import AutoTokenizer +from torch.utils.tensorboard import SummaryWriter + # mappings -binary_map = {'Yes': 1, 'No': 0} -aspect_map = {'App': 0, 'Driver': 1, 'General': 2, 'Payment': 3, 'Pricing': 4, 'Service': 5} -sentiment_map = {'Positive': 0, 'Neutral': 1, 'Negative': 2} +binary_map = {1:'Yes', 0:'No'} +aspect_map = {0:'App', 1:'Driver', 2:'General', 3:'Payment', 4:'Pricing', 5:'Service'} +sentiment_map = {0:'Positive', 1:'Neutral', 2:'Negative'} + + +SEED = 4321 +torch.manual_seed(SEED) +np.random.seed(SEED) + +def parse_args(): + parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.") + parser.add_argument("--model_path", type=str, help="Enter the models path / the desired .pt file") + parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" ) + parser.add_argument("--interactive", help="Loops reading input until exit") + parser.add_argument("--text", help="Use command line text for input") + return parser.parse_args() +