msong97 commited on
Commit
4e7aed4
·
1 Parent(s): ccc37f0

unfinished work

Browse files
Files changed (1) hide show
  1. app.py +19 -8
app.py CHANGED
@@ -74,18 +74,18 @@ def generate_imgs(dataset: EvalDataset, idx: int,
74
 
75
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
76
 
77
- def update_random_idx_and_generate_imgs(dataset: EvalDataset,
78
- model: EvalModel,
79
- baseline: BaselineModel,
80
  physics: PhysicsWithGenerator,
81
  use_gen: bool,
82
  metrics: List[Metric]):
83
  idx = random.randint(0, len(dataset)-1)
84
- x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
85
- idx,
86
- model,
87
- baseline,
88
- physics,
89
  use_gen,
90
  metrics)
91
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
@@ -145,6 +145,17 @@ def get_model(model_name, ckpt_pth):
145
  else:
146
  return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
147
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  ### Gradio Blocks interface
150
 
 
74
 
75
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
76
 
77
+ def update_random_idx_and_generate_imgs(dataset: EvalDataset,
78
+ model: EvalModel,
79
+ baseline: BaselineModel,
80
  physics: PhysicsWithGenerator,
81
  use_gen: bool,
82
  metrics: List[Metric]):
83
  idx = random.randint(0, len(dataset)-1)
84
+ x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
85
+ idx,
86
+ model,
87
+ baseline,
88
+ physics,
89
  use_gen,
90
  metrics)
91
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
 
145
  else:
146
  return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
147
 
148
+ AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
149
+
150
+ def get_dataset(dataset_name):
151
+ global AVAILABLE_PHYSICS
152
+ if dataset_name = 'MRI':
153
+ AVAILABLE_PHYSICS = ['MRI']
154
+ elif dataset_name = 'CT':
155
+ AVAILABLE_PHYSICS = ['CT']
156
+ else:
157
+ AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
158
+ return get_dataset_on_DEVICE_STR(dataset_name)
159
 
160
  ### Gradio Blocks interface
161