ccareaga commited on
Commit
e57d033
·
1 Parent(s): baf4bca

adding model switch

Browse files
Files changed (1) hide show
  1. app.py +52 -11
app.py CHANGED
@@ -9,24 +9,35 @@ from intrinsic.pipeline import load_models, run_pipeline
9
 
10
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
- intrinsic_models = load_models('v2', device=DEVICE)
 
13
 
14
- def generate_pipeline(models):
 
 
 
15
 
 
16
  def pipeline_func(image, **kwargs):
17
  return run_pipeline(models, image, **kwargs)
18
 
19
  return pipeline_func
20
 
21
-
22
- pipeline_func = generate_pipeline(intrinsic_models)
23
-
24
  @spaces.GPU
25
- def process_image(image):
 
 
 
 
 
26
  print(image.shape)
27
  image = image.astype(np.single) / 255.
28
 
29
- result = pipeline_func(image, device=DEVICE)
 
 
 
 
30
 
31
  return [view(result['hr_alb']), 1 - invert(result['dif_shd']), view_scale(result['pos_res'])]
32
 
@@ -61,6 +72,16 @@ with gr.Blocks(
61
  .md_feedback li {
62
  margin-bottom: 0px !important;
63
  }
 
 
 
 
 
 
 
 
 
 
64
  """,
65
  ) as demo:
66
  gr.Markdown(
@@ -75,14 +96,34 @@ with gr.Blocks(
75
  </a>
76
  """
77
  )
78
- with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
79
  input_img = gr.Image(label="Input Image")
80
-
81
- with gr.Row():
82
  alb_img = gr.Image(label="Albedo")
83
  shd_img = gr.Image(label="Diffuse Shading")
84
  dif_img = gr.Image(label="Diffuse Image")
85
 
86
- input_img.change(process_image, inputs=input_img, outputs=[alb_img, shd_img, dif_img])
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  demo.launch(show_error=True)
 
9
 
10
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
+ # Instead of loading models at startup, we'll create a cache for models
13
+ model_cache = {}
14
 
15
+ def get_model(model_version):
16
+ if model_version not in model_cache:
17
+ model_cache[model_version] = load_models(model_version, device=DEVICE)
18
+ return model_cache[model_version]
19
 
20
+ def generate_pipeline(models):
21
  def pipeline_func(image, **kwargs):
22
  return run_pipeline(models, image, **kwargs)
23
 
24
  return pipeline_func
25
 
 
 
 
26
  @spaces.GPU
27
+ def process_image(image, model_version):
28
+ # Check if image is provided
29
+ if image is None:
30
+ return [None, None, None]
31
+
32
+ print(f"Processing with model version: {model_version}")
33
  print(image.shape)
34
  image = image.astype(np.single) / 255.
35
 
36
+ # Get or load the selected model
37
+ models = get_model(model_version)
38
+ pipeline_func = generate_pipeline(models)
39
+
40
+ result = pipeline_func(image, device=DEVICE, resize_conf=1024)
41
 
42
  return [view(result['hr_alb']), 1 - invert(result['dif_shd']), view_scale(result['pos_res'])]
43
 
 
72
  .md_feedback li {
73
  margin-bottom: 0px !important;
74
  }
75
+ .image-gallery {
76
+ display: flex;
77
+ flex-wrap: wrap;
78
+ gap: 10px;
79
+ justify-content: center;
80
+ }
81
+ .image-gallery > * {
82
+ flex: 1;
83
+ min-width: 200px;
84
+ }
85
  """,
86
  ) as demo:
87
  gr.Markdown(
 
96
  </a>
97
  """
98
  )
99
+
100
+ # More compact model version selector using dropdown
101
+ model_version = gr.Dropdown(
102
+ choices=["v2", "v2.1"],
103
+ value="v2",
104
+ label="Model Version",
105
+ info="Select which model weights to use",
106
+ scale=2
107
+ )
108
+
109
+ # Gallery-style layout for all images
110
+ with gr.Row(elem_classes="image-gallery"):
111
  input_img = gr.Image(label="Input Image")
 
 
112
  alb_img = gr.Image(label="Albedo")
113
  shd_img = gr.Image(label="Diffuse Shading")
114
  dif_img = gr.Image(label="Diffuse Image")
115
 
116
+ # Update to pass model_version to process_image
117
+ input_img.change(
118
+ process_image,
119
+ inputs=[input_img, model_version],
120
+ outputs=[alb_img, shd_img, dif_img]
121
+ )
122
+ # Add event handler for when model_version changes
123
+ model_version.change(
124
+ process_image,
125
+ inputs=[input_img, model_version],
126
+ outputs=[alb_img, shd_img, dif_img]
127
+ )
128
 
129
  demo.launch(show_error=True)