mterris commited on
Commit
abfbafe
·
1 Parent(s): c6dcd55
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -47,27 +47,27 @@ def generate_imgs_from_user(image,
47
  x = transforms.Grayscale(num_output_channels=1)(x)
48
  x = torch.cat((x, torch.zeros_like(x)), dim=1)
49
 
50
- return generate_imgs(x, physics, use_gen, baseline, model, metrics)
51
 
52
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
53
- physics: PhysicsWithGenerator, use_gen: bool,
54
  baseline: BaselineModel, model: EvalModel,
55
  metrics: List[Metric]):
56
  ### Load 1 image
57
  x = dataset[idx] # shape : (C, H, W)
58
  x = x.unsqueeze(0) # shape : (1, C, H, W)
59
 
60
- return generate_imgs(x, physics, use_gen, baseline, model, metrics)
61
 
62
  def generate_random_imgs_from_dataset(dataset: EvalDataset,
63
  physics: PhysicsWithGenerator,
64
- use_gen: bool,
65
  baseline: BaselineModel,
66
  model: EvalModel,
67
  metrics: List[Metric]):
68
  idx = random.randint(0, len(dataset)-1)
69
  x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
70
- dataset, idx, physics, use_gen, baseline, model, metrics
71
  )
72
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
73
 
 
47
  x = transforms.Grayscale(num_output_channels=1)(x)
48
  x = torch.cat((x, torch.zeros_like(x)), dim=1)
49
 
50
+ return generate_imgs(x, physics, True, baseline, model, metrics)
51
 
52
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
53
+ physics: PhysicsWithGenerator, # use_gen: bool,
54
  baseline: BaselineModel, model: EvalModel,
55
  metrics: List[Metric]):
56
  ### Load 1 image
57
  x = dataset[idx] # shape : (C, H, W)
58
  x = x.unsqueeze(0) # shape : (1, C, H, W)
59
 
60
+ return generate_imgs(x, physics, True, baseline, model, metrics)
61
 
62
  def generate_random_imgs_from_dataset(dataset: EvalDataset,
63
  physics: PhysicsWithGenerator,
64
+ # use_gen: bool,
65
  baseline: BaselineModel,
66
  model: EvalModel,
67
  metrics: List[Metric]):
68
  idx = random.randint(0, len(dataset)-1)
69
  x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
70
+ dataset, idx, physics, True, baseline, model, metrics
71
  )
72
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
73