libokj commited on
Commit
d0473a1
·
1 Parent(s): f4b7f1a

Upload dti.py

Browse files
Files changed (1) hide show
  1. deepscreen/data/dti.py +8 -5
deepscreen/data/dti.py CHANGED
@@ -192,17 +192,20 @@ class DTIDataset(Dataset):
192
 
193
  def __getitem__(self, i):
194
  sample = self.df.loc[i]
195
- return {
196
  'N': i,
197
  'X1': sample['X1'],
198
  'X1^': self.drug_featurizer(sample['X1^']),
199
- 'ID1': sample.get('ID1'),
200
  'X2': sample['X2'],
201
  'X2^': self.protein_featurizer(sample['X2']),
202
- 'ID2': sample.get('ID2'),
203
- 'Y': sample.get('Y'),
204
- 'ID^': sample.get('ID^'),
205
  }
 
 
 
206
 
207
 
208
  class DTIDataModule(LightningDataModule):
 
192
 
193
  def __getitem__(self, i):
194
  sample = self.df.loc[i]
195
+ sample_dict = {
196
  'N': i,
197
  'X1': sample['X1'],
198
  'X1^': self.drug_featurizer(sample['X1^']),
199
+ # 'ID1': sample.get('ID1'),
200
  'X2': sample['X2'],
201
  'X2^': self.protein_featurizer(sample['X2']),
202
+ # 'ID2': sample.get('ID2'),
203
+ # 'Y': sample.get('Y'),
204
+ # 'ID^': sample.get('ID^'),
205
  }
206
+ optional_keys = ['ID1', 'ID2', 'ID^', 'Y']
207
+ sample_dict.update({key: sample[key] for key in optional_keys if sample.get(key) is not None})
208
+ return sample_dict
209
 
210
 
211
  class DTIDataModule(LightningDataModule):