libokj commited on
Commit
0cb6552
·
verified ·
1 Parent(s): 22761bf

Update deepscreen/models/dti.py

Browse files
Files changed (1) hide show
  1. deepscreen/models/dti.py +48 -30
deepscreen/models/dti.py CHANGED
@@ -17,6 +17,8 @@ class DTILightningModule(LightningModule):
17
  model: a fully initialized instance of class torch.nn.Module
18
  metrics: a list of fully initialized instances of class torchmetrics.Metric
19
  """
 
 
20
  def __init__(
21
  self,
22
  optimizer: optim.Optimizer,
@@ -49,20 +51,19 @@ class DTILightningModule(LightningModule):
49
  match stage:
50
  case 'fit':
51
  dataloader = self.trainer.datamodule.train_dataloader()
52
- case 'validate':
53
- dataloader = self.trainer.datamodule.val_dataloader()
54
- case 'test':
55
- dataloader = self.trainer.datamodule.test_dataloader()
56
- case 'predict':
57
- dataloader = self.trainer.datamodule.predict_dataloader()
58
- dummy_batch = next(iter(dataloader))
 
59
 
60
  # for key, value in dummy_batch.items():
61
  # if isinstance(value, Tensor):
62
  # dummy_batch[key] = value.to(self.device)
63
 
64
- self.forward(dummy_batch)
65
-
66
  def forward(self, batch):
67
  output = self.predictor(batch['X1^'], batch['X2^'])
68
  target = batch.get('Y')
@@ -92,13 +93,18 @@ class DTILightningModule(LightningModule):
92
  self.train_metrics(preds=preds, target=target, indexes=indexes.long())
93
  self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
94
 
95
- return {
96
- 'N': batch['N'],
97
- 'ID1': batch['ID1'], 'X1': batch['X1'],
98
- 'ID2': batch['ID2'], 'X2': batch['X2'],
99
- 'Y^': preds, 'Y': target, 'loss': loss
100
  }
101
 
 
 
 
 
 
 
102
  def on_train_epoch_end(self):
103
  pass
104
 
@@ -109,13 +115,18 @@ class DTILightningModule(LightningModule):
109
  self.val_metrics(preds=preds, target=target, indexes=indexes.long())
110
  self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
111
 
112
- return {
113
- 'N': batch['N'],
114
- 'ID1': batch['ID1'], 'X1': batch['X1'],
115
- 'ID2': batch['ID2'], 'X2': batch['X2'],
116
- 'Y^': preds, 'Y': target, 'loss': loss
117
  }
118
 
 
 
 
 
 
 
119
  def on_validation_epoch_end(self):
120
  pass
121
 
@@ -126,27 +137,34 @@ class DTILightningModule(LightningModule):
126
  self.test_metrics(preds=preds, target=target, indexes=indexes.long())
127
  self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
128
 
129
- # return a dictionary for callbacks like BasePredictionWriter
130
- return {
131
- 'N': batch['N'],
132
- 'ID1': batch['ID1'], 'X1': batch['X1'],
133
- 'ID2': batch['ID2'], 'X2': batch['X2'],
134
- 'Y^': preds, 'Y': target, 'loss': loss
135
  }
136
 
 
 
 
 
 
 
137
  def on_test_epoch_end(self):
138
  pass
139
 
140
  def predict_step(self, batch, batch_idx, dataloader_idx=0):
141
  preds, _, _, _ = self.forward(batch)
142
  # return a dictionary for callbacks like BasePredictionWriter
143
- return {
144
- 'N': batch['N'],
145
- 'ID1': batch['ID1'], 'X1': batch['X1'],
146
- 'ID2': batch['ID2'], 'X2': batch['X2'],
147
- 'Y^': preds
148
  }
149
 
 
 
 
 
 
 
150
  def configure_optimizers(self):
151
  optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
152
  if self.hparams.get('scheduler'):
 
17
  model: a fully initialized instance of class torch.nn.Module
18
  metrics: a list of fully initialized instances of class torchmetrics.Metric
19
  """
20
+ extra_return_keys = ['ID1', 'X1', 'ID2', 'X2', 'N']
21
+
22
  def __init__(
23
  self,
24
  optimizer: optim.Optimizer,
 
51
  match stage:
52
  case 'fit':
53
  dataloader = self.trainer.datamodule.train_dataloader()
54
+ dummy_batch = next(iter(dataloader))
55
+ self.forward(dummy_batch)
56
+ # case 'validate':
57
+ # dataloader = self.trainer.datamodule.val_dataloader()
58
+ # case 'test':
59
+ # dataloader = self.trainer.datamodule.test_dataloader()
60
+ # case 'predict':
61
+ # dataloader = self.trainer.datamodule.predict_dataloader()
62
 
63
  # for key, value in dummy_batch.items():
64
  # if isinstance(value, Tensor):
65
  # dummy_batch[key] = value.to(self.device)
66
 
 
 
67
  def forward(self, batch):
68
  output = self.predictor(batch['X1^'], batch['X2^'])
69
  target = batch.get('Y')
 
93
  self.train_metrics(preds=preds, target=target, indexes=indexes.long())
94
  self.log_dict(self.train_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
95
 
96
+ return_dict = {
97
+ 'Y^': preds,
98
+ 'Y': target,
99
+ 'loss': loss
 
100
  }
101
 
102
+ for key in self.extra_return_keys:
103
+ if key in batch:
104
+ return_dict[key] = batch[key]
105
+
106
+ return return_dict
107
+
108
  def on_train_epoch_end(self):
109
  pass
110
 
 
115
  self.val_metrics(preds=preds, target=target, indexes=indexes.long())
116
  self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
117
 
118
+ return_dict = {
119
+ 'Y^': preds,
120
+ 'Y': target,
121
+ 'loss': loss
 
122
  }
123
 
124
+ for key in self.extra_return_keys:
125
+ if key in batch:
126
+ return_dict[key] = batch[key]
127
+
128
+ return return_dict
129
+
130
  def on_validation_epoch_end(self):
131
  pass
132
 
 
137
  self.test_metrics(preds=preds, target=target, indexes=indexes.long())
138
  self.log_dict(self.test_metrics, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
139
 
140
+ return_dict = {
141
+ 'Y^': preds,
142
+ 'Y': target,
143
+ 'loss': loss
 
 
144
  }
145
 
146
+ for key in self.extra_return_keys:
147
+ if key in batch:
148
+ return_dict[key] = batch[key]
149
+
150
+ return return_dict
151
+
152
  def on_test_epoch_end(self):
153
  pass
154
 
155
  def predict_step(self, batch, batch_idx, dataloader_idx=0):
156
  preds, _, _, _ = self.forward(batch)
157
  # return a dictionary for callbacks like BasePredictionWriter
158
+ return_dict = {
159
+ 'Y^': preds,
 
 
 
160
  }
161
 
162
+ for key in self.extra_return_keys:
163
+ if key in batch:
164
+ return_dict[key] = batch[key]
165
+
166
+ return return_dict
167
+
168
  def configure_optimizers(self):
169
  optimizers_config = {'optimizer': self.hparams.optimizer(params=self.parameters())}
170
  if self.hparams.get('scheduler'):