Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,210 +1,220 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import torch
|
4 |
-
from PIL import Image
|
5 |
-
from diffusers import DiffusionPipeline
|
6 |
-
import random
|
7 |
import uuid
|
8 |
-
|
9 |
-
import
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
17 |
-
if randomize_seed:
|
18 |
-
seed = random.randint(0, MAX_SEED)
|
19 |
-
return seed
|
20 |
-
|
21 |
-
MAX_SEED = np.iinfo(np.int32).max
|
22 |
-
|
23 |
-
if not torch.cuda.is_available():
|
24 |
-
DESCRIPTIONz += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
|
25 |
-
|
26 |
-
base_model = "black-forest-labs/FLUX.1-dev"
|
27 |
-
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
|
28 |
-
|
29 |
-
lora_repo = "strangerzonehf/3d-Station-Toon"
|
30 |
-
trigger_word = "3d station toon" # Leave trigger_word blank if not used.
|
31 |
-
|
32 |
-
pipe.load_lora_weights(lora_repo)
|
33 |
-
pipe.to("cuda")
|
34 |
-
|
35 |
-
style_list = [
|
36 |
-
{
|
37 |
-
"name": "3840 x 2160",
|
38 |
-
"prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
39 |
-
},
|
40 |
-
{
|
41 |
-
"name": "2560 x 1440",
|
42 |
-
"prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
43 |
-
},
|
44 |
-
{
|
45 |
-
"name": "HD+",
|
46 |
-
"prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
|
47 |
-
},
|
48 |
-
{
|
49 |
-
"name": "Style Zero",
|
50 |
-
"prompt": "{prompt}",
|
51 |
-
},
|
52 |
-
]
|
53 |
-
|
54 |
-
styles = {k["name"]: k["prompt"] for k in style_list}
|
55 |
-
|
56 |
-
DEFAULT_STYLE_NAME = "3840 x 2160"
|
57 |
-
STYLE_NAMES = list(styles.keys())
|
58 |
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
-
@spaces.GPU(duration=60, enable_queue=True)
|
63 |
def generate(
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
style_name: str = DEFAULT_STYLE_NAME,
|
71 |
-
progress=gr.Progress(track_tqdm=True),
|
72 |
):
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
if
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
}
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
placeholder="Enter your prompt",
|
123 |
-
container=False,
|
124 |
-
)
|
125 |
-
run_button = gr.Button("Generate as ( 768 x 1024 )🤗", scale=0, elem_classes="submit-btn")
|
126 |
-
|
127 |
-
with gr.Accordion("Advanced options", open=True, visible=True):
|
128 |
-
seed = gr.Slider(
|
129 |
-
label="Seed",
|
130 |
-
minimum=0,
|
131 |
-
maximum=MAX_SEED,
|
132 |
-
step=1,
|
133 |
-
value=0,
|
134 |
-
visible=True
|
135 |
-
)
|
136 |
-
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
137 |
-
|
138 |
-
with gr.Row(visible=True):
|
139 |
-
width = gr.Slider(
|
140 |
-
label="Width",
|
141 |
-
minimum=512,
|
142 |
-
maximum=2048,
|
143 |
-
step=64,
|
144 |
-
value=768,
|
145 |
-
)
|
146 |
-
height = gr.Slider(
|
147 |
-
label="Height",
|
148 |
-
minimum=512,
|
149 |
-
maximum=2048,
|
150 |
-
step=64,
|
151 |
-
value=1024,
|
152 |
-
)
|
153 |
-
|
154 |
-
with gr.Row():
|
155 |
-
guidance_scale = gr.Slider(
|
156 |
-
label="Guidance Scale",
|
157 |
-
minimum=0.1,
|
158 |
-
maximum=20.0,
|
159 |
-
step=0.1,
|
160 |
-
value=3.0,
|
161 |
-
)
|
162 |
-
num_inference_steps = gr.Slider(
|
163 |
-
label="Number of inference steps",
|
164 |
-
minimum=1,
|
165 |
-
maximum=40,
|
166 |
-
step=1,
|
167 |
-
value=30,
|
168 |
-
)
|
169 |
-
|
170 |
-
style_selection = gr.Radio(
|
171 |
-
show_label=True,
|
172 |
-
container=True,
|
173 |
-
interactive=True,
|
174 |
-
choices=STYLE_NAMES,
|
175 |
-
value=DEFAULT_STYLE_NAME,
|
176 |
-
label="Quality Style",
|
177 |
-
)
|
178 |
-
|
179 |
-
with gr.Column(scale=2):
|
180 |
-
result = gr.Gallery(label="Result", columns=1, show_label=False)
|
181 |
-
|
182 |
-
gr.Examples(
|
183 |
-
examples=examples,
|
184 |
-
inputs=prompt,
|
185 |
-
outputs=[result, seed],
|
186 |
-
fn=generate,
|
187 |
-
cache_examples=False,
|
188 |
-
)
|
189 |
-
|
190 |
-
gr.on(
|
191 |
-
triggers=[
|
192 |
-
prompt.submit,
|
193 |
-
run_button.click,
|
194 |
-
],
|
195 |
-
fn=generate,
|
196 |
-
inputs=[
|
197 |
-
prompt,
|
198 |
-
seed,
|
199 |
-
width,
|
200 |
-
height,
|
201 |
-
guidance_scale,
|
202 |
-
randomize_seed,
|
203 |
-
style_selection,
|
204 |
-
],
|
205 |
-
outputs=[result, seed],
|
206 |
-
api_name="run",
|
207 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
if __name__ == "__main__":
|
210 |
-
demo.queue(max_size=
|
|
|
1 |
+
import os
|
2 |
+
import re
|
|
|
|
|
|
|
|
|
3 |
import uuid
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import random
|
7 |
+
import asyncio
|
8 |
+
import cv2
|
9 |
+
from datetime import datetime, timedelta
|
10 |
+
from threading import Thread
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
import gradio as gr
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
from vllm import LLM
|
17 |
+
from vllm.sampling_params import SamplingParams
|
18 |
+
|
19 |
+
# -----------------------------------------------------------------------------
|
20 |
+
# Helper functions
|
21 |
+
# -----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
def progress_bar_html(label: str) -> str:
|
24 |
+
"""Return an HTML snippet for a progress bar."""
|
25 |
+
return f'''
|
26 |
+
<div style="display: flex; align-items: center;">
|
27 |
+
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
|
28 |
+
<div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
|
29 |
+
<div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
|
30 |
+
</div>
|
31 |
+
</div>
|
32 |
+
<style>
|
33 |
+
@keyframes loading {{
|
34 |
+
0% {{ transform: translateX(-100%); }}
|
35 |
+
100% {{ transform: translateX(100%); }}
|
36 |
+
}}
|
37 |
+
</style>
|
38 |
+
'''
|
39 |
+
|
40 |
+
def downsample_video(video_path: str, num_frames: int = 10):
|
41 |
+
"""
|
42 |
+
Downsample a video to extract a set number of evenly spaced frames.
|
43 |
+
Returns a list of tuples (PIL.Image, timestamp in seconds).
|
44 |
+
"""
|
45 |
+
vidcap = cv2.VideoCapture(video_path)
|
46 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
47 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
48 |
+
frames = []
|
49 |
+
if total_frames <= 0 or fps <= 0:
|
50 |
+
vidcap.release()
|
51 |
+
return frames
|
52 |
+
# Get evenly spaced frame indices.
|
53 |
+
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
54 |
+
for i in frame_indices:
|
55 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
56 |
+
success, image = vidcap.read()
|
57 |
+
if success:
|
58 |
+
# Convert BGR to RGB and then to a PIL Image.
|
59 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
60 |
+
pil_image = Image.fromarray(image)
|
61 |
+
timestamp = round(i / fps, 2)
|
62 |
+
frames.append((pil_image, timestamp))
|
63 |
+
vidcap.release()
|
64 |
+
return frames
|
65 |
+
|
66 |
+
def load_system_prompt(repo_id: str, filename: str) -> str:
|
67 |
+
"""
|
68 |
+
Load the system prompt from the given Hugging Face Hub repo file,
|
69 |
+
and format it with the model name and current dates.
|
70 |
+
"""
|
71 |
+
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
72 |
+
with open(file_path, "r") as file:
|
73 |
+
system_prompt = file.read()
|
74 |
+
today = datetime.today().strftime("%Y-%m-%d")
|
75 |
+
yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
|
76 |
+
model_name = repo_id.split("/")[-1]
|
77 |
+
return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
|
78 |
+
|
79 |
+
# -----------------------------------------------------------------------------
|
80 |
+
# Global Settings and Model Initialization
|
81 |
+
# -----------------------------------------------------------------------------
|
82 |
+
|
83 |
+
# Model details (adjust as needed)
|
84 |
+
MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
85 |
+
# Load the system prompt from HF Hub (make sure SYSTEM_PROMPT.txt exists in the repo)
|
86 |
+
SYSTEM_PROMPT = load_system_prompt(MODEL_ID, "SYSTEM_PROMPT.txt")
|
87 |
+
# If you prefer a hardcoded system prompt, you can use:
|
88 |
+
# SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, and ends with an ASCII cat."
|
89 |
+
|
90 |
+
# Initialize the Mistral LLM via vllm.
|
91 |
+
# Note: Running this model on GPU may require very high VRAM.
|
92 |
+
llm = LLM(model=MODEL_ID, tokenizer_mode="mistral")
|
93 |
+
|
94 |
+
# -----------------------------------------------------------------------------
|
95 |
+
# Main Generation Function
|
96 |
+
# -----------------------------------------------------------------------------
|
97 |
|
|
|
98 |
def generate(
|
99 |
+
input_dict: dict,
|
100 |
+
chat_history: list,
|
101 |
+
max_new_tokens: int = 512,
|
102 |
+
temperature: float = 0.15,
|
103 |
+
top_p: float = 0.9,
|
104 |
+
top_k: int = 50,
|
|
|
|
|
105 |
):
|
106 |
+
"""
|
107 |
+
The main generation function for the Mistral chatbot.
|
108 |
+
It supports:
|
109 |
+
- Text-only inference.
|
110 |
+
- Image inference (attaches image file paths).
|
111 |
+
- Video inference (extracts and attaches sampled video frames).
|
112 |
+
"""
|
113 |
+
text = input_dict["text"]
|
114 |
+
files = input_dict.get("files", [])
|
115 |
+
# Prepare the conversation with a system prompt.
|
116 |
+
messages = [
|
117 |
+
{"role": "system", "content": SYSTEM_PROMPT}
|
118 |
+
]
|
119 |
|
120 |
+
# Check if any file is provided
|
121 |
+
video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
|
122 |
+
if files:
|
123 |
+
# If any file is a video, use video inference branch.
|
124 |
+
if any(str(f).lower().endswith(video_extensions) for f in files):
|
125 |
+
# Remove any @video-infer tag if present.
|
126 |
+
prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
|
127 |
+
video_path = files[0] # currently process the first video file
|
128 |
+
frames = downsample_video(video_path)
|
129 |
+
# Build a list that contains the prompt plus each frame information.
|
130 |
+
user_content = [{"type": "text", "text": prompt_clean}]
|
131 |
+
for frame in frames:
|
132 |
+
image, timestamp = frame
|
133 |
+
# Save the frame to a temporary file.
|
134 |
+
image_path = f"video_frame_{uuid.uuid4().hex}.png"
|
135 |
+
image.save(image_path)
|
136 |
+
user_content.append({"type": "text", "text": f"Frame at {timestamp} seconds:"})
|
137 |
+
user_content.append({"type": "image_path", "image_path": image_path})
|
138 |
+
messages.append({"role": "user", "content": user_content})
|
139 |
+
else:
|
140 |
+
# Assume provided files are images.
|
141 |
+
prompt_clean = re.sub(r"@mistral", "", text, flags=re.IGNORECASE).strip().strip('"')
|
142 |
+
user_content = [{"type": "text", "text": prompt_clean}]
|
143 |
+
for file in files:
|
144 |
+
try:
|
145 |
+
image = Image.open(file)
|
146 |
+
image_path = f"image_{uuid.uuid4().hex}.png"
|
147 |
+
image.save(image_path)
|
148 |
+
user_content.append({"type": "image_path", "image_path": image_path})
|
149 |
+
except Exception as e:
|
150 |
+
user_content.append({"type": "text", "text": f"Could not open file {file}"})
|
151 |
+
messages.append({"role": "user", "content": user_content})
|
152 |
+
else:
|
153 |
+
# Text-only branch.
|
154 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
|
155 |
+
|
156 |
+
# Show a progress bar before generating.
|
157 |
+
yield progress_bar_html("Processing with Mistral")
|
158 |
+
|
159 |
+
# Set up sampling parameters.
|
160 |
+
sampling_params = SamplingParams(
|
161 |
+
max_tokens=max_new_tokens,
|
162 |
+
temperature=temperature,
|
163 |
+
top_p=top_p,
|
164 |
+
top_k=top_k
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
)
|
166 |
+
# Run the chat (synchronously) using vllm.
|
167 |
+
outputs = llm.chat(messages, sampling_params=sampling_params)
|
168 |
+
final_response = outputs[0].outputs[0].text
|
169 |
+
|
170 |
+
# Simulate streaming output by chunking the result.
|
171 |
+
buffer = ""
|
172 |
+
chunk_size = 20 # number of characters per chunk
|
173 |
+
for i in range(0, len(final_response), chunk_size):
|
174 |
+
buffer = final_response[: i + chunk_size]
|
175 |
+
yield buffer
|
176 |
+
time.sleep(0.05)
|
177 |
+
return
|
178 |
+
|
179 |
+
# -----------------------------------------------------------------------------
|
180 |
+
# Gradio Interface Setup
|
181 |
+
# -----------------------------------------------------------------------------
|
182 |
+
|
183 |
+
demo = gr.ChatInterface(
|
184 |
+
fn=generate,
|
185 |
+
additional_inputs=[
|
186 |
+
gr.Slider(label="Max new tokens", minimum=1, maximum=1024, step=1, value=512),
|
187 |
+
gr.Slider(label="Temperature", minimum=0.05, maximum=2.0, step=0.05, value=0.15),
|
188 |
+
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
|
189 |
+
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
|
190 |
+
],
|
191 |
+
examples=[
|
192 |
+
# Example with text only.
|
193 |
+
["Explain the significance of today in the context of current events."],
|
194 |
+
# Example with image files (ensure you have valid image paths).
|
195 |
+
[{
|
196 |
+
"text": "Describe what you see in the image.",
|
197 |
+
"files": ["examples/3.jpg"]
|
198 |
+
}],
|
199 |
+
# Example with video file (ensure you have a valid video file).
|
200 |
+
[{
|
201 |
+
"text": "@video-infer Summarize the events shown in the video.",
|
202 |
+
"files": ["examples/sample_video.mp4"]
|
203 |
+
}],
|
204 |
+
],
|
205 |
+
cache_examples=False,
|
206 |
+
type="messages",
|
207 |
+
description="# **Mistral Multimodal Chatbot** \nSupports text, image (by reference) and video inference. Use @video-infer in your query when providing a video.",
|
208 |
+
fill_height=True,
|
209 |
+
textbox=gr.MultimodalTextbox(
|
210 |
+
label="Query Input",
|
211 |
+
file_types=["image", "video"],
|
212 |
+
file_count="multiple",
|
213 |
+
placeholder="Enter your query here. Tag with @video-infer if using a video file."
|
214 |
+
),
|
215 |
+
stop_btn="Stop Generation",
|
216 |
+
examples_per_page=3,
|
217 |
+
)
|
218 |
|
219 |
if __name__ == "__main__":
|
220 |
+
demo.queue(max_size=20).launch(share=True)
|