hysts HF staff commited on
Commit
985f232
·
1 Parent(s): 2dfdfa9

Make color and structure weights configurable

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -151,6 +151,8 @@ def run(
151
  image,
152
  style_type: str,
153
  style_id: float,
 
 
154
  dlib_landmark_model,
155
  encoder: nn.Module,
156
  generator_dict: dict[str, nn.Module],
@@ -191,7 +193,8 @@ def run(
191
  truncation=0.7,
192
  truncation_latent=0,
193
  use_res=True,
194
- interp_weights=[0.6] * 7 + [1] * 11)
 
195
  img_gen = torch.clamp(img_gen.detach(), -1, 1)
196
  # deactivate color-related layers by setting w_c = 0
197
  img_gen2, _ = generator([instyle],
@@ -200,7 +203,7 @@ def run(
200
  truncation=0.7,
201
  truncation_latent=0,
202
  use_res=True,
203
- interp_weights=[0.6] * 7 + [0] * 11)
204
  img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
205
 
206
  img_rec = postprocess(img_rec[0])
@@ -249,7 +252,8 @@ def main():
249
  func = functools.update_wrapper(func, run)
250
 
251
  image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
252
- examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
 
253
 
254
  gr.Interface(
255
  func,
@@ -262,6 +266,10 @@ def main():
262
  label='Style Type',
263
  ),
264
  gr.inputs.Number(default=26, label='Style Image Index'),
 
 
 
 
265
  ],
266
  [
267
  gr.outputs.Image(type='pil', label='Aligned Face'),
 
151
  image,
152
  style_type: str,
153
  style_id: float,
154
+ structure_weight: float,
155
+ color_weight: float,
156
  dlib_landmark_model,
157
  encoder: nn.Module,
158
  generator_dict: dict[str, nn.Module],
 
193
  truncation=0.7,
194
  truncation_latent=0,
195
  use_res=True,
196
+ interp_weights=[structure_weight] * 7 +
197
+ [color_weight] * 11)
198
  img_gen = torch.clamp(img_gen.detach(), -1, 1)
199
  # deactivate color-related layers by setting w_c = 0
200
  img_gen2, _ = generator([instyle],
 
203
  truncation=0.7,
204
  truncation_latent=0,
205
  use_res=True,
206
+ interp_weights=[structure_weight] * 7 + [0] * 11)
207
  img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
208
 
209
  img_rec = postprocess(img_rec[0])
 
252
  func = functools.update_wrapper(func, run)
253
 
254
  image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
255
+ examples = [[path.as_posix(), 'cartoon', 26, 0.6, 1.0]
256
+ for path in image_paths]
257
 
258
  gr.Interface(
259
  func,
 
266
  label='Style Type',
267
  ),
268
  gr.inputs.Number(default=26, label='Style Image Index'),
269
+ gr.inputs.Slider(
270
+ 0, 1, step=0.1, default=0.6, label='Structure Weight'),
271
+ gr.inputs.Slider(0, 1, step=0.1, default=1.0,
272
+ label='Color Weight'),
273
  ],
274
  [
275
  gr.outputs.Image(type='pil', label='Aligned Face'),