Fixed evaluation indentation and other bugs
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -11,3 +11,7 @@ backup/*.csv
|
|||||||
runs/
|
runs/
|
||||||
outputs/
|
outputs/
|
||||||
*.pt
|
*.pt
|
||||||
|
__pycache__/
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.json
|
||||||
|
|||||||
BIN
src/__pycache__/model.cpython-311.pyc
Normal file
BIN
src/__pycache__/model.cpython-311.pyc
Normal file
Binary file not shown.
@@ -75,14 +75,14 @@ def main():
|
|||||||
all_preds = {task: [] for task in active_tasks}
|
all_preds = {task: [] for task in active_tasks}
|
||||||
all_confidences = {task: [] for task in active_tasks}
|
all_confidences = {task: [] for task in active_tasks}
|
||||||
|
|
||||||
print("Running inference on test set").upper()
|
print("Running inference on test set")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in test_loader:
|
for batch in test_loader:
|
||||||
input_ids = batch['input_ids'].to(device)
|
input_ids = batch['input_ids'].to(device)
|
||||||
attention_mask = batch['attention_mask'].to(device)
|
attention_mask = batch['attention_mask'].to(device)
|
||||||
outputs = model(input_ids, attention_mask)
|
outputs = model(input_ids, attention_mask)
|
||||||
for task in active_tasks:
|
for task in active_tasks:
|
||||||
labels = batch[task].to_device()
|
labels = batch[task].to(device)
|
||||||
logits = outputs[task]
|
logits = outputs[task]
|
||||||
preds = torch.argmax(logits, dim=1)
|
preds = torch.argmax(logits, dim=1)
|
||||||
|
|
||||||
@@ -122,6 +122,7 @@ def main():
|
|||||||
|
|
||||||
report_dict = classification_report(
|
report_dict = classification_report(
|
||||||
labels_arr,
|
labels_arr,
|
||||||
|
preds_arr,
|
||||||
target_names=label_names[task],
|
target_names=label_names[task],
|
||||||
output_dict=True,
|
output_dict=True,
|
||||||
zero_division=0
|
zero_division=0
|
||||||
@@ -138,13 +139,13 @@ def main():
|
|||||||
|
|
||||||
# save summary to JSON
|
# save summary to JSON
|
||||||
summary["results"][task] = {
|
summary["results"][task] = {
|
||||||
"macro_f1": report_dict["macro avg"]["f1-score"],
|
"macro_f1": float(report_dict["macro avg"]["f1-score"]),
|
||||||
"macro_precision": report_dict["macro avg"]["precision"],
|
"macro_precision": float(report_dict["macro avg"]["precision"]),
|
||||||
"macro_recall": report_dict["macro avg"]["recall"],
|
"macro_recall": float(report_dict["macro avg"]["recall"]),
|
||||||
"confidence": {
|
"confidence": {
|
||||||
"overall": mean_conf,
|
"overall": float(mean_conf),
|
||||||
"correct": mean_conf_correct,
|
"correct": float(mean_conf_correct),
|
||||||
"incorrect": mean_conf_incorrect
|
"incorrect": float(mean_conf_incorrect)
|
||||||
},
|
},
|
||||||
"per_class": report_dict
|
"per_class": report_dict
|
||||||
}
|
}
|
||||||
@@ -160,7 +161,7 @@ def main():
|
|||||||
)
|
)
|
||||||
ax.set_xlabel("Predicted Label", fontweight="bold")
|
ax.set_xlabel("Predicted Label", fontweight="bold")
|
||||||
ax.set_ylabel("True Label", fontweight="bold")
|
ax.set_ylabel("True Label", fontweight="bold")
|
||||||
ax.set_title(f"{task.replace("_", " ").title()} Confusion Matrix ({args.mode.upper()})", fontweight="bold")
|
ax.set_title(f"{task.replace('_', ' ').title()} Confusion Matrix ({args.mode.upper()})", fontweight="bold")
|
||||||
|
|
||||||
run_name = args.task if args.mode == "stl" else "mtl"
|
run_name = args.task if args.mode == "stl" else "mtl"
|
||||||
cm_path = f"outputs/figures/cm_{args.mode}_{args.dataset}_{task}.png"
|
cm_path = f"outputs/figures/cm_{args.mode}_{args.dataset}_{task}.png"
|
||||||
|
|||||||
Reference in New Issue
Block a user