prithivMLmods commited on
Commit
2741d7d
·
verified ·
1 Parent(s): ea990b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -201
app.py CHANGED
@@ -1,210 +1,220 @@
1
- import spaces
2
- import gradio as gr
3
- import torch
4
- from PIL import Image
5
- from diffusers import DiffusionPipeline
6
- import random
7
  import uuid
8
- from typing import Tuple
9
- import numpy as np
10
-
11
- def save_image(img):
12
- unique_name = str(uuid.uuid4()) + ".png"
13
- img.save(unique_name)
14
- return unique_name
15
-
16
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
17
- if randomize_seed:
18
- seed = random.randint(0, MAX_SEED)
19
- return seed
20
-
21
- MAX_SEED = np.iinfo(np.int32).max
22
-
23
- if not torch.cuda.is_available():
24
- DESCRIPTIONz += "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
25
-
26
- base_model = "black-forest-labs/FLUX.1-dev"
27
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16)
28
-
29
- lora_repo = "strangerzonehf/3d-Station-Toon"
30
- trigger_word = "3d station toon" # Leave trigger_word blank if not used.
31
-
32
- pipe.load_lora_weights(lora_repo)
33
- pipe.to("cuda")
34
-
35
- style_list = [
36
- {
37
- "name": "3840 x 2160",
38
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
39
- },
40
- {
41
- "name": "2560 x 1440",
42
- "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
43
- },
44
- {
45
- "name": "HD+",
46
- "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
47
- },
48
- {
49
- "name": "Style Zero",
50
- "prompt": "{prompt}",
51
- },
52
- ]
53
-
54
- styles = {k["name"]: k["prompt"] for k in style_list}
55
-
56
- DEFAULT_STYLE_NAME = "3840 x 2160"
57
- STYLE_NAMES = list(styles.keys())
58
 
59
- def apply_style(style_name: str, positive: str) -> str:
60
- return styles.get(style_name, styles[DEFAULT_STYLE_NAME]).replace("{prompt}", positive)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- @spaces.GPU(duration=60, enable_queue=True)
63
  def generate(
64
- prompt: str,
65
- seed: int = 0,
66
- width: int = 1024,
67
- height: int = 1024,
68
- guidance_scale: float = 3,
69
- randomize_seed: bool = False,
70
- style_name: str = DEFAULT_STYLE_NAME,
71
- progress=gr.Progress(track_tqdm=True),
72
  ):
73
- seed = int(randomize_seed_fn(seed, randomize_seed))
74
-
75
- positive_prompt = apply_style(style_name, prompt)
 
 
 
 
 
 
 
 
 
 
76
 
77
- if trigger_word:
78
- positive_prompt = f"{trigger_word} {positive_prompt}"
79
-
80
- images = pipe(
81
- prompt=positive_prompt,
82
- width=width,
83
- height=height,
84
- guidance_scale=guidance_scale,
85
- num_inference_steps=30,
86
- num_images_per_prompt=1,
87
- output_type="pil",
88
- ).images
89
- image_paths = [save_image(img) for img in images]
90
- print(image_paths)
91
- return image_paths, seed
92
-
93
- examples = [
94
- "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250",
95
- "Woman in a red jacket, snowy, in the style of hyper-realistic portraiture, caninecore, mountainous vistas, timeless beauty, palewave, iconic, distinctive noses --ar 72:101 --stylize 750 --v 6",
96
- "Super Realism, Headshot of handsome young man, wearing dark gray sweater with buttons and big shawl collar, brown hair and short beard, serious look on his face, black background, soft studio lighting, portrait photography --ar 85:128 --v 6.0 --style",
97
- "Super-realism, Purple Dreamy, a medium-angle shot of a young woman with long brown hair, wearing a pair of eye-level glasses, stands in front of a backdrop of purple and white lights. The womans eyes are closed, her lips are slightly parted, as if she is looking up at the sky. Her hair is cascading over her shoulders, framing her face. She is wearing a sleeveless top, adorned with tiny white dots, and a gold chain necklace around her neck. Her left earrings are dangling from her ears, adding a pop of color to the scene."
98
- ]
99
-
100
- css = '''
101
- .gradio-container{max-width: 888px !important}
102
- h1{text-align:center}
103
- footer {
104
- visibility: hidden
105
- }
106
- .submit-btn {
107
- background-color: #e34949 !important;
108
- color: white !important;
109
- }
110
- .submit-btn:hover {
111
- background-color: #ff3b3b !important;
112
- }
113
- '''
114
-
115
- with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
116
- with gr.Row():
117
- with gr.Column(scale=1):
118
- prompt = gr.Text(
119
- label="Prompt",
120
- show_label=False,
121
- max_lines=1,
122
- placeholder="Enter your prompt",
123
- container=False,
124
- )
125
- run_button = gr.Button("Generate as ( 768 x 1024 )🤗", scale=0, elem_classes="submit-btn")
126
-
127
- with gr.Accordion("Advanced options", open=True, visible=True):
128
- seed = gr.Slider(
129
- label="Seed",
130
- minimum=0,
131
- maximum=MAX_SEED,
132
- step=1,
133
- value=0,
134
- visible=True
135
- )
136
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
137
-
138
- with gr.Row(visible=True):
139
- width = gr.Slider(
140
- label="Width",
141
- minimum=512,
142
- maximum=2048,
143
- step=64,
144
- value=768,
145
- )
146
- height = gr.Slider(
147
- label="Height",
148
- minimum=512,
149
- maximum=2048,
150
- step=64,
151
- value=1024,
152
- )
153
-
154
- with gr.Row():
155
- guidance_scale = gr.Slider(
156
- label="Guidance Scale",
157
- minimum=0.1,
158
- maximum=20.0,
159
- step=0.1,
160
- value=3.0,
161
- )
162
- num_inference_steps = gr.Slider(
163
- label="Number of inference steps",
164
- minimum=1,
165
- maximum=40,
166
- step=1,
167
- value=30,
168
- )
169
-
170
- style_selection = gr.Radio(
171
- show_label=True,
172
- container=True,
173
- interactive=True,
174
- choices=STYLE_NAMES,
175
- value=DEFAULT_STYLE_NAME,
176
- label="Quality Style",
177
- )
178
-
179
- with gr.Column(scale=2):
180
- result = gr.Gallery(label="Result", columns=1, show_label=False)
181
-
182
- gr.Examples(
183
- examples=examples,
184
- inputs=prompt,
185
- outputs=[result, seed],
186
- fn=generate,
187
- cache_examples=False,
188
- )
189
-
190
- gr.on(
191
- triggers=[
192
- prompt.submit,
193
- run_button.click,
194
- ],
195
- fn=generate,
196
- inputs=[
197
- prompt,
198
- seed,
199
- width,
200
- height,
201
- guidance_scale,
202
- randomize_seed,
203
- style_selection,
204
- ],
205
- outputs=[result, seed],
206
- api_name="run",
207
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
- demo.queue(max_size=40).launch()
 
1
+ import os
2
+ import re
 
 
 
 
3
  import uuid
4
+ import json
5
+ import time
6
+ import random
7
+ import asyncio
8
+ import cv2
9
+ from datetime import datetime, timedelta
10
+ from threading import Thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ import gradio as gr
13
+ import numpy as np
14
+ from PIL import Image
15
+ from huggingface_hub import hf_hub_download
16
+ from vllm import LLM
17
+ from vllm.sampling_params import SamplingParams
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Helper functions
21
+ # -----------------------------------------------------------------------------
22
+
23
+ def progress_bar_html(label: str) -> str:
24
+ """Return an HTML snippet for a progress bar."""
25
+ return f'''
26
+ <div style="display: flex; align-items: center;">
27
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
28
+ <div style="width: 110px; height: 5px; background-color: #F0FFF0; border-radius: 2px; overflow: hidden;">
29
+ <div style="width: 100%; height: 100%; background-color: #00FF00; animation: loading 1.5s linear infinite;"></div>
30
+ </div>
31
+ </div>
32
+ <style>
33
+ @keyframes loading {{
34
+ 0% {{ transform: translateX(-100%); }}
35
+ 100% {{ transform: translateX(100%); }}
36
+ }}
37
+ </style>
38
+ '''
39
+
40
+ def downsample_video(video_path: str, num_frames: int = 10):
41
+ """
42
+ Downsample a video to extract a set number of evenly spaced frames.
43
+ Returns a list of tuples (PIL.Image, timestamp in seconds).
44
+ """
45
+ vidcap = cv2.VideoCapture(video_path)
46
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
47
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
48
+ frames = []
49
+ if total_frames <= 0 or fps <= 0:
50
+ vidcap.release()
51
+ return frames
52
+ # Get evenly spaced frame indices.
53
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
54
+ for i in frame_indices:
55
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
56
+ success, image = vidcap.read()
57
+ if success:
58
+ # Convert BGR to RGB and then to a PIL Image.
59
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
60
+ pil_image = Image.fromarray(image)
61
+ timestamp = round(i / fps, 2)
62
+ frames.append((pil_image, timestamp))
63
+ vidcap.release()
64
+ return frames
65
+
66
+ def load_system_prompt(repo_id: str, filename: str) -> str:
67
+ """
68
+ Load the system prompt from the given Hugging Face Hub repo file,
69
+ and format it with the model name and current dates.
70
+ """
71
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename)
72
+ with open(file_path, "r") as file:
73
+ system_prompt = file.read()
74
+ today = datetime.today().strftime("%Y-%m-%d")
75
+ yesterday = (datetime.today() - timedelta(days=1)).strftime("%Y-%m-%d")
76
+ model_name = repo_id.split("/")[-1]
77
+ return system_prompt.format(name=model_name, today=today, yesterday=yesterday)
78
+
79
+ # -----------------------------------------------------------------------------
80
+ # Global Settings and Model Initialization
81
+ # -----------------------------------------------------------------------------
82
+
83
+ # Model details (adjust as needed)
84
+ MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
85
+ # Load the system prompt from HF Hub (make sure SYSTEM_PROMPT.txt exists in the repo)
86
+ SYSTEM_PROMPT = load_system_prompt(MODEL_ID, "SYSTEM_PROMPT.txt")
87
+ # If you prefer a hardcoded system prompt, you can use:
88
+ # SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, and ends with an ASCII cat."
89
+
90
+ # Initialize the Mistral LLM via vllm.
91
+ # Note: Running this model on GPU may require very high VRAM.
92
+ llm = LLM(model=MODEL_ID, tokenizer_mode="mistral")
93
+
94
+ # -----------------------------------------------------------------------------
95
+ # Main Generation Function
96
+ # -----------------------------------------------------------------------------
97
 
 
98
  def generate(
99
+ input_dict: dict,
100
+ chat_history: list,
101
+ max_new_tokens: int = 512,
102
+ temperature: float = 0.15,
103
+ top_p: float = 0.9,
104
+ top_k: int = 50,
 
 
105
  ):
106
+ """
107
+ The main generation function for the Mistral chatbot.
108
+ It supports:
109
+ - Text-only inference.
110
+ - Image inference (attaches image file paths).
111
+ - Video inference (extracts and attaches sampled video frames).
112
+ """
113
+ text = input_dict["text"]
114
+ files = input_dict.get("files", [])
115
+ # Prepare the conversation with a system prompt.
116
+ messages = [
117
+ {"role": "system", "content": SYSTEM_PROMPT}
118
+ ]
119
 
120
+ # Check if any file is provided
121
+ video_extensions = (".mp4", ".mov", ".avi", ".mkv", ".webm")
122
+ if files:
123
+ # If any file is a video, use video inference branch.
124
+ if any(str(f).lower().endswith(video_extensions) for f in files):
125
+ # Remove any @video-infer tag if present.
126
+ prompt_clean = re.sub(r"@video-infer", "", text, flags=re.IGNORECASE).strip().strip('"')
127
+ video_path = files[0] # currently process the first video file
128
+ frames = downsample_video(video_path)
129
+ # Build a list that contains the prompt plus each frame information.
130
+ user_content = [{"type": "text", "text": prompt_clean}]
131
+ for frame in frames:
132
+ image, timestamp = frame
133
+ # Save the frame to a temporary file.
134
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
135
+ image.save(image_path)
136
+ user_content.append({"type": "text", "text": f"Frame at {timestamp} seconds:"})
137
+ user_content.append({"type": "image_path", "image_path": image_path})
138
+ messages.append({"role": "user", "content": user_content})
139
+ else:
140
+ # Assume provided files are images.
141
+ prompt_clean = re.sub(r"@mistral", "", text, flags=re.IGNORECASE).strip().strip('"')
142
+ user_content = [{"type": "text", "text": prompt_clean}]
143
+ for file in files:
144
+ try:
145
+ image = Image.open(file)
146
+ image_path = f"image_{uuid.uuid4().hex}.png"
147
+ image.save(image_path)
148
+ user_content.append({"type": "image_path", "image_path": image_path})
149
+ except Exception as e:
150
+ user_content.append({"type": "text", "text": f"Could not open file {file}"})
151
+ messages.append({"role": "user", "content": user_content})
152
+ else:
153
+ # Text-only branch.
154
+ messages.append({"role": "user", "content": [{"type": "text", "text": text}]})
155
+
156
+ # Show a progress bar before generating.
157
+ yield progress_bar_html("Processing with Mistral")
158
+
159
+ # Set up sampling parameters.
160
+ sampling_params = SamplingParams(
161
+ max_tokens=max_new_tokens,
162
+ temperature=temperature,
163
+ top_p=top_p,
164
+ top_k=top_k
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
166
+ # Run the chat (synchronously) using vllm.
167
+ outputs = llm.chat(messages, sampling_params=sampling_params)
168
+ final_response = outputs[0].outputs[0].text
169
+
170
+ # Simulate streaming output by chunking the result.
171
+ buffer = ""
172
+ chunk_size = 20 # number of characters per chunk
173
+ for i in range(0, len(final_response), chunk_size):
174
+ buffer = final_response[: i + chunk_size]
175
+ yield buffer
176
+ time.sleep(0.05)
177
+ return
178
+
179
+ # -----------------------------------------------------------------------------
180
+ # Gradio Interface Setup
181
+ # -----------------------------------------------------------------------------
182
+
183
+ demo = gr.ChatInterface(
184
+ fn=generate,
185
+ additional_inputs=[
186
+ gr.Slider(label="Max new tokens", minimum=1, maximum=1024, step=1, value=512),
187
+ gr.Slider(label="Temperature", minimum=0.05, maximum=2.0, step=0.05, value=0.15),
188
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
189
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
190
+ ],
191
+ examples=[
192
+ # Example with text only.
193
+ ["Explain the significance of today in the context of current events."],
194
+ # Example with image files (ensure you have valid image paths).
195
+ [{
196
+ "text": "Describe what you see in the image.",
197
+ "files": ["examples/3.jpg"]
198
+ }],
199
+ # Example with video file (ensure you have a valid video file).
200
+ [{
201
+ "text": "@video-infer Summarize the events shown in the video.",
202
+ "files": ["examples/sample_video.mp4"]
203
+ }],
204
+ ],
205
+ cache_examples=False,
206
+ type="messages",
207
+ description="# **Mistral Multimodal Chatbot** \nSupports text, image (by reference) and video inference. Use @video-infer in your query when providing a video.",
208
+ fill_height=True,
209
+ textbox=gr.MultimodalTextbox(
210
+ label="Query Input",
211
+ file_types=["image", "video"],
212
+ file_count="multiple",
213
+ placeholder="Enter your query here. Tag with @video-infer if using a video file."
214
+ ),
215
+ stop_btn="Stop Generation",
216
+ examples_per_page=3,
217
+ )
218
 
219
  if __name__ == "__main__":
220
+ demo.queue(max_size=20).launch(share=True)