From 7fa67af6c0e8094e4a18d8e76aedae4017badd65 Mon Sep 17 00:00:00 2001 From: charlie-rasberry Date: Wed, 1 Apr 2026 02:33:47 +0100 Subject: [PATCH] added timestamps to infer.py and tested it --- src/infer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/infer.py b/src/infer.py index 6076e7d..3955b8b 100644 --- a/src/infer.py +++ b/src/infer.py @@ -1,4 +1,5 @@ # infer.py +from datetime import datetime import os import torch import time @@ -19,6 +20,8 @@ from torch.utils.data import Dataset from dataset import InferenceDataset from model import Model, SingleTaskModel + + label_names = { 'bug_report': ['No', 'Yes'], 'feature_request': ['No', 'Yes'], @@ -31,6 +34,8 @@ torch.manual_seed(SEED) np.random.seed(SEED) + + def parse_args(): parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.") parser.add_argument("--model_path", type=str, required=True, help=".pt file in outputs/") @@ -113,6 +118,8 @@ def main(): all_preds = {task: [] for task in active_tasks} all_confidences = {task: [] for task in active_tasks} print(f"Running inference on {args.dataset} dataset") + start_time = time.time() + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') with torch.no_grad(): for batch in infer_loader: @@ -129,6 +136,7 @@ def main(): all_preds[task].extend(preds.cpu().numpy()) all_confidences[task].extend(confidence.cpu().numpy()) + end_time = time.time() df = pd.DataFrame({"text": infer_df[args.text_column]}) for task in active_tasks: # ensures ALL tasks included @@ -137,10 +145,12 @@ def main(): output_path = filename df.to_csv(output_path, index=False) + if not args.text: print(f"Inference finished. Predictions saved to {output_path}") + print(f"Time taken: {end_time - start_time:.2f} seconds") else: - print(f"Inference finished.\n") + print(f"Inference completed in {end_time - start_time:.2f} seconds.\n") print(df.to_string(index=False)) again = input("Do you want to enter another text for inference? (y/n): ") if again.lower() == 'y': @@ -148,6 +158,7 @@ def main(): else: print("Exiting interactive inference.") + if __name__ == "__main__": main() \ No newline at end of file