Spaces:
Sleeping
Sleeping
rm stuff
Browse files
app.py
CHANGED
@@ -47,27 +47,27 @@ def generate_imgs_from_user(image,
|
|
47 |
x = transforms.Grayscale(num_output_channels=1)(x)
|
48 |
x = torch.cat((x, torch.zeros_like(x)), dim=1)
|
49 |
|
50 |
-
return generate_imgs(x, physics,
|
51 |
|
52 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
53 |
-
physics: PhysicsWithGenerator, use_gen: bool,
|
54 |
baseline: BaselineModel, model: EvalModel,
|
55 |
metrics: List[Metric]):
|
56 |
### Load 1 image
|
57 |
x = dataset[idx] # shape : (C, H, W)
|
58 |
x = x.unsqueeze(0) # shape : (1, C, H, W)
|
59 |
|
60 |
-
return generate_imgs(x, physics,
|
61 |
|
62 |
def generate_random_imgs_from_dataset(dataset: EvalDataset,
|
63 |
physics: PhysicsWithGenerator,
|
64 |
-
use_gen: bool,
|
65 |
baseline: BaselineModel,
|
66 |
model: EvalModel,
|
67 |
metrics: List[Metric]):
|
68 |
idx = random.randint(0, len(dataset)-1)
|
69 |
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
|
70 |
-
dataset, idx, physics,
|
71 |
)
|
72 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
73 |
|
|
|
47 |
x = transforms.Grayscale(num_output_channels=1)(x)
|
48 |
x = torch.cat((x, torch.zeros_like(x)), dim=1)
|
49 |
|
50 |
+
return generate_imgs(x, physics, True, baseline, model, metrics)
|
51 |
|
52 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
53 |
+
physics: PhysicsWithGenerator, # use_gen: bool,
|
54 |
baseline: BaselineModel, model: EvalModel,
|
55 |
metrics: List[Metric]):
|
56 |
### Load 1 image
|
57 |
x = dataset[idx] # shape : (C, H, W)
|
58 |
x = x.unsqueeze(0) # shape : (1, C, H, W)
|
59 |
|
60 |
+
return generate_imgs(x, physics, True, baseline, model, metrics)
|
61 |
|
62 |
def generate_random_imgs_from_dataset(dataset: EvalDataset,
|
63 |
physics: PhysicsWithGenerator,
|
64 |
+
# use_gen: bool,
|
65 |
baseline: BaselineModel,
|
66 |
model: EvalModel,
|
67 |
metrics: List[Metric]):
|
68 |
idx = random.randint(0, len(dataset)-1)
|
69 |
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
|
70 |
+
dataset, idx, physics, True, baseline, model, metrics
|
71 |
)
|
72 |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
73 |
|