Added implementation for single task roberta, using args for everything made it simple

This commit is contained in:
2026-02-26 18:21:13 +00:00
parent 01e2142276
commit 96a0c45e84
3 changed files with 51 additions and 11 deletions

View File

@@ -10,7 +10,21 @@ import torch.nn as nn
# Each nn.linear is used to map RoBERTa's hidden representation onto the output space of each task head # Each nn.linear is used to map RoBERTa's hidden representation onto the output space of each task head
# Each hidden representation is size 768 # Each hidden representation is size 768
class Model(nn.Module):
class SingleTaskModel(nn.Module): # SINGLE TASK MODEL ARCHITECTURE
def __init__(self, task_name, num_classes, dropout_rate=0.2):
super().__init__()
self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
self.droput = nn.Dropout(dropout_rate)
self.head = nn.Linear(self.encoder.config.hidden_size, num_classes)
self.task_name = task_name
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
output= self.droput(outputs.last_hidden_state[:, 0, :])
logits = self.head(output)
return {self.task_name: logits}
class Model(nn.Module): # MULTITASK MODEL ARCHITECTURE
def __init__(self, dropout_rate=0.2): # Try other p values def __init__(self, dropout_rate=0.2): # Try other p values
super().__init__() super().__init__()
self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base") self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")

View File

@@ -18,7 +18,14 @@ from sklearn.utils.class_weight import compute_class_weight
from dataset import ReviewDataset from dataset import ReviewDataset
from model import Model from model import Model, SingleTaskModel
# =======================================================================
# Multitask implementation
# =======================================================================
# NFR5, reproducibility # NFR5, reproducibility
SEED = 4321 SEED = 4321
@@ -41,6 +48,8 @@ def compute_weights(df, column, device):
# python src/train.py --epochs 15 NOTE: 8 - 12 epochs has seen best results so far # python src/train.py --epochs 15 NOTE: 8 - 12 epochs has seen best results so far
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.") parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
parser.add_argument("--mode", type=str, default="mtl", choices=["mtl", "stl"], help="Choose between 'mtl' (multitask learning) and 'stl' (single task learning).")
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" )
parser.add_argument("--dataset", type=str, default="original", choices=["original", "boosted"], help="Choose between 'original' and 'boosted' dataset.") parser.add_argument("--dataset", type=str, default="original", choices=["original", "boosted"], help="Choose between 'original' and 'boosted' dataset.")
parser.add_argument("--batch_size", type=int, default=16, help="Keep to 16 or 8 for 8GB VRAM") parser.add_argument("--batch_size", type=int, default=16, help="Keep to 16 or 8 for 8GB VRAM")
parser.add_argument("--epochs", type=int, default=5, help="Maxiumum training epochs.") parser.add_argument("--epochs", type=int, default=5, help="Maxiumum training epochs.")
@@ -81,7 +90,24 @@ def main():
validation_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) validation_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
# FR3, shared multilingual model with task-specific heads # FR3, shared multilingual model with task-specific heads
if args.mode == "mtl":
model = Model().to(device) model = Model().to(device)
active_tasks = ['bug_report', 'feature_request', 'aspect', 'aspect_sentiment']
run_name = f"mtl_{args.dataset}"
else:
if args.task == "all":
raise ValueError("For single task learning, please specify a task using --task argument.")
task_classes = {
'bug_report': 2,
'feature_request': 2,
'aspect': 6,
'aspect_sentiment': 3
}
model = SingleTaskModel(args.task, task_classes[args.task]).to(device)
active_tasks = [args.task]
run_name = f"stl_{args.task}_{args.dataset}"
train_df = pd.read_csv(train) train_df = pd.read_csv(train)
# Class weights # Class weights
@@ -128,7 +154,7 @@ def main():
# ------------------- Training loop ------------------- # ------------------- Training loop -------------------
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp)) writer = SummaryWriter(f'runs/reclass_{run_name}_{timestamp}')
best_f1 = 0.0 best_f1 = 0.0
patience_counter = 0 patience_counter = 0
@@ -153,7 +179,7 @@ def main():
outputs = model(input_ids, attention_mask) outputs = model(input_ids, attention_mask)
loss = 0 loss = 0
for task in criterions.keys(): for task in active_tasks:
labels = batch[task].to(device) labels = batch[task].to(device)
loss += criterions[task](outputs[task], labels) loss += criterions[task](outputs[task], labels)
@@ -177,8 +203,8 @@ def main():
model.eval() model.eval()
total_val_loss = 0.0 total_val_loss = 0.0
all_preds = {task: [] for task in criterions.keys()} all_preds = {task: [] for task in active_tasks}
all_labels ={task: [] for task in criterions.keys()} all_labels ={task: [] for task in active_tasks}
with torch.no_grad(): with torch.no_grad():
for batch in validation_loader: for batch in validation_loader:
@@ -188,7 +214,7 @@ def main():
outputs = model(input_ids, attention_mask) outputs = model(input_ids, attention_mask)
v_loss = 0.0 # batch validation loss v_loss = 0.0 # batch validation loss
for task in criterions.keys(): for task in active_tasks:
labels = batch[task].to(device) labels = batch[task].to(device)
v_loss += criterions[task](outputs[task], labels).item() # detatch .item(*) v_loss += criterions[task](outputs[task], labels).item() # detatch .item(*)
@@ -203,8 +229,8 @@ def main():
# FR11, Performance evaluation # FR11, Performance evaluation
print("\nValidation Metrics (MACRO F1):") print("\nValidation Metrics (MACRO F1):")
epoch_f1 = [] epoch_f1 = []
for task in criterions.keys(): for task in active_tasks:
task_f1 = f1_score(all_labels[task], all_preds[task], average='macro') task_f1 = f1_score(all_labels[task], all_preds[task], average='macro', zero_division=0)
epoch_f1.append(task_f1) epoch_f1.append(task_f1)
writer.add_scalar(f"F1/val_{task}", task_f1, epoch) writer.add_scalar(f"F1/val_{task}", task_f1, epoch)
print(f" {task}: {task_f1:.4f}") print(f" {task}: {task_f1:.4f}")
@@ -218,7 +244,7 @@ def main():
best_f1 = avg_macro_f1 best_f1 = avg_macro_f1
patience_counter = 0 patience_counter = 0
# Save the model with a name for the type of dataset and epoch for later analysis # Save the model with a name for the type of dataset and epoch for later analysis
model_save_path = f"outputs/best_model_{args.dataset}.pt" model_save_path = f"outputs/best_model_{run_name}.pt"
torch.save(model.state_dict(), model_save_path) torch.save(model.state_dict(), model_save_path)
print(" New best model saved to:", model_save_path) print(" New best model saved to:", model_save_path)
else: else: