Spaces:
Running
Running
Make color and structure weights configurable
Browse files
app.py
CHANGED
@@ -151,6 +151,8 @@ def run(
|
|
151 |
image,
|
152 |
style_type: str,
|
153 |
style_id: float,
|
|
|
|
|
154 |
dlib_landmark_model,
|
155 |
encoder: nn.Module,
|
156 |
generator_dict: dict[str, nn.Module],
|
@@ -191,7 +193,8 @@ def run(
|
|
191 |
truncation=0.7,
|
192 |
truncation_latent=0,
|
193 |
use_res=True,
|
194 |
-
interp_weights=[
|
|
|
195 |
img_gen = torch.clamp(img_gen.detach(), -1, 1)
|
196 |
# deactivate color-related layers by setting w_c = 0
|
197 |
img_gen2, _ = generator([instyle],
|
@@ -200,7 +203,7 @@ def run(
|
|
200 |
truncation=0.7,
|
201 |
truncation_latent=0,
|
202 |
use_res=True,
|
203 |
-
interp_weights=[
|
204 |
img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
|
205 |
|
206 |
img_rec = postprocess(img_rec[0])
|
@@ -249,7 +252,8 @@ def main():
|
|
249 |
func = functools.update_wrapper(func, run)
|
250 |
|
251 |
image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
|
252 |
-
examples = [[path.as_posix(), 'cartoon', 26
|
|
|
253 |
|
254 |
gr.Interface(
|
255 |
func,
|
@@ -262,6 +266,10 @@ def main():
|
|
262 |
label='Style Type',
|
263 |
),
|
264 |
gr.inputs.Number(default=26, label='Style Image Index'),
|
|
|
|
|
|
|
|
|
265 |
],
|
266 |
[
|
267 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|
|
|
151 |
image,
|
152 |
style_type: str,
|
153 |
style_id: float,
|
154 |
+
structure_weight: float,
|
155 |
+
color_weight: float,
|
156 |
dlib_landmark_model,
|
157 |
encoder: nn.Module,
|
158 |
generator_dict: dict[str, nn.Module],
|
|
|
193 |
truncation=0.7,
|
194 |
truncation_latent=0,
|
195 |
use_res=True,
|
196 |
+
interp_weights=[structure_weight] * 7 +
|
197 |
+
[color_weight] * 11)
|
198 |
img_gen = torch.clamp(img_gen.detach(), -1, 1)
|
199 |
# deactivate color-related layers by setting w_c = 0
|
200 |
img_gen2, _ = generator([instyle],
|
|
|
203 |
truncation=0.7,
|
204 |
truncation_latent=0,
|
205 |
use_res=True,
|
206 |
+
interp_weights=[structure_weight] * 7 + [0] * 11)
|
207 |
img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
|
208 |
|
209 |
img_rec = postprocess(img_rec[0])
|
|
|
252 |
func = functools.update_wrapper(func, run)
|
253 |
|
254 |
image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
|
255 |
+
examples = [[path.as_posix(), 'cartoon', 26, 0.6, 1.0]
|
256 |
+
for path in image_paths]
|
257 |
|
258 |
gr.Interface(
|
259 |
func,
|
|
|
266 |
label='Style Type',
|
267 |
),
|
268 |
gr.inputs.Number(default=26, label='Style Image Index'),
|
269 |
+
gr.inputs.Slider(
|
270 |
+
0, 1, step=0.1, default=0.6, label='Structure Weight'),
|
271 |
+
gr.inputs.Slider(0, 1, step=0.1, default=1.0,
|
272 |
+
label='Color Weight'),
|
273 |
],
|
274 |
[
|
275 |
gr.outputs.Image(type='pil', label='Aligned Face'),
|