Added implementation for single task roberta, using args for everything made it simple
This commit is contained in:
Binary file not shown.
16
src/model.py
16
src/model.py
@@ -10,7 +10,21 @@ import torch.nn as nn
|
||||
|
||||
# Each nn.linear is used to map RoBERTa's hidden representation onto the output space of each task head
|
||||
# Each hidden representation is size 768
|
||||
class Model(nn.Module):
|
||||
|
||||
class SingleTaskModel(nn.Module): # SINGLE TASK MODEL ARCHITECTURE
|
||||
def __init__(self, task_name, num_classes, dropout_rate=0.2):
|
||||
super().__init__()
|
||||
self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
self.droput = nn.Dropout(dropout_rate)
|
||||
self.head = nn.Linear(self.encoder.config.hidden_size, num_classes)
|
||||
self.task_name = task_name
|
||||
def forward(self, input_ids, attention_mask):
|
||||
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
output= self.droput(outputs.last_hidden_state[:, 0, :])
|
||||
logits = self.head(output)
|
||||
return {self.task_name: logits}
|
||||
|
||||
class Model(nn.Module): # MULTITASK MODEL ARCHITECTURE
|
||||
def __init__(self, dropout_rate=0.2): # Try other p values
|
||||
super().__init__()
|
||||
self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
|
||||
44
src/train.py
44
src/train.py
@@ -18,7 +18,14 @@ from sklearn.utils.class_weight import compute_class_weight
|
||||
|
||||
|
||||
from dataset import ReviewDataset
|
||||
from model import Model
|
||||
from model import Model, SingleTaskModel
|
||||
|
||||
|
||||
|
||||
|
||||
# =======================================================================
|
||||
# Multitask implementation
|
||||
# =======================================================================
|
||||
|
||||
# NFR5, reproducibility
|
||||
SEED = 4321
|
||||
@@ -41,6 +48,8 @@ def compute_weights(df, column, device):
|
||||
# python src/train.py --epochs 15 NOTE: 8 - 12 epochs has seen best results so far
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
||||
parser.add_argument("--mode", type=str, default="mtl", choices=["mtl", "stl"], help="Choose between 'mtl' (multitask learning) and 'stl' (single task learning).")
|
||||
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" )
|
||||
parser.add_argument("--dataset", type=str, default="original", choices=["original", "boosted"], help="Choose between 'original' and 'boosted' dataset.")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="Keep to 16 or 8 for 8GB VRAM")
|
||||
parser.add_argument("--epochs", type=int, default=5, help="Maxiumum training epochs.")
|
||||
@@ -81,7 +90,24 @@ def main():
|
||||
validation_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
# FR3, shared multilingual model with task-specific heads
|
||||
if args.mode == "mtl":
|
||||
model = Model().to(device)
|
||||
active_tasks = ['bug_report', 'feature_request', 'aspect', 'aspect_sentiment']
|
||||
run_name = f"mtl_{args.dataset}"
|
||||
else:
|
||||
if args.task == "all":
|
||||
raise ValueError("For single task learning, please specify a task using --task argument.")
|
||||
|
||||
task_classes = {
|
||||
'bug_report': 2,
|
||||
'feature_request': 2,
|
||||
'aspect': 6,
|
||||
'aspect_sentiment': 3
|
||||
}
|
||||
model = SingleTaskModel(args.task, task_classes[args.task]).to(device)
|
||||
active_tasks = [args.task]
|
||||
run_name = f"stl_{args.task}_{args.dataset}"
|
||||
|
||||
train_df = pd.read_csv(train)
|
||||
|
||||
# Class weights
|
||||
@@ -128,7 +154,7 @@ def main():
|
||||
|
||||
# ------------------- Training loop -------------------
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
|
||||
writer = SummaryWriter(f'runs/reclass_{run_name}_{timestamp}')
|
||||
|
||||
best_f1 = 0.0
|
||||
patience_counter = 0
|
||||
@@ -153,7 +179,7 @@ def main():
|
||||
outputs = model(input_ids, attention_mask)
|
||||
|
||||
loss = 0
|
||||
for task in criterions.keys():
|
||||
for task in active_tasks:
|
||||
labels = batch[task].to(device)
|
||||
loss += criterions[task](outputs[task], labels)
|
||||
|
||||
@@ -177,8 +203,8 @@ def main():
|
||||
model.eval()
|
||||
total_val_loss = 0.0
|
||||
|
||||
all_preds = {task: [] for task in criterions.keys()}
|
||||
all_labels ={task: [] for task in criterions.keys()}
|
||||
all_preds = {task: [] for task in active_tasks}
|
||||
all_labels ={task: [] for task in active_tasks}
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in validation_loader:
|
||||
@@ -188,7 +214,7 @@ def main():
|
||||
outputs = model(input_ids, attention_mask)
|
||||
|
||||
v_loss = 0.0 # batch validation loss
|
||||
for task in criterions.keys():
|
||||
for task in active_tasks:
|
||||
labels = batch[task].to(device)
|
||||
v_loss += criterions[task](outputs[task], labels).item() # detatch .item(*)
|
||||
|
||||
@@ -203,8 +229,8 @@ def main():
|
||||
# FR11, Performance evaluation
|
||||
print("\nValidation Metrics (MACRO F1):")
|
||||
epoch_f1 = []
|
||||
for task in criterions.keys():
|
||||
task_f1 = f1_score(all_labels[task], all_preds[task], average='macro')
|
||||
for task in active_tasks:
|
||||
task_f1 = f1_score(all_labels[task], all_preds[task], average='macro', zero_division=0)
|
||||
epoch_f1.append(task_f1)
|
||||
writer.add_scalar(f"F1/val_{task}", task_f1, epoch)
|
||||
print(f" {task}: {task_f1:.4f}")
|
||||
@@ -218,7 +244,7 @@ def main():
|
||||
best_f1 = avg_macro_f1
|
||||
patience_counter = 0
|
||||
# Save the model with a name for the type of dataset and epoch for later analysis
|
||||
model_save_path = f"outputs/best_model_{args.dataset}.pt"
|
||||
model_save_path = f"outputs/best_model_{run_name}.pt"
|
||||
torch.save(model.state_dict(), model_save_path)
|
||||
print(" New best model saved to:", model_save_path)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user