added timestamps to infer.py and tested it
This commit is contained in:
13
src/infer.py
13
src/infer.py
@@ -1,4 +1,5 @@
|
||||
# infer.py
|
||||
from datetime import datetime
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
@@ -19,6 +20,8 @@ from torch.utils.data import Dataset
|
||||
from dataset import InferenceDataset
|
||||
from model import Model, SingleTaskModel
|
||||
|
||||
|
||||
|
||||
label_names = {
|
||||
'bug_report': ['No', 'Yes'],
|
||||
'feature_request': ['No', 'Yes'],
|
||||
@@ -31,6 +34,8 @@ torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
||||
parser.add_argument("--model_path", type=str, required=True, help=".pt file in outputs/")
|
||||
@@ -113,6 +118,8 @@ def main():
|
||||
all_preds = {task: [] for task in active_tasks}
|
||||
all_confidences = {task: [] for task in active_tasks}
|
||||
print(f"Running inference on {args.dataset} dataset")
|
||||
start_time = time.time()
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in infer_loader:
|
||||
@@ -129,6 +136,7 @@ def main():
|
||||
all_preds[task].extend(preds.cpu().numpy())
|
||||
all_confidences[task].extend(confidence.cpu().numpy())
|
||||
|
||||
end_time = time.time()
|
||||
df = pd.DataFrame({"text": infer_df[args.text_column]})
|
||||
|
||||
for task in active_tasks: # ensures ALL tasks included
|
||||
@@ -137,10 +145,12 @@ def main():
|
||||
|
||||
output_path = filename
|
||||
df.to_csv(output_path, index=False)
|
||||
|
||||
if not args.text:
|
||||
print(f"Inference finished. Predictions saved to {output_path}")
|
||||
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
||||
else:
|
||||
print(f"Inference finished.\n")
|
||||
print(f"Inference completed in {end_time - start_time:.2f} seconds.\n")
|
||||
print(df.to_string(index=False))
|
||||
again = input("Do you want to enter another text for inference? (y/n): ")
|
||||
if again.lower() == 'y':
|
||||
@@ -149,5 +159,6 @@ def main():
|
||||
print("Exiting interactive inference.")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user