added timestamps to infer.py and tested it

This commit is contained in:
2026-04-01 02:33:47 +01:00
parent 82e6277cc1
commit 7fa67af6c0

View File

@@ -1,4 +1,5 @@
# infer.py
from datetime import datetime
import os
import torch
import time
@@ -19,6 +20,8 @@ from torch.utils.data import Dataset
from dataset import InferenceDataset
from model import Model, SingleTaskModel
label_names = {
'bug_report': ['No', 'Yes'],
'feature_request': ['No', 'Yes'],
@@ -31,6 +34,8 @@ torch.manual_seed(SEED)
np.random.seed(SEED)
def parse_args():
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
parser.add_argument("--model_path", type=str, required=True, help=".pt file in outputs/")
@@ -113,6 +118,8 @@ def main():
all_preds = {task: [] for task in active_tasks}
all_confidences = {task: [] for task in active_tasks}
print(f"Running inference on {args.dataset} dataset")
start_time = time.time()
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
with torch.no_grad():
for batch in infer_loader:
@@ -129,6 +136,7 @@ def main():
all_preds[task].extend(preds.cpu().numpy())
all_confidences[task].extend(confidence.cpu().numpy())
end_time = time.time()
df = pd.DataFrame({"text": infer_df[args.text_column]})
for task in active_tasks: # ensures ALL tasks included
@@ -137,10 +145,12 @@ def main():
output_path = filename
df.to_csv(output_path, index=False)
if not args.text:
print(f"Inference finished. Predictions saved to {output_path}")
print(f"Time taken: {end_time - start_time:.2f} seconds")
else:
print(f"Inference finished.\n")
print(f"Inference completed in {end_time - start_time:.2f} seconds.\n")
print(df.to_string(index=False))
again = input("Do you want to enter another text for inference? (y/n): ")
if again.lower() == 'y':
@@ -148,6 +158,7 @@ def main():
else:
print("Exiting interactive inference.")
if __name__ == "__main__":
main()