Added implementation for single task roberta, using args for everything made it simple
This commit is contained in:
Binary file not shown.
16
src/model.py
16
src/model.py
@@ -10,7 +10,21 @@ import torch.nn as nn
|
|||||||
|
|
||||||
# Each nn.linear is used to map RoBERTa's hidden representation onto the output space of each task head
|
# Each nn.linear is used to map RoBERTa's hidden representation onto the output space of each task head
|
||||||
# Each hidden representation is size 768
|
# Each hidden representation is size 768
|
||||||
class Model(nn.Module):
|
|
||||||
|
class SingleTaskModel(nn.Module): # SINGLE TASK MODEL ARCHITECTURE
|
||||||
|
def __init__(self, task_name, num_classes, dropout_rate=0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||||
|
self.droput = nn.Dropout(dropout_rate)
|
||||||
|
self.head = nn.Linear(self.encoder.config.hidden_size, num_classes)
|
||||||
|
self.task_name = task_name
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
|
output= self.droput(outputs.last_hidden_state[:, 0, :])
|
||||||
|
logits = self.head(output)
|
||||||
|
return {self.task_name: logits}
|
||||||
|
|
||||||
|
class Model(nn.Module): # MULTITASK MODEL ARCHITECTURE
|
||||||
def __init__(self, dropout_rate=0.2): # Try other p values
|
def __init__(self, dropout_rate=0.2): # Try other p values
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
|
self.encoder = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||||
|
|||||||
46
src/train.py
46
src/train.py
@@ -18,7 +18,14 @@ from sklearn.utils.class_weight import compute_class_weight
|
|||||||
|
|
||||||
|
|
||||||
from dataset import ReviewDataset
|
from dataset import ReviewDataset
|
||||||
from model import Model
|
from model import Model, SingleTaskModel
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# =======================================================================
|
||||||
|
# Multitask implementation
|
||||||
|
# =======================================================================
|
||||||
|
|
||||||
# NFR5, reproducibility
|
# NFR5, reproducibility
|
||||||
SEED = 4321
|
SEED = 4321
|
||||||
@@ -41,6 +48,8 @@ def compute_weights(df, column, device):
|
|||||||
# python src/train.py --epochs 15 NOTE: 8 - 12 epochs has seen best results so far
|
# python src/train.py --epochs 15 NOTE: 8 - 12 epochs has seen best results so far
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
parser = argparse.ArgumentParser(description="RECLASS, Multitask learning for review classification.")
|
||||||
|
parser.add_argument("--mode", type=str, default="mtl", choices=["mtl", "stl"], help="Choose between 'mtl' (multitask learning) and 'stl' (single task learning).")
|
||||||
|
parser.add_argument("--task", type=str, default="all", choices=["all", "bug_report", "feature_request", "aspect", "aspect_sentiment"], help="Specific task to train for stl usage only" )
|
||||||
parser.add_argument("--dataset", type=str, default="original", choices=["original", "boosted"], help="Choose between 'original' and 'boosted' dataset.")
|
parser.add_argument("--dataset", type=str, default="original", choices=["original", "boosted"], help="Choose between 'original' and 'boosted' dataset.")
|
||||||
parser.add_argument("--batch_size", type=int, default=16, help="Keep to 16 or 8 for 8GB VRAM")
|
parser.add_argument("--batch_size", type=int, default=16, help="Keep to 16 or 8 for 8GB VRAM")
|
||||||
parser.add_argument("--epochs", type=int, default=5, help="Maxiumum training epochs.")
|
parser.add_argument("--epochs", type=int, default=5, help="Maxiumum training epochs.")
|
||||||
@@ -81,7 +90,24 @@ def main():
|
|||||||
validation_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
|
validation_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)
|
||||||
|
|
||||||
# FR3, shared multilingual model with task-specific heads
|
# FR3, shared multilingual model with task-specific heads
|
||||||
model = Model().to(device)
|
if args.mode == "mtl":
|
||||||
|
model = Model().to(device)
|
||||||
|
active_tasks = ['bug_report', 'feature_request', 'aspect', 'aspect_sentiment']
|
||||||
|
run_name = f"mtl_{args.dataset}"
|
||||||
|
else:
|
||||||
|
if args.task == "all":
|
||||||
|
raise ValueError("For single task learning, please specify a task using --task argument.")
|
||||||
|
|
||||||
|
task_classes = {
|
||||||
|
'bug_report': 2,
|
||||||
|
'feature_request': 2,
|
||||||
|
'aspect': 6,
|
||||||
|
'aspect_sentiment': 3
|
||||||
|
}
|
||||||
|
model = SingleTaskModel(args.task, task_classes[args.task]).to(device)
|
||||||
|
active_tasks = [args.task]
|
||||||
|
run_name = f"stl_{args.task}_{args.dataset}"
|
||||||
|
|
||||||
train_df = pd.read_csv(train)
|
train_df = pd.read_csv(train)
|
||||||
|
|
||||||
# Class weights
|
# Class weights
|
||||||
@@ -128,7 +154,7 @@ def main():
|
|||||||
|
|
||||||
# ------------------- Training loop -------------------
|
# ------------------- Training loop -------------------
|
||||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
|
writer = SummaryWriter(f'runs/reclass_{run_name}_{timestamp}')
|
||||||
|
|
||||||
best_f1 = 0.0
|
best_f1 = 0.0
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
@@ -153,7 +179,7 @@ def main():
|
|||||||
outputs = model(input_ids, attention_mask)
|
outputs = model(input_ids, attention_mask)
|
||||||
|
|
||||||
loss = 0
|
loss = 0
|
||||||
for task in criterions.keys():
|
for task in active_tasks:
|
||||||
labels = batch[task].to(device)
|
labels = batch[task].to(device)
|
||||||
loss += criterions[task](outputs[task], labels)
|
loss += criterions[task](outputs[task], labels)
|
||||||
|
|
||||||
@@ -177,8 +203,8 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
total_val_loss = 0.0
|
total_val_loss = 0.0
|
||||||
|
|
||||||
all_preds = {task: [] for task in criterions.keys()}
|
all_preds = {task: [] for task in active_tasks}
|
||||||
all_labels ={task: [] for task in criterions.keys()}
|
all_labels ={task: [] for task in active_tasks}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in validation_loader:
|
for batch in validation_loader:
|
||||||
@@ -188,7 +214,7 @@ def main():
|
|||||||
outputs = model(input_ids, attention_mask)
|
outputs = model(input_ids, attention_mask)
|
||||||
|
|
||||||
v_loss = 0.0 # batch validation loss
|
v_loss = 0.0 # batch validation loss
|
||||||
for task in criterions.keys():
|
for task in active_tasks:
|
||||||
labels = batch[task].to(device)
|
labels = batch[task].to(device)
|
||||||
v_loss += criterions[task](outputs[task], labels).item() # detatch .item(*)
|
v_loss += criterions[task](outputs[task], labels).item() # detatch .item(*)
|
||||||
|
|
||||||
@@ -203,8 +229,8 @@ def main():
|
|||||||
# FR11, Performance evaluation
|
# FR11, Performance evaluation
|
||||||
print("\nValidation Metrics (MACRO F1):")
|
print("\nValidation Metrics (MACRO F1):")
|
||||||
epoch_f1 = []
|
epoch_f1 = []
|
||||||
for task in criterions.keys():
|
for task in active_tasks:
|
||||||
task_f1 = f1_score(all_labels[task], all_preds[task], average='macro')
|
task_f1 = f1_score(all_labels[task], all_preds[task], average='macro', zero_division=0)
|
||||||
epoch_f1.append(task_f1)
|
epoch_f1.append(task_f1)
|
||||||
writer.add_scalar(f"F1/val_{task}", task_f1, epoch)
|
writer.add_scalar(f"F1/val_{task}", task_f1, epoch)
|
||||||
print(f" {task}: {task_f1:.4f}")
|
print(f" {task}: {task_f1:.4f}")
|
||||||
@@ -218,7 +244,7 @@ def main():
|
|||||||
best_f1 = avg_macro_f1
|
best_f1 = avg_macro_f1
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
# Save the model with a name for the type of dataset and epoch for later analysis
|
# Save the model with a name for the type of dataset and epoch for later analysis
|
||||||
model_save_path = f"outputs/best_model_{args.dataset}.pt"
|
model_save_path = f"outputs/best_model_{run_name}.pt"
|
||||||
torch.save(model.state_dict(), model_save_path)
|
torch.save(model.state_dict(), model_save_path)
|
||||||
print(" New best model saved to:", model_save_path)
|
print(" New best model saved to:", model_save_path)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user