ovi054 commited on
Commit
021dd02
·
verified ·
1 Parent(s): d232e5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -58
app.py CHANGED
@@ -1,17 +1,25 @@
1
  import gradio as gr
2
- import requests
3
- import io
4
  import random
5
- import os
6
- import time
7
- from PIL import Image
8
- from deep_translator import GoogleTranslator
9
- import json
10
 
11
 
12
- API_TOKEN = os.getenv("HF_READ_TOKEN")
13
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
14
- timeout = 100
 
 
 
 
 
 
 
 
 
15
 
16
  article_text = """
17
  <div style="text-align: center;">
@@ -27,57 +35,104 @@ article_text = """
27
  </div>
28
  """
29
 
30
- def query(lora_id, prompt, steps=28, cfg_scale=3.5, randomize_seed=True, seed=-1, width=1024, height=1024):
31
- if prompt == "" or prompt == None:
32
- return None
33
-
34
- if lora_id.strip() == "" or lora_id == None:
35
- lora_id = "black-forest-labs/FLUX.1-dev"
36
-
37
- key = random.randint(0, 999)
38
 
39
- API_URL = "https://api-inference.huggingface.co/models/"+ lora_id.strip()
40
 
41
- API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN")])
42
- headers = {"Authorization": f"Bearer {API_TOKEN}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- # prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
45
- # print(f'\033[1mGeneration {key} translation:\033[0m {prompt}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
48
- # print(f'\033[1mGeneration {key}:\033[0m {prompt}')
49
 
50
- # If seed is -1, generate a random seed and use it
51
- if randomize_seed:
52
- seed = random.randint(1, 4294967296)
53
 
54
- payload = {
55
- "inputs": prompt,
56
- "steps": steps,
57
- "cfg_scale": cfg_scale,
58
- "seed": seed,
59
- "parameters": {
60
- "width": width, # Pass the width to the API
61
- "height": height # Pass the height to the API
62
- }
63
- }
64
-
65
- response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
66
- if response.status_code != 200:
67
- print(f"Error: Failed to get image. Response status: {response.status_code}")
68
- print(f"Response content: {response.text}")
69
- if response.status_code == 503:
70
- raise gr.Error(f"{response.status_code} : The model is being loaded")
71
- raise gr.Error(f"{response.status_code}")
72
 
73
- try:
74
- image_bytes = response.content
75
- image = Image.open(io.BytesIO(image_bytes))
76
- print(f'\033[1mGeneration {key} completed!\033[0m ({prompt})')
77
- return image, seed, seed
78
- except Exception as e:
79
- print(f"Error when trying to open the image: {e}")
80
- return None
81
 
82
 
83
  examples = [
@@ -114,13 +169,20 @@ with gr.Blocks(css=css) as app:
114
  with gr.Row():
115
  with gr.Accordion("Advanced Settings", open=False):
116
  with gr.Row():
 
 
 
 
 
 
 
117
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=8)
118
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=8)
119
  seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=4294967296, step=1)
120
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
121
  with gr.Row():
122
- steps = gr.Slider(label="Sampling steps", value=28, minimum=1, maximum=100, step=1)
123
- cfg = gr.Slider(label="CFG Scale", value=3.5, minimum=1, maximum=20, step=0.5)
124
  # method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"])
125
 
126
  with gr.Row():
@@ -140,6 +202,7 @@ with gr.Blocks(css=css) as app:
140
  )
141
 
142
 
143
- text_button.click(query, inputs=[custom_lora, text_prompt, steps, cfg, randomize_seed, seed, width, height], outputs=[image_output,seed_output, seed])
 
144
 
145
- app.launch(show_api=False, share=False)
 
1
  import gradio as gr
2
+ import numpy as np
 
3
  import random
4
+ import spaces
5
+ import torch
6
+ from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
+ from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
8
+ from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
9
 
10
 
11
+ dtype = torch.bfloat16
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
15
+ good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=dtype).to(device)
16
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, vae=taef1).to(device)
17
+ torch.cuda.empty_cache()
18
+
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ MAX_IMAGE_SIZE = 2048
21
+
22
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
23
 
24
  article_text = """
25
  <div style="text-align: center;">
 
35
  </div>
36
  """
37
 
38
+ @spaces.GPU()
39
+ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
40
+ if randomize_seed:
41
+ seed = random.randint(0, MAX_SEED)
42
+ generator = torch.Generator().manual_seed(seed)
 
 
 
43
 
 
44
 
45
+ # for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
46
+ # prompt=prompt,
47
+ # guidance_scale=guidance_scale,
48
+ # num_inference_steps=num_inference_steps,
49
+ # width=width,
50
+ # height=height,
51
+ # generator=generator,
52
+ # output_type="pil",
53
+ # good_vae=good_vae,
54
+ # ):
55
+ # yield img, seed
56
+
57
+ # Handle LoRA loading
58
+ # Load LoRA weights and prepare joint_attention_kwargs
59
+ if lora_id:
60
+ pipe.unload_lora_weights()
61
+ pipe.load_lora_weights(lora_id)
62
+ joint_attention_kwargs = {"scale": lora_scale}
63
+ else:
64
+ joint_attention_kwargs = None
65
 
66
+ try:
67
+ # Call the custom pipeline function with the correct keyword argument
68
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
69
+ prompt=prompt,
70
+ guidance_scale=guidance_scale,
71
+ num_inference_steps=num_inference_steps,
72
+ width=width,
73
+ height=height,
74
+ generator=generator,
75
+ output_type="pil",
76
+ good_vae=good_vae, # Assuming good_vae is defined elsewhere
77
+ joint_attention_kwargs=joint_attention_kwargs, # Fixed parameter name
78
+ ):
79
+ yield img, seed
80
+ finally:
81
+ # Unload LoRA weights if they were loaded
82
+ if lora_id:
83
+ pipe.unload_lora_weights()
84
+
85
+ # def query(lora_id, prompt, steps=28, cfg_scale=3.5, randomize_seed=True, seed=-1, width=1024, height=1024):
86
+ # if prompt == "" or prompt == None:
87
+ # return None
88
+
89
+ # if lora_id.strip() == "" or lora_id == None:
90
+ # lora_id = "black-forest-labs/FLUX.1-dev"
91
+
92
+ # key = random.randint(0, 999)
93
+
94
+ # API_URL = "https://api-inference.huggingface.co/models/"+ lora_id.strip()
95
+
96
+ # API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN")])
97
+ # headers = {"Authorization": f"Bearer {API_TOKEN}"}
98
+
99
+ # # prompt = GoogleTranslator(source='ru', target='en').translate(prompt)
100
+ # # print(f'\033[1mGeneration {key} translation:\033[0m {prompt}')
101
 
102
+ # prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
103
+ # # print(f'\033[1mGeneration {key}:\033[0m {prompt}')
104
 
105
+ # # If seed is -1, generate a random seed and use it
106
+ # if randomize_seed:
107
+ # seed = random.randint(1, 4294967296)
108
 
109
+ # payload = {
110
+ # "inputs": prompt,
111
+ # "steps": steps,
112
+ # "cfg_scale": cfg_scale,
113
+ # "seed": seed,
114
+ # "parameters": {
115
+ # "width": width, # Pass the width to the API
116
+ # "height": height # Pass the height to the API
117
+ # }
118
+ # }
119
+
120
+ # response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
121
+ # if response.status_code != 200:
122
+ # print(f"Error: Failed to get image. Response status: {response.status_code}")
123
+ # print(f"Response content: {response.text}")
124
+ # if response.status_code == 503:
125
+ # raise gr.Error(f"{response.status_code} : The model is being loaded")
126
+ # raise gr.Error(f"{response.status_code}")
127
 
128
+ # try:
129
+ # image_bytes = response.content
130
+ # image = Image.open(io.BytesIO(image_bytes))
131
+ # print(f'\033[1mGeneration {key} completed!\033[0m ({prompt})')
132
+ # return image, seed, seed
133
+ # except Exception as e:
134
+ # print(f"Error when trying to open the image: {e}")
135
+ # return None
136
 
137
 
138
  examples = [
 
169
  with gr.Row():
170
  with gr.Accordion("Advanced Settings", open=False):
171
  with gr.Row():
172
+ lora_scale = gr.Slider(
173
+ label="LoRA Scale",
174
+ minimum=0,
175
+ maximum=2,
176
+ step=0.01,
177
+ value=0.95,
178
+ )
179
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=8)
180
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=8)
181
  seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=4294967296, step=1)
182
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
183
  with gr.Row():
184
+ steps = gr.Slider(label="Inference steps steps", value=28, minimum=1, maximum=100, step=1)
185
+ cfg = gr.Slider(label="Guidance Scale", value=3.5, minimum=1, maximum=20, step=0.5)
186
  # method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"])
187
 
188
  with gr.Row():
 
202
  )
203
 
204
 
205
+ # text_button.click(query, inputs=[custom_lora, text_prompt, steps, cfg, randomize_seed, seed, width, height], outputs=[image_output,seed_output, seed])
206
+ text_button.click(infer, inputs=[text_prompt, seed, randomize_seed, width, height, cfg, steps, custom_lora, lora_scale], outputs=[image_output,seed_output, seed])
207
 
208
+ app.launch()