Implemented dataset.py which tokenises and returns tensors, ready to load the model now
This commit is contained in:
@@ -17,10 +17,39 @@ class ReviewDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
review = self.df.iloc[idx]['review']
|
review = self.df.iloc[idx]['review']
|
||||||
return review
|
|
||||||
|
|
||||||
uber = ReviewDataset("data/processed/original_train.csv", AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base"))
|
# encoding['input_ids']
|
||||||
print(uber.__getitem__(1))
|
# encoding['attention_mask']
|
||||||
|
# Both have shape [1, max_length] because of return_tensors='pt'
|
||||||
|
# Squeeze them to [max_length] with .squeeze(0)
|
||||||
|
encoding = self.tokenizer(
|
||||||
|
review,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Returns a dictionary with:
|
||||||
|
# 'input_ids': tensor of shape [max_length]
|
||||||
|
|
||||||
|
# 'attention_mask': tensor of shape [max_length]
|
||||||
|
|
||||||
|
# 'bug_report': tensor scalar (torch.tensor(label_value))
|
||||||
|
# 'feature_request': tensor scalar (torch.tensor(label_value))
|
||||||
|
# 'aspect': tensor scalar (torch.tensor(label_value))
|
||||||
|
# 'aspect_sentiment': tensor scalar (torch.tensor(label_value))
|
||||||
|
return {
|
||||||
|
'input_ids': encoding['input_ids'].squeeze(0),
|
||||||
|
'attention_mask': encoding['attention_mask'].squeeze(0),
|
||||||
|
'bug_report': torch.tensor(self.df.iloc[idx]['bug_report']),
|
||||||
|
'feature_request': torch.tensor(self.df.iloc[idx]['feature_request']),
|
||||||
|
'aspect': torch.tensor(self.df.iloc[idx]['aspect']),
|
||||||
|
'aspect_sentiment': torch.tensor(self.df.iloc[idx]['aspect_sentiment'])
|
||||||
|
}
|
||||||
|
|
||||||
|
# uber = ReviewDataset("data/processed/original_train.csv", AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base"))
|
||||||
|
# print(uber.__getitem__(1))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
src/model.py
Normal file
0
src/model.py
Normal file
Reference in New Issue
Block a user