added timestamps to infer.py and tested it
This commit is contained in:
13
src/infer.py
13
src/infer.py
@@ -1,4 +1,5 @@
|
|||||||
# infer.py
|
# infer.py
|
||||||
|
from datetime import datetime
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
@@ -19,6 +20,8 @@ from torch.utils.data import Dataset
|
|||||||
from dataset import InferenceDataset
|
from dataset import InferenceDataset
|
||||||
from model import Model, SingleTaskModel
|
from model import Model, SingleTaskModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
label_names = {
|
label_names = {
|
||||||
'bug_report': ['No', 'Yes'],
|
'bug_report': ['No', 'Yes'],
|
||||||
'feature_request': ['No', 'Yes'],
|
'feature_request': ['No', 'Yes'],
|
||||||
@@ -31,6 +34,8 @@ torch.manual_seed(SEED)
|
|||||||
np.random.seed(SEED)
|
np.random.seed(SEED)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
||||||
parser.add_argument("--model_path", type=str, required=True, help=".pt file in outputs/")
|
parser.add_argument("--model_path", type=str, required=True, help=".pt file in outputs/")
|
||||||
@@ -113,6 +118,8 @@ def main():
|
|||||||
all_preds = {task: [] for task in active_tasks}
|
all_preds = {task: [] for task in active_tasks}
|
||||||
all_confidences = {task: [] for task in active_tasks}
|
all_confidences = {task: [] for task in active_tasks}
|
||||||
print(f"Running inference on {args.dataset} dataset")
|
print(f"Running inference on {args.dataset} dataset")
|
||||||
|
start_time = time.time()
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in infer_loader:
|
for batch in infer_loader:
|
||||||
@@ -129,6 +136,7 @@ def main():
|
|||||||
all_preds[task].extend(preds.cpu().numpy())
|
all_preds[task].extend(preds.cpu().numpy())
|
||||||
all_confidences[task].extend(confidence.cpu().numpy())
|
all_confidences[task].extend(confidence.cpu().numpy())
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
df = pd.DataFrame({"text": infer_df[args.text_column]})
|
df = pd.DataFrame({"text": infer_df[args.text_column]})
|
||||||
|
|
||||||
for task in active_tasks: # ensures ALL tasks included
|
for task in active_tasks: # ensures ALL tasks included
|
||||||
@@ -137,10 +145,12 @@ def main():
|
|||||||
|
|
||||||
output_path = filename
|
output_path = filename
|
||||||
df.to_csv(output_path, index=False)
|
df.to_csv(output_path, index=False)
|
||||||
|
|
||||||
if not args.text:
|
if not args.text:
|
||||||
print(f"Inference finished. Predictions saved to {output_path}")
|
print(f"Inference finished. Predictions saved to {output_path}")
|
||||||
|
print(f"Time taken: {end_time - start_time:.2f} seconds")
|
||||||
else:
|
else:
|
||||||
print(f"Inference finished.\n")
|
print(f"Inference completed in {end_time - start_time:.2f} seconds.\n")
|
||||||
print(df.to_string(index=False))
|
print(df.to_string(index=False))
|
||||||
again = input("Do you want to enter another text for inference? (y/n): ")
|
again = input("Do you want to enter another text for inference? (y/n): ")
|
||||||
if again.lower() == 'y':
|
if again.lower() == 'y':
|
||||||
@@ -149,5 +159,6 @@ def main():
|
|||||||
print("Exiting interactive inference.")
|
print("Exiting interactive inference.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user