From 4f0c54fe28210229740db6dcf1a786bc6c8cd67e Mon Sep 17 00:00:00 2001 From: charlie-rasberry Date: Mon, 23 Feb 2026 16:26:48 +0000 Subject: [PATCH] Added training loop for the MTL architecture on the original distribution --- .gitignore | 2 + ...events.out.tfevents.1771860635.mik.38652.0 | Bin 0 -> 88 bytes ...events.out.tfevents.1771860871.mik.34960.0 | Bin 0 -> 88 bytes ...events.out.tfevents.1771861405.mik.44772.0 | Bin 0 -> 88 bytes ...events.out.tfevents.1771861946.mik.10220.0 | Bin 0 -> 1893 bytes src/__pycache__/dataset.cpython-313.pyc | Bin 2184 -> 2543 bytes src/dataset.py | 3 +- src/train.py | 206 +++++++++++++++--- 8 files changed, 174 insertions(+), 37 deletions(-) create mode 100644 runs/fashion_trainer_20260223_153035/events.out.tfevents.1771860635.mik.38652.0 create mode 100644 runs/fashion_trainer_20260223_153431/events.out.tfevents.1771860871.mik.34960.0 create mode 100644 runs/fashion_trainer_20260223_154325/events.out.tfevents.1771861405.mik.44772.0 create mode 100644 runs/fashion_trainer_20260223_155226/events.out.tfevents.1771861946.mik.10220.0 diff --git a/.gitignore b/.gitignore index f57e31f..acb32c8 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ models/ .ipynb_checkpoints/ *.csv backup/*.csv +runs/ +outputs/ \ No newline at end of file diff --git a/runs/fashion_trainer_20260223_153035/events.out.tfevents.1771860635.mik.38652.0 b/runs/fashion_trainer_20260223_153035/events.out.tfevents.1771860635.mik.38652.0 new file mode 100644 index 0000000000000000000000000000000000000000..cefdbaf95f2f518e6dcdee5d4b2ca73c14b81a01 GIT binary patch literal 88 zcmeZZfPjCKJmzxlJossuO!_THDc+=_#LPTB*Rs^S5-X!1JuaP+)V$*SqNM!9q7=R2 h(%js{qDsB;qRf)iBE3|Qs`#|boYZ)TNWU)sbpTN4Ad>(9 literal 0 HcmV?d00001 diff --git a/runs/fashion_trainer_20260223_153431/events.out.tfevents.1771860871.mik.34960.0 b/runs/fashion_trainer_20260223_153431/events.out.tfevents.1771860871.mik.34960.0 new file mode 100644 index 0000000000000000000000000000000000000000..d930585742770f4345d808a9c5449e933acdea43 GIT binary patch literal 88 zcmeZZfPjCKJmzvPvcCOLCjFM96mL>dVrHJ6YguYuiIq{19+yr@YF=@EQBrM*}QKepaQD#YMkzOiDReV}zPHH?vWG!p0CIBs_AJ6~* literal 0 HcmV?d00001 diff --git a/runs/fashion_trainer_20260223_154325/events.out.tfevents.1771861405.mik.44772.0 b/runs/fashion_trainer_20260223_154325/events.out.tfevents.1771861405.mik.44772.0 new file mode 100644 index 0000000000000000000000000000000000000000..73e50226aec54e7d47598ad4e7ac5c3603432dc4 GIT binary patch literal 88 zcmeZZfPjCKJmzxV5e!e4O~2(R#hX-=n3<>NT9%quVr3Mh$E8z}npd1(l$4)Xl%iK$ hnwy(gRH;{9lv$Emq?Za(6`z)wlNt{ZdF1qXH303MAUyy8 literal 0 HcmV?d00001 diff --git a/runs/fashion_trainer_20260223_155226/events.out.tfevents.1771861946.mik.10220.0 b/runs/fashion_trainer_20260223_155226/events.out.tfevents.1771861946.mik.10220.0 new file mode 100644 index 0000000000000000000000000000000000000000..b53d518d2457727514689ccc48bf98a89959fee9 GIT binary patch literal 1893 zcmaLTYe-XJ9LI5`BF|GRmmOy-O>-`CGSTvuTMudGvYB8lQn7W^qs?@yv&}5K5jlq! zNG&YQR5B4MO_ZRUq(&D?$?8RkB^YRym&}YPx~%-adEmJ7;(2lSeV+$DDu>zccbTf1 z5{Hj*1$)fL>_bRK$&N%tN_T$~)X-YN;TQPD&!mR=Rwxu@3oT25| zg9gq-SBJ>4+BMMX3sG}P=?o>I*UUb!8o47^h|*ccWad~~c*x8eO?2^YF;?E+6YB+~ zW8T;|sx=f!SF{`npk#~(*;Et#Dqf7Q{P^4@gTS(~dpqTv1c4{=Kq4a08f|R4Y}-K( z9SFuB3hMStAdQfnilx-akgh-?JfplZ*C}&?k;bc%k#IYZQ2X zuz`Rs76(arprsQ73gM%9!36ZlNZ8H+EsfoRGk^RkC!sSSvIngh_rv?X*W?n=-9=FH zK$ji&!&4VKnh5B^W(e#;Yd1yXevcSW0=oMKr2NoT$}qg(^wll`dZ4TV^FqIKtqfNb z&>{jlF+sGM2YPUOo1*@L&p876Q9f+vfKHoH;V+>9^#rt_5h8og!kG=&vfJ$`0X^6Q zB@eV!>4!rtg-;0R+mjI3g}xOOgWtzEIS|lmzd*_ltrscrrHN||1auMImc$EPIJ#Jo znAuxPK);J*lX;+9qh2cnN%~p>nliw44(OK;wqa{TVjKaTS_+Xp=%UInj1Lv9BcMZC zpyYv`8Qy?%N4c{E^mH!-cA+22qVYG$j>QD@{60wep)CU#PX%=DB%pg1I_L62UwTm* zzEM0cfq?eblpNuKcKO<;cqYqwOF%c|!FCSl%DN~#)N)fpK!2}>$R2dll5jkty37#J unH^B_K-YN(;!j88j07|uhQKa#g&^vQo&!vqRjNnyu_UNk|Ll2z2xGu zDrvXG0{feAi-J<-Y)S{9^-K50gR88(8jmiJm9TW_JZob8zR+OI`Ur>|^ z)|{GhiydToF_3FI*`7m!%Lb?j6t%@3lM^{)jlVExGID(gV`O0A>8QHQEP0bp`htx8 h0om)0ahDw9F2vVdcC5Y3SJ&wNfti6tsz?N=698BBje7t9 delta 266 zcmaDa+##s`nU|M~0SFqqr)M7EU|@I*;=lkal=0bVqPjCnI-@4r#&xkwe4IdGW+47- z$Hc%eb#gw7(PVGt$CEEGN3wE(w18RMFqS9}h-JzMB!D28SC37FL6dLt0oLrvCTz9> zN best_f1: + best_f1 = avg_macro_f1 + patience_counter = 0 + torch.save(model.state_dict(), f"outputs/best_mode.pt") + print(" New best model saved.") + else: + patience_counter += 1 + print(f" No improvement. Patience counter: {patience_counter}/{PATIENCE}") + if patience_counter >= PATIENCE: + print(" Early stopping triggered.") + break + writer.close() + print("Training complete.") + +if __name__ == "__main__": + main() \ No newline at end of file