ccchenzc commited on
Commit
82a7709
·
verified ·
1 Parent(s): 0fa3943

Update webui/runner.py

Browse files
Files changed (1) hide show
  1. webui/runner.py +161 -161
webui/runner.py CHANGED
@@ -1,161 +1,161 @@
1
- import torch
2
- from PIL import Image
3
- from diffusers import DDIMScheduler
4
- from accelerate.utils import set_seed
5
- from torchvision.transforms.functional import to_pil_image, to_tensor
6
-
7
- from pipeline_sd import ADPipeline
8
- from pipeline_sdxl import ADPipeline as ADXLPipeline
9
- from utils import Controller
10
-
11
- import os
12
- import spaces
13
-
14
-
15
- class Runner:
16
- def __init__(self):
17
- self.sd15 = None
18
- self.sdxl = None
19
- self.loss_fn = torch.nn.L1Loss(reduction="mean")
20
-
21
- def load_pipeline(self, model_path_or_name):
22
-
23
- if 'xl' in model_path_or_name and self.sdxl is None:
24
- scheduler = DDIMScheduler.from_pretrained(os.path.join('./checkpoints', model_path_or_name), subfolder="scheduler")
25
- self.sdxl = ADXLPipeline.from_pretrained(os.path.join('./checkpoints', model_path_or_name), scheduler=scheduler, safety_checker=None)
26
- self.sdxl.classifier = self.sdxl.unet
27
- elif self.sd15 is None:
28
- scheduler = DDIMScheduler.from_pretrained(os.path.join('./checkpoints', model_path_or_name), subfolder="scheduler")
29
- self.sd15 = ADPipeline.from_pretrained(os.path.join('./checkpoints', model_path_or_name), scheduler=scheduler, safety_checker=None)
30
- self.sd15.classifier = self.sd15.unet
31
-
32
- def preprocecss(self, image: Image.Image, height=None, width=None):
33
- if width is None or height is None:
34
- width, height = image.size
35
- new_width = (width // 64) * 64
36
- new_height = (height // 64) * 64
37
- size = (new_width, new_height)
38
- image = image.resize(size, Image.BICUBIC)
39
- return to_tensor(image).unsqueeze(0)
40
-
41
- @spaces.GPU
42
- def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs):
43
- self.load_pipeline(model)
44
-
45
- content_image = self.preprocecss(content_image)
46
- style_image = self.preprocecss(style_image, height=512, width=512)
47
-
48
- height, width = content_image.shape[-2:]
49
- set_seed(seed)
50
- controller = Controller(self_layers=(10, 16))
51
- result = self.sd15.optimize(
52
- lr=lr,
53
- batch_size=1,
54
- iters=1,
55
- width=width,
56
- height=height,
57
- weight=content_weight,
58
- controller=controller,
59
- style_image=style_image,
60
- content_image=content_image,
61
- mixed_precision=mixed_precision,
62
- num_inference_steps=num_steps,
63
- enable_gradient_checkpoint=False,
64
- )
65
- output_image = to_pil_image(result[0])
66
- del result
67
- torch.cuda.empty_cache()
68
- return [output_image]
69
-
70
- @spaces.GPU
71
- def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model):
72
- self.load_pipeline(model)
73
-
74
- use_xl = 'xl' in model
75
- height, width = (1024, 1024) if 'xl' in model else (512, 512)
76
- style_image = self.preprocecss(style_image, height=height, width=width)
77
-
78
- set_seed(seed)
79
- self_layers = (64, 70) if use_xl else (10, 16)
80
-
81
- controller = Controller(self_layers=self_layers)
82
-
83
- pipeline = self.sdxl if use_xl else self.sd15
84
- images = pipeline.sample(
85
- controller=controller,
86
- iters=iterations,
87
- lr=lr,
88
- adain=is_adain,
89
- height=height,
90
- width=width,
91
- mixed_precision=mixed_precision,
92
- style_image=style_image,
93
- prompt=prompt,
94
- negative_prompt=negative_prompt,
95
- guidance_scale=guidance_scale,
96
- num_inference_steps=num_steps,
97
- num_images_per_prompt=num_images_per_prompt,
98
- enable_gradient_checkpoint=False
99
- )
100
- output_images = [to_pil_image(image) for image in images]
101
-
102
- del images
103
- torch.cuda.empty_cache()
104
- return output_images
105
-
106
- @spaces.GPU
107
- def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model):
108
- self.load_pipeline(model)
109
-
110
- texture_image = self.preprocecss(texture_image, height=512, width=512)
111
-
112
- set_seed(seed)
113
- controller = Controller(self_layers=(10, 16))
114
-
115
- if synthesis_way == 'Sampling':
116
- results = self.sd15.sample(
117
- lr=lr,
118
- adain=False,
119
- iters=iterations,
120
- width=width,
121
- height=height,
122
- weight=0.,
123
- controller=controller,
124
- style_image=texture_image,
125
- content_image=None,
126
- prompt="",
127
- negative_prompt="",
128
- mixed_precision=mixed_precision,
129
- num_inference_steps=num_steps,
130
- guidance_scale=1.,
131
- num_images_per_prompt=num_images_per_prompt,
132
- enable_gradient_checkpoint=False,
133
- )
134
- elif synthesis_way == 'MultiDiffusion':
135
- results = self.sd15.panorama(
136
- lr=lr,
137
- iters=iterations,
138
- width=width,
139
- height=height,
140
- weight=0.,
141
- controller=controller,
142
- style_image=texture_image,
143
- content_image=None,
144
- prompt="",
145
- negative_prompt="",
146
- stride=8,
147
- view_batch_size=8,
148
- mixed_precision=mixed_precision,
149
- num_inference_steps=num_steps,
150
- guidance_scale=1.,
151
- num_images_per_prompt=num_images_per_prompt,
152
- enable_gradient_checkpoint=False,
153
- )
154
- else:
155
- raise ValueError
156
-
157
- output_images = [to_pil_image(image) for image in results]
158
- del results
159
- torch.cuda.empty_cache()
160
- return output_images
161
-
 
1
+ import torch
2
+ from PIL import Image
3
+ from diffusers import DDIMScheduler
4
+ from accelerate.utils import set_seed
5
+ from torchvision.transforms.functional import to_pil_image, to_tensor
6
+
7
+ from pipeline_sd import ADPipeline
8
+ from pipeline_sdxl import ADPipeline as ADXLPipeline
9
+ from utils import Controller
10
+
11
+ import os
12
+ import spaces
13
+
14
+
15
+ class Runner:
16
+ def __init__(self):
17
+ self.sd15 = None
18
+ self.sdxl = None
19
+ self.loss_fn = torch.nn.L1Loss(reduction="mean")
20
+
21
+ def load_pipeline(self, model_path_or_name):
22
+
23
+ if 'xl' in model_path_or_name and self.sdxl is None:
24
+ scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
25
+ self.sdxl = ADXLPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
26
+ self.sdxl.classifier = self.sdxl.unet
27
+ elif self.sd15 is None:
28
+ scheduler = DDIMScheduler.from_pretrained(model_path_or_name, subfolder="scheduler")
29
+ self.sd15 = ADPipeline.from_pretrained(model_path_or_name, scheduler=scheduler, safety_checker=None)
30
+ self.sd15.classifier = self.sd15.unet
31
+
32
+ def preprocecss(self, image: Image.Image, height=None, width=None):
33
+ if width is None or height is None:
34
+ width, height = image.size
35
+ new_width = (width // 64) * 64
36
+ new_height = (height // 64) * 64
37
+ size = (new_width, new_height)
38
+ image = image.resize(size, Image.BICUBIC)
39
+ return to_tensor(image).unsqueeze(0)
40
+
41
+ @spaces.GPU
42
+ def run_style_transfer(self, content_image, style_image, seed, num_steps, lr, content_weight, mixed_precision, model, **kwargs):
43
+ self.load_pipeline(model)
44
+
45
+ content_image = self.preprocecss(content_image)
46
+ style_image = self.preprocecss(style_image, height=512, width=512)
47
+
48
+ height, width = content_image.shape[-2:]
49
+ set_seed(seed)
50
+ controller = Controller(self_layers=(10, 16))
51
+ result = self.sd15.optimize(
52
+ lr=lr,
53
+ batch_size=1,
54
+ iters=1,
55
+ width=width,
56
+ height=height,
57
+ weight=content_weight,
58
+ controller=controller,
59
+ style_image=style_image,
60
+ content_image=content_image,
61
+ mixed_precision=mixed_precision,
62
+ num_inference_steps=num_steps,
63
+ enable_gradient_checkpoint=False,
64
+ )
65
+ output_image = to_pil_image(result[0])
66
+ del result
67
+ torch.cuda.empty_cache()
68
+ return [output_image]
69
+
70
+ @spaces.GPU
71
+ def run_style_t2i_generation(self, style_image, prompt, negative_prompt, guidance_scale, height, width, seed, num_steps, iterations, lr, num_images_per_prompt, mixed_precision, is_adain, model):
72
+ self.load_pipeline(model)
73
+
74
+ use_xl = 'xl' in model
75
+ height, width = (1024, 1024) if 'xl' in model else (512, 512)
76
+ style_image = self.preprocecss(style_image, height=height, width=width)
77
+
78
+ set_seed(seed)
79
+ self_layers = (64, 70) if use_xl else (10, 16)
80
+
81
+ controller = Controller(self_layers=self_layers)
82
+
83
+ pipeline = self.sdxl if use_xl else self.sd15
84
+ images = pipeline.sample(
85
+ controller=controller,
86
+ iters=iterations,
87
+ lr=lr,
88
+ adain=is_adain,
89
+ height=height,
90
+ width=width,
91
+ mixed_precision=mixed_precision,
92
+ style_image=style_image,
93
+ prompt=prompt,
94
+ negative_prompt=negative_prompt,
95
+ guidance_scale=guidance_scale,
96
+ num_inference_steps=num_steps,
97
+ num_images_per_prompt=num_images_per_prompt,
98
+ enable_gradient_checkpoint=False
99
+ )
100
+ output_images = [to_pil_image(image) for image in images]
101
+
102
+ del images
103
+ torch.cuda.empty_cache()
104
+ return output_images
105
+
106
+ @spaces.GPU
107
+ def run_texture_synthesis(self, texture_image, height, width, seed, num_steps, iterations, lr, mixed_precision, num_images_per_prompt, synthesis_way,model):
108
+ self.load_pipeline(model)
109
+
110
+ texture_image = self.preprocecss(texture_image, height=512, width=512)
111
+
112
+ set_seed(seed)
113
+ controller = Controller(self_layers=(10, 16))
114
+
115
+ if synthesis_way == 'Sampling':
116
+ results = self.sd15.sample(
117
+ lr=lr,
118
+ adain=False,
119
+ iters=iterations,
120
+ width=width,
121
+ height=height,
122
+ weight=0.,
123
+ controller=controller,
124
+ style_image=texture_image,
125
+ content_image=None,
126
+ prompt="",
127
+ negative_prompt="",
128
+ mixed_precision=mixed_precision,
129
+ num_inference_steps=num_steps,
130
+ guidance_scale=1.,
131
+ num_images_per_prompt=num_images_per_prompt,
132
+ enable_gradient_checkpoint=False,
133
+ )
134
+ elif synthesis_way == 'MultiDiffusion':
135
+ results = self.sd15.panorama(
136
+ lr=lr,
137
+ iters=iterations,
138
+ width=width,
139
+ height=height,
140
+ weight=0.,
141
+ controller=controller,
142
+ style_image=texture_image,
143
+ content_image=None,
144
+ prompt="",
145
+ negative_prompt="",
146
+ stride=8,
147
+ view_batch_size=8,
148
+ mixed_precision=mixed_precision,
149
+ num_inference_steps=num_steps,
150
+ guidance_scale=1.,
151
+ num_images_per_prompt=num_images_per_prompt,
152
+ enable_gradient_checkpoint=False,
153
+ )
154
+ else:
155
+ raise ValueError
156
+
157
+ output_images = [to_pil_image(image) for image in results]
158
+ del results
159
+ torch.cuda.empty_cache()
160
+ return output_images
161
+