From 7bd68108d0ec82eb99d0f0c1b87484667e682a0c Mon Sep 17 00:00:00 2001 From: charlie-rasberry Date: Mon, 23 Feb 2026 12:54:23 +0000 Subject: [PATCH] Implemented initial training structure, adding further logic soon including loss, stopping, optimisation and loop --- src/__pycache__/model.cpython-313.pyc | Bin 0 -> 2924 bytes src/dataset.py | 14 +++--- src/train.py | 70 ++++++++++++++++++++++++-- 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 src/__pycache__/model.cpython-313.pyc diff --git a/src/__pycache__/model.cpython-313.pyc b/src/__pycache__/model.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f87f939135b2fd3b053114ab963abea1927be0b GIT binary patch literal 2924 zcmb7GO>7&-6`mz``6vGBhrhB-D|TYa7G*i6QU#42x3OLQ`9PY01FySSkxOYKa+jW6 zI;nc$DWK67wHhRHeX-C>m>vqlryhIlO$?Jr9>{=#^rD+u8F0{3-|UY_7(S!}aOS<4 z_ujmD?|W}~7zp?fjPp+)$_HMAexQs!IEy%}FbI8)3aG%$qd5jMa}Mm7b7E(j?m6b! zISzAa6h~vI;7p@%#n%^M8bGe0or z$Nn_pnS8Qm$EM&mZw@@80|$8@@IF46b67iG-am$TZ(*R|pK;(29|kT0zmTnCuXvDm zkD;8YhVT$S1X6>A0DV5JLAoAU9p;DgXD!jb#nv7|ze*s~03UD^$M`r{it@4ig+5V> zv@NF5VhW)Flt#;rWk)8oPtP+>61dXP)tl<7q{wR$Cc~6nRLfHJnu-@iZB;5SERxWz zg~hzOBw<}N&Tf?YFS6o|4(av3h79ruGX;bl1qM3{ToGMl3XUZD$nz0{S(-V_6(^*7!AY^t#(&#poHb&lM zQ@s_k#@LxCqUw=s5#>ESDVxe0Xtgm$L*jkl@HrgZyiI${#yb0=IGNJw)c7~ zXgLZFJ9FUefPMCV>K#?bPtbGr-A5B)jGv&#_URqPVi{$)8~;sR6HC&Ps;*v{o&2O) zo4_W7CYD4^+NXKmr-|C98P0?Vr#0$;K#w5EimVF)rUgMl6|B~TI+k=S%8FDbZb>Nt zJ~45XRHY&>lYtevT$U6;0{|0;q7e57SgorKow&;u;aUJy5o|9tmW35bEE9i467>d_ zOm0Bb>QYHJnGuW8Bt@5NFd5q!mav*stq{LuRKTJx;UVgyR}e~7QPTuLqqs_^(@io~ zdY3P164v-%u86oQOZ@ctskeCzm-rejuWbD{I7F3$2KKY|t7! z{*5QO8+&s*HvM&Mdh4B=-B_{3{@%YAO>R`ztJ~2`Cz{!g<~q?_Yv6?o#giLre_Go- z*Nsh5v46#rkG+q)n`hg1yYVU53G9v@e?0qWc60T~RChGjV!!buc2mb5&pevhy#3^4 zH+7M2#dlA?_2q|u`LO-lzhCK|zSLsB^^6^YTc~80J7s8!i~?ni)(ypEHA;GDk6VKU zX2zj7i_Uz1aa+gGn9&7*ySZL;Qj-b`mMHPcD6)158xrnUj)CoyfX0(QD!86xp9Hj^ z{84GWR>BPuN5$}32~UwLx{B_ir{K)8(9GBy83*yoO1+^Aa#0(AGG07N={M+f3MS%)(g5!?;xY--m#Vrf!(Pz##o_~&3nq)tfB}sL6)^nf5*N0DnM;TWy zMrPcM2-5Hx5zH{lkN!-@KQCuIv?(cKO%ep+69lN=MwQZjLAcuxtJXV70b)AB31>X;bu$PJxpr^O!}5Yy|iNV*&duSN#X}dMT3;r zpxqhwVA=-F{_W9KTw2L03Z@@A;sSg?@XVJ@(H%FvA6&#gqg>b&Z|HJW%hD4(Lru|7 zEAi?_e^Xpl;{}9Oq)&9h%1{7=Gu$H_q(`daajL<=K}{pfDq-#su8w6zhtJH4SeL#+ zX50+)dVX%E^b%eGEgFCADNF}WhGAa%5Oag+pvC8Chh=bP7`1!51~?JJ!? zwt4M)e|UZPLFKd2uK#%R>i7Q0@2)qmJ_`;#xccyHXZX!Fx0SurIdi!ie7|{pFC5zl ntOvdb-Dg`&YwDRN&=Ma+He&0s?>veBI8h{ha160Y<4ykqa#M_0 literal 0 HcmV?d00001 diff --git a/src/dataset.py b/src/dataset.py index 2be5173..379478e 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -43,14 +43,16 @@ class ReviewDataset(Dataset): return { 'input_ids': encoding['input_ids'].squeeze(0), 'attention_mask': encoding['attention_mask'].squeeze(0), - 'bug_report': torch.tensor(self.df.iloc[idx]['bug_report']), - 'feature_request': torch.tensor(self.df.iloc[idx]['feature_request']), - 'aspect': torch.tensor(self.df.iloc[idx]['aspect']), - 'aspect_sentiment': torch.tensor(self.df.iloc[idx]['aspect_sentiment']) + 'bug_report': torch.tensor(self.df.iloc[idx]['bug_report'], dtype=torch.long), + 'feature_request': torch.tensor(self.df.iloc[idx]['feature_request'], dtype=torch.long), + 'aspect': torch.tensor(self.df.iloc[idx]['aspect'], dtype=torch.long), + 'aspect_sentiment': torch.tensor(self.df.iloc[idx]['aspect_sentiment'], dtype=torch.long) } + +if __name__ == "__main__": + dataset = ReviewDataset("data/processed/original_train.csv", AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")) + print(dataset.__getitem__(1)) -# uber = ReviewDataset("data/processed/original_train.csv", AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")) -# print(uber.__getitem__(1)) diff --git a/src/train.py b/src/train.py index 46a59f7..e7734d6 100644 --- a/src/train.py +++ b/src/train.py @@ -1,7 +1,69 @@ -#train.py - +# train.py +import torch +from sklearn.utils.class_weight import compute_class_weight +import numpy as np +import torch.nn as nn +from torch.utils.data import DataLoader from transformers import AutoTokenizer +import pandas as pd + +from dataset import ReviewDataset +from model import Model + +# class weights, training loop and early stopping +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base") + +train = "data/processed/original_train.csv" +val = "data/processed/original_val.csv" +train_dataset = ReviewDataset(train, tokenizer) +val_dataset = ReviewDataset(val, tokenizer) +train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) +val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False) + +model = Model().to(device) -class multiTaskModel(): - tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base") \ No newline at end of file +# move input_ids, attention_mask and labels to device in each batch + +# ------------------- Class weights ------------------- +# Using weights inversely proportional to class frequencies to avoid majority class bias, +# prioritize useful bug reports / feature requests +def compute_weights(train_df, column): + classes = np.unique(train_df[column]) + weights = compute_class_weight(class_weight='balanced', classes=classes, y=train_df[column]) + return torch.tensor(weights, dtype=torch.float).to(device) + +# -------------------- Loss functions ------------------- +# just a later idea +# 1.0 * bug_loss + +# 1.0 * feature_loss + +# 0.5 * aspect_loss + +# 0.5 * sentiment_loss + + +# -------------------- Optimizer and scheduler ------------------- + + + + +# ------------------- Training loop ------------------- +# For each epoch: + + + + +# ------------------- Stopping logic ------------------- +# After each epoch, find mean of 4 macro f1 scores +# If there is no improvement for 3 epochs consecutively, stop training +# Prevents overfitting which saves time and resources + + + + +train_df = pd.read_csv(train) +bug_weights = compute_weights(train_df, 'bug_report') +feature_weights = compute_weights(train_df, 'feature_request') +aspect_weights = compute_weights(train_df, 'aspect') +aspect_sentiment_weights = compute_weights(train_df, 'aspect_sentiment') \ No newline at end of file