Fixed evaluation indentation and other bugs

This commit is contained in:
2026-02-26 20:39:19 +00:00
parent 99896c0873
commit cabf8aa9b5
3 changed files with 111 additions and 106 deletions

4
.gitignore vendored
View File

@@ -11,3 +11,7 @@ backup/*.csv
runs/ runs/
outputs/ outputs/
*.pt *.pt
__pycache__/
*.png
*.jpg
*.json

Binary file not shown.

View File

@@ -75,14 +75,14 @@ def main():
all_preds = {task: [] for task in active_tasks} all_preds = {task: [] for task in active_tasks}
all_confidences = {task: [] for task in active_tasks} all_confidences = {task: [] for task in active_tasks}
print("Running inference on test set").upper() print("Running inference on test set")
with torch.no_grad(): with torch.no_grad():
for batch in test_loader: for batch in test_loader:
input_ids = batch['input_ids'].to(device) input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device) attention_mask = batch['attention_mask'].to(device)
outputs = model(input_ids, attention_mask) outputs = model(input_ids, attention_mask)
for task in active_tasks: for task in active_tasks:
labels = batch[task].to_device() labels = batch[task].to(device)
logits = outputs[task] logits = outputs[task]
preds = torch.argmax(logits, dim=1) preds = torch.argmax(logits, dim=1)
@@ -122,6 +122,7 @@ def main():
report_dict = classification_report( report_dict = classification_report(
labels_arr, labels_arr,
preds_arr,
target_names=label_names[task], target_names=label_names[task],
output_dict=True, output_dict=True,
zero_division=0 zero_division=0
@@ -138,13 +139,13 @@ def main():
# save summary to JSON # save summary to JSON
summary["results"][task] = { summary["results"][task] = {
"macro_f1": report_dict["macro avg"]["f1-score"], "macro_f1": float(report_dict["macro avg"]["f1-score"]),
"macro_precision": report_dict["macro avg"]["precision"], "macro_precision": float(report_dict["macro avg"]["precision"]),
"macro_recall": report_dict["macro avg"]["recall"], "macro_recall": float(report_dict["macro avg"]["recall"]),
"confidence": { "confidence": {
"overall": mean_conf, "overall": float(mean_conf),
"correct": mean_conf_correct, "correct": float(mean_conf_correct),
"incorrect": mean_conf_incorrect "incorrect": float(mean_conf_incorrect)
}, },
"per_class": report_dict "per_class": report_dict
} }
@@ -160,7 +161,7 @@ def main():
) )
ax.set_xlabel("Predicted Label", fontweight="bold") ax.set_xlabel("Predicted Label", fontweight="bold")
ax.set_ylabel("True Label", fontweight="bold") ax.set_ylabel("True Label", fontweight="bold")
ax.set_title(f"{task.replace("_", " ").title()} Confusion Matrix ({args.mode.upper()})", fontweight="bold") ax.set_title(f"{task.replace('_', ' ').title()} Confusion Matrix ({args.mode.upper()})", fontweight="bold")
run_name = args.task if args.mode == "stl" else "mtl" run_name = args.task if args.mode == "stl" else "mtl"
cm_path = f"outputs/figures/cm_{args.mode}_{args.dataset}_{task}.png" cm_path = f"outputs/figures/cm_{args.mode}_{args.dataset}_{task}.png"