Analysis started, almost complete - compiled some excel sheets from the csv output with notes. Started infer.py, nothing major implemented yet
This commit is contained in:
585
notebooks/reclass_analysis.ipynb
Normal file
585
notebooks/reclass_analysis.ipynb
Normal file
@@ -0,0 +1,585 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "03dd6d43",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'c:\\\\Users\\\\Charlie\\\\6013\\\\notebooks'"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.getcwd()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "e0c7757e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" mode dataset task model_path \\\n",
|
||||
"bug_report mtl original all outputs/best_model_mtl_original.pt \n",
|
||||
"feature_request mtl original all outputs/best_model_mtl_original.pt \n",
|
||||
"aspect mtl original all outputs/best_model_mtl_original.pt \n",
|
||||
"aspect_sentiment mtl original all outputs/best_model_mtl_original.pt \n",
|
||||
"\n",
|
||||
" results \n",
|
||||
"bug_report {'macro_f1': 0.7833333333333331, 'macro_precis... \n",
|
||||
"feature_request {'macro_f1': 0.7632819746470161, 'macro_precis... \n",
|
||||
"aspect {'macro_f1': 0.7170467094024511, 'macro_precis... \n",
|
||||
"aspect_sentiment {'macro_f1': 0.7574652640875611, 'macro_precis... \n",
|
||||
" mode dataset task model_path \\\n",
|
||||
"bug_report mtl boosted all outputs/best_model_mtl_boosted.pt \n",
|
||||
"feature_request mtl boosted all outputs/best_model_mtl_boosted.pt \n",
|
||||
"aspect mtl boosted all outputs/best_model_mtl_boosted.pt \n",
|
||||
"aspect_sentiment mtl boosted all outputs/best_model_mtl_boosted.pt \n",
|
||||
"\n",
|
||||
" results \n",
|
||||
"bug_report {'macro_f1': 0.9051856266200821, 'macro_precis... \n",
|
||||
"feature_request {'macro_f1': 0.8164215686274511, 'macro_precis... \n",
|
||||
"aspect {'macro_f1': 0.8025782333853201, 'macro_precis... \n",
|
||||
"aspect_sentiment {'macro_f1': 0.6003394063095551, 'macro_precis... \n",
|
||||
" mode dataset task model_path \\\n",
|
||||
"aspect stl original aspect outputs/best_model_stl_aspect_original.pt \n",
|
||||
"\n",
|
||||
" results \n",
|
||||
"aspect {'macro_f1': 0.6941267385341161, 'macro_precis... \n",
|
||||
" mode dataset task \\\n",
|
||||
"aspect_sentiment stl original aspect_sentiment \n",
|
||||
"\n",
|
||||
" model_path \\\n",
|
||||
"aspect_sentiment outputs/best_model_stl_aspect_sentiment_origin... \n",
|
||||
"\n",
|
||||
" results \n",
|
||||
"aspect_sentiment {'macro_f1': 0.7856154710827541, 'macro_precis... \n",
|
||||
" mode dataset task \\\n",
|
||||
"bug_report stl original bug_report \n",
|
||||
"\n",
|
||||
" model_path \\\n",
|
||||
"bug_report outputs/best_model_stl_bug_report_original.pt \n",
|
||||
"\n",
|
||||
" results \n",
|
||||
"bug_report {'macro_f1': 0.7845034000574651, 'macro_precis... \n",
|
||||
" mode dataset task \\\n",
|
||||
"feature_request stl original feature_request \n",
|
||||
"\n",
|
||||
" model_path \\\n",
|
||||
"feature_request outputs/best_model_stl_feature_request_origina... \n",
|
||||
"\n",
|
||||
" results \n",
|
||||
"feature_request {'macro_f1': 0.741924491616304, 'macro_precisi... \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"summaries = {\n",
|
||||
" \"mtl_original\": pd.read_json(\"../outputs/eval_summary_mtl_mtl_original.json\"),\n",
|
||||
" \"mtl_boosted\": pd.read_json(\"../outputs/eval_summary_mtl_mtl_boosted.json\"),\n",
|
||||
" \"stl_aspect\": pd.read_json(\"../outputs/eval_summary_stl_aspect_original.json\"),\n",
|
||||
" \"stl_aspect_sentiment\": pd.read_json(\"../outputs/eval_summary_stl_aspect_sentiment_original.json\"),\n",
|
||||
" \"stl_bug_report\": pd.read_json(\"../outputs/eval_summary_stl_bug_report_original.json\"),\n",
|
||||
" \"stl_feature_request\": pd.read_json(\"../outputs/eval_summary_stl_feature_request_original.json\"),\n",
|
||||
"}\n",
|
||||
"for task, df in summaries.items():\n",
|
||||
" print(df.head())\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "e9ddb0d1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||
"Index(['mode', 'dataset', 'task', 'model_path', 'results'], dtype='object')\n",
|
||||
"mode mtl\n",
|
||||
"dataset original\n",
|
||||
"task all\n",
|
||||
"model_path outputs/best_model_mtl_original.pt\n",
|
||||
"results {'macro_f1': 0.7833333333333331, 'macro_precis...\n",
|
||||
"Name: bug_report, dtype: object\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(type(summaries[\"mtl_original\"]))\n",
|
||||
"print(summaries[\"mtl_original\"].columns)\n",
|
||||
"print(summaries[\"mtl_original\"].iloc[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "7fbdcc74",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" model task macro_f1 macro_precision macro_recall \\\n",
|
||||
"0 mtl_original bug_report 0.783333 0.768456 0.802785 \n",
|
||||
"1 mtl_original feature_request 0.763282 0.752334 0.776279 \n",
|
||||
"2 mtl_original aspect 0.717047 0.718659 0.717574 \n",
|
||||
"3 mtl_original aspect_sentiment 0.757465 0.771312 0.747191 \n",
|
||||
"4 mtl_boosted bug_report 0.905186 0.901647 0.909222 \n",
|
||||
"\n",
|
||||
" accuracy conf_overall conf_correct conf_incorrect \n",
|
||||
"0 0.861333 0.956886 0.971339 0.867113 \n",
|
||||
"1 0.869333 0.960423 0.971002 0.890037 \n",
|
||||
"2 0.736000 0.898030 0.927060 0.817099 \n",
|
||||
"3 0.912000 0.965079 0.976294 0.848845 \n",
|
||||
"4 0.913218 0.975117 0.983552 0.886357 \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Creating section 1, general summary table of macro_f1, macro_precision, macro_recall, accuracy inside per_class with confidence values\n",
|
||||
"\"\"\"\n",
|
||||
"{\n",
|
||||
" \"mode\": \"mtl\",\n",
|
||||
" \"dataset\": \"original\",\n",
|
||||
" \"task\": \"all\",\n",
|
||||
" \"model_path\": \"outputs/best_model_mtl_original.pt\",\n",
|
||||
" \"results\": {\n",
|
||||
" \"bug_report\": {\n",
|
||||
" \"macro_f1\": 0.7833333333333333,\n",
|
||||
" \"macro_precision\": 0.7684555303602922,\n",
|
||||
" \"macro_recall\": 0.8027848820687695,\n",
|
||||
" \"confidence\": {\n",
|
||||
" \"overall\": 0.9568860530853271,\n",
|
||||
" \"correct\": 0.97133868932724,\n",
|
||||
" \"incorrect\": 0.8671128153800964\n",
|
||||
" },\n",
|
||||
" \"per_class\": {\n",
|
||||
" \"No\": {\n",
|
||||
" \"precision\": 0.9319727891156463,\n",
|
||||
" \"recall\": 0.8954248366013072,\n",
|
||||
" \"f1-score\": 0.9133333333333333,\n",
|
||||
" \"support\": 612.0\n",
|
||||
" },\n",
|
||||
" \"Yes\": {\n",
|
||||
" \"precision\": 0.6049382716049383,\n",
|
||||
" \"recall\": 0.7101449275362319,\n",
|
||||
" \"f1-score\": 0.6533333333333333,\n",
|
||||
" \"support\": 138.0\n",
|
||||
" },\n",
|
||||
" \"accuracy\": 0.8613333333333333,\n",
|
||||
" \"macro avg\": {\n",
|
||||
" \"precision\": 0.7684555303602922,\n",
|
||||
" \"recall\": 0.8027848820687695,\n",
|
||||
" \"f1-score\": 0.7833333333333333,\n",
|
||||
" \"support\": 750.0\n",
|
||||
" },\n",
|
||||
" \"weighted avg\": {\n",
|
||||
" \"precision\": 0.8717984378936761,\n",
|
||||
" \"recall\": 0.8613333333333333,\n",
|
||||
" \"f1-score\": 0.8654933333333333,\n",
|
||||
" \"support\": 750.0\n",
|
||||
" }\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"rows = []\n",
|
||||
"# \"mtl_original\", pd.read_json(\"../outputs/eval_summary_mtl_mtl_original.json\")\n",
|
||||
"for model_name, df in summaries.items():\n",
|
||||
" \n",
|
||||
" for task_name, row in df.iterrows():\n",
|
||||
" \n",
|
||||
" task_result = row[\"results\"]\n",
|
||||
"\n",
|
||||
" rows.append({\n",
|
||||
" \"model\": model_name,\n",
|
||||
" \"task\": task_name,\n",
|
||||
"\n",
|
||||
" \"macro_f1\": task_result[\"macro_f1\"],\n",
|
||||
" \"macro_precision\": task_result[\"macro_precision\"],\n",
|
||||
" \"macro_recall\": task_result[\"macro_recall\"],\n",
|
||||
" \"accuracy\": task_result[\"per_class\"][\"accuracy\"],\n",
|
||||
"\n",
|
||||
" \"conf_overall\": task_result[\"confidence\"][\"overall\"],\n",
|
||||
" \"conf_correct\": task_result[\"confidence\"][\"correct\"],\n",
|
||||
" \"conf_incorrect\": task_result[\"confidence\"][\"incorrect\"],\n",
|
||||
"\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"analysis_summary_df = pd.DataFrame(rows)\n",
|
||||
"print(analysis_summary_df.head())\n",
|
||||
"analysis_summary_df.to_csv(\"analysis.csv\", index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b66b19f4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"id": "4147fda8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"model mtl_boosted mtl_original stl_aspect stl_aspect_sentiment \\\n",
|
||||
"task \n",
|
||||
"aspect 0.802578 0.717047 0.694127 NaN \n",
|
||||
"aspect_sentiment 0.600339 0.757465 NaN 0.785615 \n",
|
||||
"bug_report 0.905186 0.783333 NaN NaN \n",
|
||||
"feature_request 0.816422 0.763282 NaN NaN \n",
|
||||
"\n",
|
||||
"model stl_bug_report stl_feature_request \n",
|
||||
"task \n",
|
||||
"aspect NaN NaN \n",
|
||||
"aspect_sentiment NaN NaN \n",
|
||||
"bug_report 0.784503 NaN \n",
|
||||
"feature_request NaN 0.741924 \n",
|
||||
"\n",
|
||||
"final:\n",
|
||||
"\n",
|
||||
" model stl mtl_original mtl_boosted mtl_orig_excluding_stl \\\n",
|
||||
"task \n",
|
||||
"aspect 0.694127 0.717047 0.802578 0.02292 \n",
|
||||
"aspect_sentiment 0.785615 0.757465 0.600339 -0.02815 \n",
|
||||
"bug_report 0.784503 0.783333 0.905186 -0.00117 \n",
|
||||
"feature_request 0.741924 0.763282 0.816422 0.021357 \n",
|
||||
"\n",
|
||||
"model mtl_boost_excluding_orig \n",
|
||||
"task \n",
|
||||
"aspect 0.085532 \n",
|
||||
"aspect_sentiment -0.157126 \n",
|
||||
"bug_report 0.121852 \n",
|
||||
"feature_request 0.053140 \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pivot = analysis_summary_df.pivot(\n",
|
||||
" index=\"task\",\n",
|
||||
" columns=\"model\",\n",
|
||||
" values=\"macro_f1\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(pivot) # wrong, pull actual stl score\n",
|
||||
"\n",
|
||||
"pivot[\"stl\"] = None\n",
|
||||
"\n",
|
||||
"pivot.loc[\"bug_report\", \"stl\"] = pivot.loc[\"bug_report\", \"stl_bug_report\"]\n",
|
||||
"pivot.loc[\"feature_request\", \"stl\"] = pivot.loc[\"feature_request\", \"stl_feature_request\"]\n",
|
||||
"pivot.loc[\"aspect\", \"stl\"] = pivot.loc[\"aspect\", \"stl_aspect\"]\n",
|
||||
"pivot.loc[\"aspect_sentiment\", \"stl\"] = pivot.loc[\"aspect_sentiment\", \"stl_aspect_sentiment\"]\n",
|
||||
"\n",
|
||||
"pivot[\"mtl_orig_excluding_stl\"] = pivot[\"mtl_original\"] - pivot[\"stl\"]\n",
|
||||
"pivot[\"mtl_boost_excluding_orig\"] = pivot[\"mtl_boosted\"] - pivot[\"mtl_original\"]\n",
|
||||
"final = pivot[\n",
|
||||
" [\n",
|
||||
" \"stl\",\n",
|
||||
" \"mtl_original\",\n",
|
||||
" \"mtl_boosted\",\n",
|
||||
" \"mtl_orig_excluding_stl\",\n",
|
||||
" \"mtl_boost_excluding_orig\"\n",
|
||||
" ]\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(\"\\nfinal:\\n\\n\", final)\n",
|
||||
"\n",
|
||||
"final.to_csv(\"analysis_macro_f1_pivot.csv\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "1182dc66",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Per class df model task class precision recall \\\n",
|
||||
"0 mtl_original bug_report No 0.931973 0.895425 \n",
|
||||
"1 mtl_original bug_report Yes 0.604938 0.710145 \n",
|
||||
"2 mtl_original feature_request No 0.932149 0.911532 \n",
|
||||
"3 mtl_original feature_request Yes 0.572519 0.641026 \n",
|
||||
"4 mtl_original aspect App 0.787500 0.840000 \n",
|
||||
"5 mtl_original aspect Driver 0.786260 0.780303 \n",
|
||||
"6 mtl_original aspect General 0.803738 0.688000 \n",
|
||||
"7 mtl_original aspect Payment 0.617647 0.636364 \n",
|
||||
"8 mtl_original aspect Pricing 0.698630 0.750000 \n",
|
||||
"9 mtl_original aspect Service 0.618182 0.610778 \n",
|
||||
"10 mtl_original aspect_sentiment Positive 0.938776 0.953368 \n",
|
||||
"11 mtl_original aspect_sentiment Neutral 0.451613 0.358974 \n",
|
||||
"12 mtl_original aspect_sentiment Negative 0.923547 0.929231 \n",
|
||||
"13 mtl_boosted bug_report No 0.943515 0.922290 \n",
|
||||
"14 mtl_boosted bug_report Yes 0.859779 0.896154 \n",
|
||||
"15 mtl_boosted feature_request No 0.963343 0.969027 \n",
|
||||
"16 mtl_boosted feature_request Yes 0.686567 0.647887 \n",
|
||||
"17 mtl_boosted aspect App 0.932990 0.841860 \n",
|
||||
"18 mtl_boosted aspect Driver 0.839572 0.892045 \n",
|
||||
"19 mtl_boosted aspect General 0.829268 0.772727 \n",
|
||||
"20 mtl_boosted aspect Payment 0.711111 0.842105 \n",
|
||||
"21 mtl_boosted aspect Pricing 0.826087 0.802817 \n",
|
||||
"22 mtl_boosted aspect Service 0.669291 0.691057 \n",
|
||||
"23 mtl_boosted aspect_sentiment Positive 0.800000 0.882353 \n",
|
||||
"24 mtl_boosted aspect_sentiment Neutral 0.000000 0.000000 \n",
|
||||
"25 mtl_boosted aspect_sentiment Negative 0.968280 0.955519 \n",
|
||||
"26 stl_aspect aspect App 0.797357 0.804444 \n",
|
||||
"27 stl_aspect aspect Driver 0.717241 0.787879 \n",
|
||||
"28 stl_aspect aspect General 0.761905 0.640000 \n",
|
||||
"29 stl_aspect aspect Payment 0.625000 0.606061 \n",
|
||||
"30 stl_aspect aspect Pricing 0.561404 0.941176 \n",
|
||||
"31 stl_aspect aspect Service 0.692913 0.526946 \n",
|
||||
"32 stl_aspect_sentiment aspect_sentiment Positive 0.960000 0.932642 \n",
|
||||
"33 stl_aspect_sentiment aspect_sentiment Neutral 0.500000 0.461538 \n",
|
||||
"34 stl_aspect_sentiment aspect_sentiment Negative 0.911504 0.950769 \n",
|
||||
"35 stl_bug_report bug_report No 0.942105 0.877451 \n",
|
||||
"36 stl_bug_report bug_report Yes 0.583333 0.760870 \n",
|
||||
"37 stl_feature_request feature_request No 0.920635 0.916272 \n",
|
||||
"38 stl_feature_request feature_request Yes 0.558333 0.572650 \n",
|
||||
"\n",
|
||||
" f1 support \n",
|
||||
"0 0.913333 612.0 \n",
|
||||
"1 0.653333 138.0 \n",
|
||||
"2 0.921725 633.0 \n",
|
||||
"3 0.604839 117.0 \n",
|
||||
"4 0.812903 225.0 \n",
|
||||
"5 0.783270 132.0 \n",
|
||||
"6 0.741379 125.0 \n",
|
||||
"7 0.626866 33.0 \n",
|
||||
"8 0.723404 68.0 \n",
|
||||
"9 0.614458 167.0 \n",
|
||||
"10 0.946015 386.0 \n",
|
||||
"11 0.400000 39.0 \n",
|
||||
"12 0.926380 325.0 \n",
|
||||
"13 0.932782 489.0 \n",
|
||||
"14 0.877589 260.0 \n",
|
||||
"15 0.966176 678.0 \n",
|
||||
"16 0.666667 71.0 \n",
|
||||
"17 0.885086 215.0 \n",
|
||||
"18 0.865014 176.0 \n",
|
||||
"19 0.800000 88.0 \n",
|
||||
"20 0.771084 76.0 \n",
|
||||
"21 0.814286 71.0 \n",
|
||||
"22 0.680000 123.0 \n",
|
||||
"23 0.839161 136.0 \n",
|
||||
"24 0.000000 6.0 \n",
|
||||
"25 0.961857 607.0 \n",
|
||||
"26 0.800885 225.0 \n",
|
||||
"27 0.750903 132.0 \n",
|
||||
"28 0.695652 125.0 \n",
|
||||
"29 0.615385 33.0 \n",
|
||||
"30 0.703297 68.0 \n",
|
||||
"31 0.598639 167.0 \n",
|
||||
"32 0.946124 386.0 \n",
|
||||
"33 0.480000 39.0 \n",
|
||||
"34 0.930723 325.0 \n",
|
||||
"35 0.908629 612.0 \n",
|
||||
"36 0.660377 138.0 \n",
|
||||
"37 0.918448 633.0 \n",
|
||||
"38 0.565401 117.0 \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"per_class_rows = []\n",
|
||||
"skip = {\"accuracy\", \"macro avg\", \"weighted avg\"}\n",
|
||||
"\n",
|
||||
"for model_name, df in summaries.items():\n",
|
||||
" for task_name, row in df.iterrows():\n",
|
||||
" task_result = row[\"results\"]\n",
|
||||
" \n",
|
||||
" for class_name, class_metrics in task_result[\"per_class\"].items():\n",
|
||||
" if class_name in skip:\n",
|
||||
" continue\n",
|
||||
" per_class_rows.append({\n",
|
||||
" \"model\": model_name,\n",
|
||||
" \"task\": task_name,\n",
|
||||
" \"class\": class_name,\n",
|
||||
" \"precision\": class_metrics[\"precision\"],\n",
|
||||
" \"recall\": class_metrics[\"recall\"],\n",
|
||||
" \"f1\": class_metrics[\"f1-score\"],\n",
|
||||
" \"support\": class_metrics[\"support\"],\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"per_class_analysis_df = pd.DataFrame(per_class_rows)\n",
|
||||
"print(\"Per class df\", per_class_analysis_df)\n",
|
||||
"per_class_analysis_df.to_csv(\"per_class_analysis.csv\", index=False)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "f796bc9e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"label_names = {\n",
|
||||
" 'bug_report': ['No', 'Yes'],\n",
|
||||
" 'feature_request': ['No', 'Yes'],\n",
|
||||
" 'aspect': ['App', 'Driver', 'General', 'Payment', 'Pricing', 'Service'],\n",
|
||||
" 'aspect_sentiment': ['Positive', 'Neutral', 'Negative']\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "4cc572c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tasks = [\"bug_report\", \"feature_request\", \"aspect\", \"aspect_sentiment\"]\n",
|
||||
"mcnemar_rows = []\n",
|
||||
"\n",
|
||||
"mtl_path = \"../outputs/test_predictions_mtl_mtl_original.csv\"\n",
|
||||
"mtl_df = pd.read_csv(mtl_path)\n",
|
||||
"\n",
|
||||
"from statsmodels.stats.contingency_tables import mcnemar"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"id": "0ef9d5ca",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"McNemar Test Results (MTL vs STL Original):\n",
|
||||
" task mtl_only_correct stl_only_correct p_value significant\n",
|
||||
"0 bug_report 32 28 0.698883 False\n",
|
||||
"1 feature_request 42 37 0.652964 False\n",
|
||||
"2 aspect 70 55 0.210327 False\n",
|
||||
"3 aspect_sentiment 15 18 0.728332 False\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for task in tasks:\n",
|
||||
" stl_path = f\"../outputs/test_predictions_stl_{task}_original.csv\"\n",
|
||||
" \n",
|
||||
" if not os.path.exists(stl_path):\n",
|
||||
" print(f\"Warning: Missing STL file for {task}\")\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" stl_df = pd.read_csv(stl_path)\n",
|
||||
"\n",
|
||||
" # 3. Align ground truth\n",
|
||||
" # The 'task' column in the CSV is numeric (0, 1, 2...). \n",
|
||||
" # The '{task}_pred' column is a string ('Yes', 'No', 'App'...).\n",
|
||||
" # We map the numeric ground truth to labels so we can compare directly.\n",
|
||||
" y_true = [label_names[task][int(val)] for val in mtl_df[task]]\n",
|
||||
" \n",
|
||||
" # 4. Get predictions\n",
|
||||
" y_mtl = mtl_df[f\"{task}_pred\"].values\n",
|
||||
" y_stl = stl_df[f\"{task}_pred\"].values\n",
|
||||
"\n",
|
||||
" # 5. Determine correctness for each model\n",
|
||||
" mtl_correct = (y_mtl == y_true)\n",
|
||||
" stl_correct = (y_stl == y_true)\n",
|
||||
"\n",
|
||||
" # 6. Build the 2x2 Contingency Table\n",
|
||||
" # Table structure for McNemar:\n",
|
||||
" # MTL Correct | MTL Wrong\n",
|
||||
" # STL Correct | [0,0] | [0,1]\n",
|
||||
" # STL Wrong | [1,0] | [1,1]\n",
|
||||
" \n",
|
||||
" both_correct = np.sum(mtl_correct & stl_correct)\n",
|
||||
" stl_only = np.sum(~mtl_correct & stl_correct) # STL right, MTL wrong\n",
|
||||
" mtl_only = np.sum(mtl_correct & ~stl_correct) # MTL right, STL wrong\n",
|
||||
" both_wrong = np.sum(~mtl_correct & ~stl_correct)\n",
|
||||
"\n",
|
||||
" contingency_table = [\n",
|
||||
" [both_correct, stl_only],\n",
|
||||
" [mtl_only, both_wrong]\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" # 7. Run McNemar's Test\n",
|
||||
" # We use exact=True because some of our discordant cells (mtl_only/stl_only) might be small\n",
|
||||
" result = mcnemar(contingency_table, exact=True)\n",
|
||||
" \n",
|
||||
" mcnemar_rows.append({\n",
|
||||
" \"task\": task,\n",
|
||||
" \"both_correct\": both_correct,\n",
|
||||
" \"mtl_only_correct\": mtl_only,\n",
|
||||
" \"stl_only_correct\": stl_only,\n",
|
||||
" \"both_wrong\": both_wrong,\n",
|
||||
" \"p_value\": result.pvalue,\n",
|
||||
" \"significant\": result.pvalue < 0.05\n",
|
||||
" })\n",
|
||||
"\n",
|
||||
"# 8. Save and Display Results\n",
|
||||
"mcnemar_df = pd.DataFrame(mcnemar_rows)\n",
|
||||
"print(\"\\nMcNemar Test Results (MTL vs STL Original):\")\n",
|
||||
"print(mcnemar_df[[\"task\", \"mtl_only_correct\", \"stl_only_correct\", \"p_value\", \"significant\"]])\n",
|
||||
"\n",
|
||||
"mcnemar_df.to_csv(\"analysis_mcnemar_results.csv\", index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6afd6cf8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "multitag",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user