Model almost complete, need to work on loss functions soon
This commit is contained in:
BIN
src/__pycache__/dataset.cpython-313.pyc
Normal file
BIN
src/__pycache__/dataset.cpython-313.pyc
Normal file
Binary file not shown.
31
src/model.py
31
src/model.py
@@ -24,13 +24,40 @@ class Model(nn.Module):
|
||||
self.feature_head = nn.Linear(hidden_size, 2)
|
||||
self.aspect_head = nn.Linear(hidden_size, 6)
|
||||
self.aspect_sentiment_head = nn.Linear(hidden_size, 3)
|
||||
|
||||
# Pass through encoder then extract the token representation
|
||||
# Apply droupout to it, take scores for each head, return them in a dictionary
|
||||
def forward(self, input_ids, attention_mask):
|
||||
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
||||
output = outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
output = self.dropout(output)
|
||||
|
||||
# Logits for each head:
|
||||
bug_logits = self.bug_head(output)
|
||||
feature_logits = self.feature_head(output)
|
||||
aspect_logits = self.aspect_head(output)
|
||||
aspect_sentiment = self.aspect_sentiment_head(output)
|
||||
return {
|
||||
'bug_report': bug_logits,
|
||||
'feature_request': feature_logits,
|
||||
'aspect': aspect_logits,
|
||||
'aspect_sentiment': aspect_sentiment
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dataset import ReviewDataset
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
dataset = ReviewDataset("data/processed/original_train.csv", tokenizer)
|
||||
loader = DataLoader(dataset, batch_size=2)
|
||||
|
||||
batch = next(iter(loader))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
model = AutoModelForMaskedLM.from_pretrained("FacebookAI/xlm-roberta-base")
|
||||
model = Model()
|
||||
outputs = model(batch["input_ids"], batch["attention_mask"])
|
||||
|
||||
for k, v in outputs.items():
|
||||
print(k, v.shape)
|
||||
Reference in New Issue
Block a user