rhfeiyang commited on
Commit
c2786e2
·
verified ·
1 Parent(s): 19ff7f7

Update hf_demo.py

Browse files
Files changed (1) hide show
  1. hf_demo.py +149 -149
hf_demo.py CHANGED
@@ -1,150 +1,150 @@
1
- # Authors: Hui Ren (rhfeiyang.github.io)
2
- import os
3
- import spaces
4
- import gradio as gr
5
- from diffusers import DiffusionPipeline
6
- import matplotlib.pyplot as plt
7
- import torch
8
- from PIL import Image
9
-
10
-
11
-
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
14
- pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
15
- dtype=dtype).to(device)
16
-
17
- from inference import get_lora_network, inference, get_validation_dataloader
18
- lora_map = {
19
- "None": "None",
20
- "Andre Derain": "andre-derain_subset1",
21
- "Vincent van Gogh": "van_gogh_subset1",
22
- "Andy Warhol": "andy_subset1",
23
- "Walter Battiss": "walter-battiss_subset2",
24
- "Camille Corot": "camille-corot_subset1",
25
- "Claude Monet": "monet_subset2",
26
- "Pablo Picasso": "picasso_subset1",
27
- "Jackson Pollock": "jackson-pollock_subset1",
28
- "Gerhard Richter": "gerhard-richter_subset1",
29
- "M.C. Escher": "m.c.-escher_subset1",
30
- "Albert Gleizes": "albert-gleizes_subset1",
31
- "Hokusai": "katsushika-hokusai_subset1",
32
- "Wassily Kandinsky": "kandinsky_subset1",
33
- "Gustav Klimt": "klimt_subset3",
34
- "Roy Lichtenstein": "roy-lichtenstein_subset1",
35
- "Henri Matisse": "henri-matisse_subset1",
36
- "Joan Miro": "joan-miro_subset2",
37
- }
38
- @spaces.GPU
39
- def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
40
- adapter_path = lora_map[adapter_choice]
41
- if adapter_path not in [None, "None"]:
42
- adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
43
-
44
- prompts = [prompt]*samples
45
- infer_loader = get_validation_dataloader(prompts)
46
- network = get_lora_network(pipe.unet, adapter_path)["network"]
47
- pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
48
- height=512, width=512, scales=[1.0],
49
- save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
50
- start_noise=-1, show=False, style_prompt="sks art", no_load=True,
51
- from_scratch=True, device=device)[0][1.0]
52
- return pred_images
53
- @spaces.GPU
54
- def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
55
- infer_loader = get_validation_dataloader(prompts, image)
56
- network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
57
- pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
58
- height=512, width=512, scales=[0.,1.],
59
- save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
60
- start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
61
- from_scratch=False, device=device)
62
- return pred_images
63
-
64
- # def infer(prompt, samples, steps, scale, seed):
65
- # generator = torch.Generator(device=device).manual_seed(seed)
66
- # images_list = pipe( # type: ignore
67
- # [prompt] * samples,
68
- # num_inference_steps=steps,
69
- # guidance_scale=scale,
70
- # generator=generator,
71
- # )
72
- # images = []
73
- # safe_image = Image.open(r"data/unsafe.png")
74
- # print(images_list)
75
- # for i, image in enumerate(images_list["images"]): # type: ignore
76
- # if images_list["nsfw_content_detected"][i]: # type: ignore
77
- # images.append(safe_image)
78
- # else:
79
- # images.append(image)
80
- # return images
81
-
82
-
83
-
84
-
85
- block = gr.Blocks()
86
- # Direct infer
87
- with block:
88
- with gr.Group():
89
- gr.Markdown(" # Art-Free Diffusion Demo")
90
- with gr.Row():
91
- text = gr.Textbox(
92
- label="Enter your prompt",
93
- max_lines=2,
94
- placeholder="Enter your prompt",
95
- container=False,
96
- value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
97
- )
98
-
99
-
100
-
101
- btn = gr.Button("Run", scale=0)
102
- gallery = gr.Gallery(
103
- label="Generated images",
104
- show_label=False,
105
- elem_id="gallery",
106
- columns=[2],
107
- )
108
-
109
- advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
110
-
111
- with gr.Row(elem_id="advanced-options"):
112
- adapter_choice = gr.Dropdown(
113
- label="Choose adapter",
114
- choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss",
115
- "Camille Corot", "Claude Monet", "Pablo Picasso",
116
- "Jackson Pollock", "Gerhard Richter", "M.C. Escher",
117
- "Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein",
118
- "Henri Matisse", "Joan Miro"
119
- ],
120
- value="None"
121
- )
122
- # print(adapter_choice[0])
123
- # lora_path = lora_map[adapter_choice.value]
124
- # if lora_path is not None:
125
- # lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
126
-
127
- samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
128
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
129
- scale = gr.Slider(
130
- label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
131
- )
132
- print(scale)
133
- seed = gr.Slider(
134
- label="Seed",
135
- minimum=0,
136
- maximum=2147483647,
137
- step=1,
138
- randomize=True,
139
- )
140
-
141
- gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
142
- advanced_button.click(
143
- None,
144
- [],
145
- text,
146
- )
147
-
148
-
149
-
150
  block.launch()
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import spaces
3
+ import os
4
+ import gradio as gr
5
+ from diffusers import DiffusionPipeline
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16
14
+ pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",
15
+ dtype=dtype).to(device)
16
+
17
+ from inference import get_lora_network, inference, get_validation_dataloader
18
+ lora_map = {
19
+ "None": "None",
20
+ "Andre Derain": "andre-derain_subset1",
21
+ "Vincent van Gogh": "van_gogh_subset1",
22
+ "Andy Warhol": "andy_subset1",
23
+ "Walter Battiss": "walter-battiss_subset2",
24
+ "Camille Corot": "camille-corot_subset1",
25
+ "Claude Monet": "monet_subset2",
26
+ "Pablo Picasso": "picasso_subset1",
27
+ "Jackson Pollock": "jackson-pollock_subset1",
28
+ "Gerhard Richter": "gerhard-richter_subset1",
29
+ "M.C. Escher": "m.c.-escher_subset1",
30
+ "Albert Gleizes": "albert-gleizes_subset1",
31
+ "Hokusai": "katsushika-hokusai_subset1",
32
+ "Wassily Kandinsky": "kandinsky_subset1",
33
+ "Gustav Klimt": "klimt_subset3",
34
+ "Roy Lichtenstein": "roy-lichtenstein_subset1",
35
+ "Henri Matisse": "henri-matisse_subset1",
36
+ "Joan Miro": "joan-miro_subset2",
37
+ }
38
+ @spaces.GPU
39
+ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
40
+ adapter_path = lora_map[adapter_choice]
41
+ if adapter_path not in [None, "None"]:
42
+ adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
43
+
44
+ prompts = [prompt]*samples
45
+ infer_loader = get_validation_dataloader(prompts)
46
+ network = get_lora_network(pipe.unet, adapter_path)["network"]
47
+ pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
48
+ height=512, width=512, scales=[1.0],
49
+ save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
50
+ start_noise=-1, show=False, style_prompt="sks art", no_load=True,
51
+ from_scratch=True, device=device)[0][1.0]
52
+ return pred_images
53
+ @spaces.GPU
54
+ def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
55
+ infer_loader = get_validation_dataloader(prompts, image)
56
+ network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
57
+ pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
58
+ height=512, width=512, scales=[0.,1.],
59
+ save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
60
+ start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
61
+ from_scratch=False, device=device)
62
+ return pred_images
63
+
64
+ # def infer(prompt, samples, steps, scale, seed):
65
+ # generator = torch.Generator(device=device).manual_seed(seed)
66
+ # images_list = pipe( # type: ignore
67
+ # [prompt] * samples,
68
+ # num_inference_steps=steps,
69
+ # guidance_scale=scale,
70
+ # generator=generator,
71
+ # )
72
+ # images = []
73
+ # safe_image = Image.open(r"data/unsafe.png")
74
+ # print(images_list)
75
+ # for i, image in enumerate(images_list["images"]): # type: ignore
76
+ # if images_list["nsfw_content_detected"][i]: # type: ignore
77
+ # images.append(safe_image)
78
+ # else:
79
+ # images.append(image)
80
+ # return images
81
+
82
+
83
+
84
+
85
+ block = gr.Blocks()
86
+ # Direct infer
87
+ with block:
88
+ with gr.Group():
89
+ gr.Markdown(" # Art-Free Diffusion Demo")
90
+ with gr.Row():
91
+ text = gr.Textbox(
92
+ label="Enter your prompt",
93
+ max_lines=2,
94
+ placeholder="Enter your prompt",
95
+ container=False,
96
+ value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
97
+ )
98
+
99
+
100
+
101
+ btn = gr.Button("Run", scale=0)
102
+ gallery = gr.Gallery(
103
+ label="Generated images",
104
+ show_label=False,
105
+ elem_id="gallery",
106
+ columns=[2],
107
+ )
108
+
109
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
110
+
111
+ with gr.Row(elem_id="advanced-options"):
112
+ adapter_choice = gr.Dropdown(
113
+ label="Choose adapter",
114
+ choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss",
115
+ "Camille Corot", "Claude Monet", "Pablo Picasso",
116
+ "Jackson Pollock", "Gerhard Richter", "M.C. Escher",
117
+ "Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein",
118
+ "Henri Matisse", "Joan Miro"
119
+ ],
120
+ value="None"
121
+ )
122
+ # print(adapter_choice[0])
123
+ # lora_path = lora_map[adapter_choice.value]
124
+ # if lora_path is not None:
125
+ # lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
126
+
127
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
128
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
129
+ scale = gr.Slider(
130
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
131
+ )
132
+ print(scale)
133
+ seed = gr.Slider(
134
+ label="Seed",
135
+ minimum=0,
136
+ maximum=2147483647,
137
+ step=1,
138
+ randomize=True,
139
+ )
140
+
141
+ gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
142
+ advanced_button.click(
143
+ None,
144
+ [],
145
+ text,
146
+ )
147
+
148
+
149
+
150
  block.launch()