LPX55 commited on
Commit
4a83d65
·
verified ·
1 Parent(s): 9aee934

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -27
app.py CHANGED
@@ -4,15 +4,19 @@ import random
4
  import torch
5
  import spaces
6
  from PIL import Image
7
-
8
  from diffusers import QwenImageEditPipeline
9
  from diffusers.utils import is_xformers_available
10
-
11
  import os
12
  import base64
13
  import json
14
  from huggingface_hub import InferenceClient
15
-
 
 
 
 
 
 
16
  def get_caption_language(prompt):
17
  """Detects if the prompt contains Chinese characters."""
18
  ranges = [
@@ -22,7 +26,6 @@ def get_caption_language(prompt):
22
  if any(start <= char <= end for start, end in ranges):
23
  return 'zh'
24
  return 'en'
25
-
26
  def polish_prompt(original_prompt, system_prompt, hf_token):
27
  """
28
  Rewrites the prompt using a Hugging Face InferenceClient.
@@ -31,7 +34,6 @@ def polish_prompt(original_prompt, system_prompt, hf_token):
31
  if not hf_token or not hf_token.strip():
32
  gr.Warning("HF Token is required for prompt rewriting but was not provided!")
33
  return original_prompt
34
-
35
  client = InferenceClient(
36
  provider="cerebras",
37
  api_key=hf_token,
@@ -53,7 +55,6 @@ def polish_prompt(original_prompt, system_prompt, hf_token):
53
  print(f"Error during Hugging Face API call: {e}")
54
  gr.Warning("Failed to rewrite prompt. Using original.")
55
  return original_prompt
56
-
57
  SYSTEM_PROMPT_EDIT = '''
58
  # Edit Instruction Rewriter
59
  You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image.
@@ -85,7 +86,6 @@ Please provide the rewritten instruction in a clean `json` format as:
85
  "Rewritten": "..."
86
  }
87
  '''
88
-
89
  dtype = torch.bfloat16
90
  device = "cuda" if torch.cuda.is_available() else "cpu"
91
  pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
@@ -94,12 +94,10 @@ pipe.load_lora_weights(
94
  "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
95
  )
96
  pipe.fuse_lora()
97
-
98
  if is_xformers_available():
99
  pipe.enable_xformers_memory_efficient_attention()
100
  else:
101
  print("xformers not available or failed to load.")
102
-
103
  @spaces.GPU(duration=60)
104
  def infer(
105
  image,
@@ -116,21 +114,49 @@ def infer(
116
  """
117
  Requires user-provided HF token for prompt rewriting.
118
  """
 
119
  negative_prompt = " "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  if randomize_seed:
121
  seed = random.randint(0, MAX_SEED)
122
  generator = torch.Generator(device=device).manual_seed(seed)
123
 
124
- if rewrite_prompt:
125
- lang = get_caption_language(prompt)
126
- system_prompt = SYSTEM_PROMPT_EDIT
127
- polished_prompt = polish_prompt(prompt, system_prompt, hf_token)
128
- print(f"Rewritten Prompt: {polished_prompt}")
129
- prompt = polished_prompt
130
-
131
  edited_images = pipe(
132
  image,
133
- prompt=prompt,
134
  negative_prompt=negative_prompt,
135
  num_inference_steps=num_inference_steps,
136
  generator=generator,
@@ -138,7 +164,7 @@ def infer(
138
  num_images_per_prompt=num_images_per_prompt,
139
  ).images
140
 
141
- return edited_images, seed
142
 
143
  MAX_SEED = np.iinfo(np.int32).max
144
  examples = [
@@ -154,7 +180,7 @@ with gr.Blocks() as demo:
154
  gr.Markdown("✨ **8-step lightning inferencing with lightx2v's LoRA.**")
155
  gr.Markdown("⚠️ **Prompt rewriting requires your own [Hugging Face token](https://huggingface.co/settings/tokens)**")
156
  gr.Markdown("🚧 **Work in progress, further improvements coming soon.**")
157
-
158
  with gr.Row():
159
  with gr.Column():
160
  input_image = gr.Image(label="Input Image", type="pil")
@@ -168,7 +194,6 @@ with gr.Blocks() as demo:
168
  value=0
169
  )
170
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
171
-
172
  with gr.Row():
173
  true_guidance_scale = gr.Slider(
174
  label="True Guidance Scale",
@@ -191,12 +216,12 @@ with gr.Blocks() as demo:
191
  step=1,
192
  value=1
193
  )
194
-
195
-
196
  run_button = gr.Button("Edit", variant="primary")
197
-
198
  with gr.Column():
199
  result = gr.Gallery(label="Output Images", show_label=False, columns=1)
 
 
200
 
201
  with gr.Group():
202
  rewrite_toggle = gr.Checkbox(label="Use Prompt Rewriter (Requires HF Token)", value=False, interactive=True)
@@ -207,7 +232,6 @@ with gr.Blocks() as demo:
207
  visible=False,
208
  info="Required for prompt rewriting - get yours from [Hugging Face settings](https://huggingface.co/settings/tokens). API tokens are kept safe locally, but as good practice please make sure to double check the source code. Invalid or missing keys will revert to the original prompt entered."
209
  )
210
-
211
  def toggle_token_visibility(checked):
212
  return gr.update(visible=checked)
213
 
@@ -217,8 +241,6 @@ with gr.Blocks() as demo:
217
  outputs=[hf_token_input]
218
  )
219
 
220
-
221
-
222
  gr.on(
223
  triggers=[run_button.click, prompt.submit],
224
  fn=infer,
@@ -233,7 +255,24 @@ with gr.Blocks() as demo:
233
  hf_token_input,
234
  num_images_per_prompt
235
  ],
236
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  )
238
 
239
  if __name__ == "__main__":
 
4
  import torch
5
  import spaces
6
  from PIL import Image
 
7
  from diffusers import QwenImageEditPipeline
8
  from diffusers.utils import is_xformers_available
 
9
  import os
10
  import base64
11
  import json
12
  from huggingface_hub import InferenceClient
13
+ import logging
14
+ #############################
15
+ os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
16
+ os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
17
+ logging.basicConfig(level=logging.DEBUG)
18
+ logger = logging.getLogger(__name__)
19
+ #############################
20
  def get_caption_language(prompt):
21
  """Detects if the prompt contains Chinese characters."""
22
  ranges = [
 
26
  if any(start <= char <= end for start, end in ranges):
27
  return 'zh'
28
  return 'en'
 
29
  def polish_prompt(original_prompt, system_prompt, hf_token):
30
  """
31
  Rewrites the prompt using a Hugging Face InferenceClient.
 
34
  if not hf_token or not hf_token.strip():
35
  gr.Warning("HF Token is required for prompt rewriting but was not provided!")
36
  return original_prompt
 
37
  client = InferenceClient(
38
  provider="cerebras",
39
  api_key=hf_token,
 
55
  print(f"Error during Hugging Face API call: {e}")
56
  gr.Warning("Failed to rewrite prompt. Using original.")
57
  return original_prompt
 
58
  SYSTEM_PROMPT_EDIT = '''
59
  # Edit Instruction Rewriter
60
  You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable instruction based on the user's intent and the input image.
 
86
  "Rewritten": "..."
87
  }
88
  '''
 
89
  dtype = torch.bfloat16
90
  device = "cuda" if torch.cuda.is_available() else "cpu"
91
  pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device)
 
94
  "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors"
95
  )
96
  pipe.fuse_lora()
 
97
  if is_xformers_available():
98
  pipe.enable_xformers_memory_efficient_attention()
99
  else:
100
  print("xformers not available or failed to load.")
 
101
  @spaces.GPU(duration=60)
102
  def infer(
103
  image,
 
114
  """
115
  Requires user-provided HF token for prompt rewriting.
116
  """
117
+ original_prompt = prompt # Save original prompt for display
118
  negative_prompt = " "
119
+ prompt_info = "" # Initialize info text
120
+
121
+ # Handle prompt rewriting with status messages
122
+ if rewrite_prompt:
123
+ if not hf_token.strip():
124
+ gr.Warning("HF Token is required for prompt rewriting but was not provided!")
125
+ prompt_info = f"""## ⚠️ Prompt Rewriting Skipped (No HF Token)
126
+ **Original Prompt:**
127
+ {original_prompt}"""
128
+ rewritten_prompt = original_prompt
129
+ else:
130
+ try:
131
+ rewritten_prompt = polish_prompt(original_prompt, SYSTEM_PROMPT_EDIT, hf_token)
132
+ prompt_info = f"""## ✅ Prompt Rewrite Successful
133
+ **Original Prompt:**
134
+ {original_prompt}
135
+
136
+ **Enhanced Prompt:**
137
+ {rewritten_prompt}"""
138
+ except Exception as e:
139
+ gr.Warning(f"Prompt rewriting failed: {str(e)}")
140
+ rewritten_prompt = original_prompt
141
+ prompt_info = f"""## ❌ Prompt Rewrite Failed
142
+ **Original Prompt:**
143
+ {original_prompt}
144
+ **Error:**
145
+ {str(e)}"""
146
+ else:
147
+ rewritten_prompt = original_prompt
148
+ prompt_info = f"""## Original Prompt (No Rewrite)
149
+ **User Input:**
150
+ {original_prompt}"""
151
+
152
+ # Generate images
153
  if randomize_seed:
154
  seed = random.randint(0, MAX_SEED)
155
  generator = torch.Generator(device=device).manual_seed(seed)
156
 
 
 
 
 
 
 
 
157
  edited_images = pipe(
158
  image,
159
+ prompt=rewritten_prompt,
160
  negative_prompt=negative_prompt,
161
  num_inference_steps=num_inference_steps,
162
  generator=generator,
 
164
  num_images_per_prompt=num_images_per_prompt,
165
  ).images
166
 
167
+ return edited_images, seed, prompt_info
168
 
169
  MAX_SEED = np.iinfo(np.int32).max
170
  examples = [
 
180
  gr.Markdown("✨ **8-step lightning inferencing with lightx2v's LoRA.**")
181
  gr.Markdown("⚠️ **Prompt rewriting requires your own [Hugging Face token](https://huggingface.co/settings/tokens)**")
182
  gr.Markdown("🚧 **Work in progress, further improvements coming soon.**")
183
+
184
  with gr.Row():
185
  with gr.Column():
186
  input_image = gr.Image(label="Input Image", type="pil")
 
194
  value=0
195
  )
196
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
 
197
  with gr.Row():
198
  true_guidance_scale = gr.Slider(
199
  label="True Guidance Scale",
 
216
  step=1,
217
  value=1
218
  )
 
 
219
  run_button = gr.Button("Edit", variant="primary")
220
+
221
  with gr.Column():
222
  result = gr.Gallery(label="Output Images", show_label=False, columns=1)
223
+ # New prompt display component
224
+ prompt_info = gr.Markdown("## Prompt Details", visible=False)
225
 
226
  with gr.Group():
227
  rewrite_toggle = gr.Checkbox(label="Use Prompt Rewriter (Requires HF Token)", value=False, interactive=True)
 
232
  visible=False,
233
  info="Required for prompt rewriting - get yours from [Hugging Face settings](https://huggingface.co/settings/tokens). API tokens are kept safe locally, but as good practice please make sure to double check the source code. Invalid or missing keys will revert to the original prompt entered."
234
  )
 
235
  def toggle_token_visibility(checked):
236
  return gr.update(visible=checked)
237
 
 
241
  outputs=[hf_token_input]
242
  )
243
 
 
 
244
  gr.on(
245
  triggers=[run_button.click, prompt.submit],
246
  fn=infer,
 
255
  hf_token_input,
256
  num_images_per_prompt
257
  ],
258
+ outputs=[result, seed, prompt_info]
259
+ )
260
+
261
+ # Show prompt info box after processing
262
+ def set_prompt_visible():
263
+ return gr.update(visible=True)
264
+
265
+ run_button.click(
266
+ fn=set_prompt_visible,
267
+ inputs=None,
268
+ outputs=[prompt_info],
269
+ queue=False
270
+ )
271
+ prompt.submit(
272
+ fn=set_prompt_visible,
273
+ inputs=None,
274
+ outputs=[prompt_info],
275
+ queue=False
276
  )
277
 
278
  if __name__ == "__main__":