fffiloni commited on
Commit
536d153
·
verified ·
1 Parent(s): a40a1a9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL
4
+ from blora_utils import BLOCKS, filter_lora, scale_lora
5
+
6
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
7
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
8
+ "stabilityai/stable-diffusion-xl-base-1.0",
9
+ vae=vae,
10
+ torch_dtype=torch.float16,
11
+ ).to("cuda")
12
+
13
+ def load_b_lora_to_unet(pipe, content_lora_model_id: str = '', style_lora_model_id: str = '', content_alpha: float = 1.,
14
+ style_alpha: float = 1.) -> None:
15
+ try:
16
+ # Get Content B-LoRA SD
17
+ if content_lora_model_id:
18
+ content_B_LoRA_sd, _ = pipe.lora_state_dict(content_lora_model_id)
19
+ content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content'])
20
+ content_B_LoRA = scale_lora(content_B_LoRA, content_alpha)
21
+ else:
22
+ content_B_LoRA = {}
23
+
24
+ # Get Style B-LoRA SD
25
+ if style_lora_model_id:
26
+ style_B_LoRA_sd, _ = pipe.lora_state_dict(style_lora_model_id)
27
+ style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style'])
28
+ style_B_LoRA = scale_lora(style_B_LoRA, style_alpha)
29
+ else:
30
+ style_B_LoRA = {}
31
+
32
+ # Merge B-LoRAs SD
33
+ res_lora = {**content_B_LoRA, **style_B_LoRA}
34
+
35
+ # Load
36
+ pipe.load_lora_into_unet(res_lora, None, pipe.unet)
37
+ except Exception as e:
38
+ raise type(e)(f'failed to load_b_lora_to_unet, due to: {e}')
39
+
40
+ def main(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
41
+ content_B_LoRA_path = ''
42
+ style_B_LoRA_path = 'fffiloni/b_lora_trained_test_7'
43
+ content_alpha,style_alpha = 1,1.1
44
+
45
+ load_b_lora_to_unet(pipeline, content_B_LoRA_path, style_B_LoRA_path, content_alpha, style_alpha)
46
+ prompt = 'An eagle in [v42] style'
47
+ image = pipeline(
48
+ prompt,
49
+ generator=torch.Generator(device="cuda").manual_seed(48),
50
+ num_images_per_prompt=1
51
+ ).images[0].resize((512,512))
52
+
53
+ pipeline.unload_lora_weights()
54
+
55
+ return image
56
+
57
+ css="""
58
+ #col-container {
59
+ margin: 0 auto;
60
+ max-width: 520px;
61
+ }
62
+ """
63
+
64
+ if torch.cuda.is_available():
65
+ power_device = "GPU"
66
+ else:
67
+ power_device = "CPU"
68
+
69
+ with gr.Blocks(css=css) as demo:
70
+
71
+ with gr.Column(elem_id="col-container"):
72
+ gr.Markdown(f"""
73
+ # Text-to-Image Gradio Template
74
+ Currently running on {power_device}.
75
+ """)
76
+
77
+ with gr.Row():
78
+
79
+ prompt = gr.Text(
80
+ label="Prompt",
81
+ show_label=False,
82
+ max_lines=1,
83
+ placeholder="Enter your prompt",
84
+ container=False,
85
+ )
86
+
87
+ run_button = gr.Button("Run", scale=0)
88
+
89
+ result = gr.Image(label="Result", show_label=False)
90
+
91
+ with gr.Accordion("Advanced Settings", open=False):
92
+
93
+ negative_prompt = gr.Text(
94
+ label="Negative prompt",
95
+ max_lines=1,
96
+ placeholder="Enter a negative prompt",
97
+ visible=False,
98
+ )
99
+
100
+ seed = gr.Slider(
101
+ label="Seed",
102
+ minimum=0,
103
+ maximum=MAX_SEED,
104
+ step=1,
105
+ value=0,
106
+ )
107
+
108
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
109
+
110
+ with gr.Row():
111
+
112
+ width = gr.Slider(
113
+ label="Width",
114
+ minimum=256,
115
+ maximum=MAX_IMAGE_SIZE,
116
+ step=32,
117
+ value=512,
118
+ )
119
+
120
+ height = gr.Slider(
121
+ label="Height",
122
+ minimum=256,
123
+ maximum=MAX_IMAGE_SIZE,
124
+ step=32,
125
+ value=512,
126
+ )
127
+
128
+ with gr.Row():
129
+
130
+ guidance_scale = gr.Slider(
131
+ label="Guidance scale",
132
+ minimum=0.0,
133
+ maximum=10.0,
134
+ step=0.1,
135
+ value=0.0,
136
+ )
137
+
138
+ num_inference_steps = gr.Slider(
139
+ label="Number of inference steps",
140
+ minimum=1,
141
+ maximum=50,
142
+ step=1,
143
+ value=50,
144
+ )
145
+
146
+ gr.Examples(
147
+ examples = examples,
148
+ inputs = [prompt]
149
+ )
150
+
151
+ run_button.click(
152
+ fn = main,
153
+ inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
154
+ outputs = [result]
155
+ )
156
+
157
+ demo.queue().launch()