Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -23,9 +23,6 @@ from transformers import (
|
|
23 |
from transformers.image_utils import load_image
|
24 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
25 |
|
26 |
-
# ============================================
|
27 |
-
# CHAT & TTS SETUP
|
28 |
-
# ============================================
|
29 |
|
30 |
DESCRIPTION = """
|
31 |
# QwQ Edge 💬
|
@@ -61,13 +58,11 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
61 |
)
|
62 |
model.eval()
|
63 |
|
64 |
-
# TTS voices
|
65 |
TTS_VOICES = [
|
66 |
"en-US-JennyNeural", # @tts1
|
67 |
"en-US-GuyNeural", # @tts2
|
68 |
]
|
69 |
|
70 |
-
# Load multimodal (OCR) model and processor
|
71 |
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
72 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
73 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
@@ -93,10 +88,6 @@ def clean_chat_history(chat_history):
|
|
93 |
cleaned.append(msg)
|
94 |
return cleaned
|
95 |
|
96 |
-
# ============================================
|
97 |
-
# IMAGE GENERATION SETUP
|
98 |
-
# ============================================
|
99 |
-
|
100 |
# Environment variables and parameters for Stable Diffusion XL
|
101 |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
102 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
@@ -187,10 +178,6 @@ def generate_image_fn(
|
|
187 |
image_paths = [save_image(img) for img in images]
|
188 |
return image_paths, seed
|
189 |
|
190 |
-
# ============================================
|
191 |
-
# MAIN GENERATION FUNCTION (CHAT)
|
192 |
-
# ============================================
|
193 |
-
|
194 |
@spaces.GPU
|
195 |
def generate(
|
196 |
input_dict: dict,
|
@@ -210,9 +197,6 @@ def generate(
|
|
210 |
text = input_dict["text"]
|
211 |
files = input_dict.get("files", [])
|
212 |
|
213 |
-
# ----------------------------
|
214 |
-
# IMAGE GENERATION BRANCH
|
215 |
-
# ----------------------------
|
216 |
if text.strip().lower().startswith("@image"):
|
217 |
# Remove the "@image" tag and use the rest as prompt
|
218 |
prompt = text[len("@image"):].strip()
|
@@ -234,9 +218,6 @@ def generate(
|
|
234 |
yield gr.Image(image_paths[0])
|
235 |
return # Exit early
|
236 |
|
237 |
-
# ----------------------------
|
238 |
-
# TTS Branch (if query starts with @tts)
|
239 |
-
# ----------------------------
|
240 |
tts_prefix = "@tts"
|
241 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
242 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
@@ -253,9 +234,6 @@ def generate(
|
|
253 |
conversation = clean_chat_history(chat_history)
|
254 |
conversation.append({"role": "user", "content": text})
|
255 |
|
256 |
-
# ----------------------------
|
257 |
-
# Multimodal (image + text) branch
|
258 |
-
# ----------------------------
|
259 |
if files:
|
260 |
if len(files) > 1:
|
261 |
images = [load_image(image) for image in files]
|
@@ -285,9 +263,7 @@ def generate(
|
|
285 |
time.sleep(0.01)
|
286 |
yield buffer
|
287 |
else:
|
288 |
-
|
289 |
-
# Text-only branch
|
290 |
-
# ----------------------------
|
291 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
292 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
293 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
@@ -321,10 +297,6 @@ def generate(
|
|
321 |
output_file = asyncio.run(text_to_speech(final_response, voice))
|
322 |
yield gr.Audio(output_file, autoplay=True)
|
323 |
|
324 |
-
# ============================================
|
325 |
-
# GRADIO DEMO SETUP
|
326 |
-
# ============================================
|
327 |
-
|
328 |
demo = gr.ChatInterface(
|
329 |
fn=generate,
|
330 |
additional_inputs=[
|
|
|
23 |
from transformers.image_utils import load_image
|
24 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
25 |
|
|
|
|
|
|
|
26 |
|
27 |
DESCRIPTION = """
|
28 |
# QwQ Edge 💬
|
|
|
58 |
)
|
59 |
model.eval()
|
60 |
|
|
|
61 |
TTS_VOICES = [
|
62 |
"en-US-JennyNeural", # @tts1
|
63 |
"en-US-GuyNeural", # @tts2
|
64 |
]
|
65 |
|
|
|
66 |
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
|
67 |
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
|
68 |
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
88 |
cleaned.append(msg)
|
89 |
return cleaned
|
90 |
|
|
|
|
|
|
|
|
|
91 |
# Environment variables and parameters for Stable Diffusion XL
|
92 |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
93 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
|
|
178 |
image_paths = [save_image(img) for img in images]
|
179 |
return image_paths, seed
|
180 |
|
|
|
|
|
|
|
|
|
181 |
@spaces.GPU
|
182 |
def generate(
|
183 |
input_dict: dict,
|
|
|
197 |
text = input_dict["text"]
|
198 |
files = input_dict.get("files", [])
|
199 |
|
|
|
|
|
|
|
200 |
if text.strip().lower().startswith("@image"):
|
201 |
# Remove the "@image" tag and use the rest as prompt
|
202 |
prompt = text[len("@image"):].strip()
|
|
|
218 |
yield gr.Image(image_paths[0])
|
219 |
return # Exit early
|
220 |
|
|
|
|
|
|
|
221 |
tts_prefix = "@tts"
|
222 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
223 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
|
|
234 |
conversation = clean_chat_history(chat_history)
|
235 |
conversation.append({"role": "user", "content": text})
|
236 |
|
|
|
|
|
|
|
237 |
if files:
|
238 |
if len(files) > 1:
|
239 |
images = [load_image(image) for image in files]
|
|
|
263 |
time.sleep(0.01)
|
264 |
yield buffer
|
265 |
else:
|
266 |
+
|
|
|
|
|
267 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
268 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
269 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
|
|
297 |
output_file = asyncio.run(text_to_speech(final_response, voice))
|
298 |
yield gr.Audio(output_file, autoplay=True)
|
299 |
|
|
|
|
|
|
|
|
|
300 |
demo = gr.ChatInterface(
|
301 |
fn=generate,
|
302 |
additional_inputs=[
|