Fixed evaluation indentation and other bugs
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -11,3 +11,7 @@ backup/*.csv
|
||||
runs/
|
||||
outputs/
|
||||
*.pt
|
||||
__pycache__/
|
||||
*.png
|
||||
*.jpg
|
||||
*.json
|
||||
|
||||
BIN
src/__pycache__/model.cpython-311.pyc
Normal file
BIN
src/__pycache__/model.cpython-311.pyc
Normal file
Binary file not shown.
@@ -75,14 +75,14 @@ def main():
|
||||
all_preds = {task: [] for task in active_tasks}
|
||||
all_confidences = {task: [] for task in active_tasks}
|
||||
|
||||
print("Running inference on test set").upper()
|
||||
print("Running inference on test set")
|
||||
with torch.no_grad():
|
||||
for batch in test_loader:
|
||||
input_ids = batch['input_ids'].to(device)
|
||||
attention_mask = batch['attention_mask'].to(device)
|
||||
outputs = model(input_ids, attention_mask)
|
||||
for task in active_tasks:
|
||||
labels = batch[task].to_device()
|
||||
labels = batch[task].to(device)
|
||||
logits = outputs[task]
|
||||
preds = torch.argmax(logits, dim=1)
|
||||
|
||||
@@ -122,6 +122,7 @@ def main():
|
||||
|
||||
report_dict = classification_report(
|
||||
labels_arr,
|
||||
preds_arr,
|
||||
target_names=label_names[task],
|
||||
output_dict=True,
|
||||
zero_division=0
|
||||
@@ -138,13 +139,13 @@ def main():
|
||||
|
||||
# save summary to JSON
|
||||
summary["results"][task] = {
|
||||
"macro_f1": report_dict["macro avg"]["f1-score"],
|
||||
"macro_precision": report_dict["macro avg"]["precision"],
|
||||
"macro_recall": report_dict["macro avg"]["recall"],
|
||||
"macro_f1": float(report_dict["macro avg"]["f1-score"]),
|
||||
"macro_precision": float(report_dict["macro avg"]["precision"]),
|
||||
"macro_recall": float(report_dict["macro avg"]["recall"]),
|
||||
"confidence": {
|
||||
"overall": mean_conf,
|
||||
"correct": mean_conf_correct,
|
||||
"incorrect": mean_conf_incorrect
|
||||
"overall": float(mean_conf),
|
||||
"correct": float(mean_conf_correct),
|
||||
"incorrect": float(mean_conf_incorrect)
|
||||
},
|
||||
"per_class": report_dict
|
||||
}
|
||||
@@ -160,7 +161,7 @@ def main():
|
||||
)
|
||||
ax.set_xlabel("Predicted Label", fontweight="bold")
|
||||
ax.set_ylabel("True Label", fontweight="bold")
|
||||
ax.set_title(f"{task.replace("_", " ").title()} Confusion Matrix ({args.mode.upper()})", fontweight="bold")
|
||||
ax.set_title(f"{task.replace('_', ' ').title()} Confusion Matrix ({args.mode.upper()})", fontweight="bold")
|
||||
|
||||
run_name = args.task if args.mode == "stl" else "mtl"
|
||||
cm_path = f"outputs/figures/cm_{args.mode}_{args.dataset}_{task}.png"
|
||||
|
||||
Reference in New Issue
Block a user