prithivMLmods commited on
Commit
235f049
·
verified ·
1 Parent(s): 82d9471

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -299
app.py CHANGED
@@ -1,300 +1,300 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
-
9
- import gradio as gr
10
- import spaces
11
- import torch
12
- import numpy as np
13
- from PIL import Image
14
- import cv2
15
- import edge_tts
16
-
17
- from transformers import (
18
- AutoModelForCausalLM,
19
- AutoTokenizer,
20
- TextIteratorStreamer,
21
- Qwen2VLForConditionalGeneration,
22
- AutoProcessor,
23
- )
24
- from transformers.image_utils import load_image
25
-
26
- # Constants for text generation
27
- MAX_MAX_NEW_TOKENS = 2048
28
- DEFAULT_MAX_NEW_TOKENS = 1024
29
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
-
31
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
-
33
- # Load text-only model and tokenizer
34
- model_id = "prithivMLmods/Galactic-Qwen-14B-Exp2"
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_id,
38
- device_map="auto",
39
- torch_dtype=torch.bfloat16,
40
- )
41
- model.eval()
42
-
43
- # Load multimodal processor and model
44
- MODEL_ID = "prithivMLmods/Imgscope-OCR-2B-0527"
45
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
46
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
47
- MODEL_ID,
48
- trust_remote_code=True,
49
- torch_dtype=torch.float16
50
- ).to("cuda").eval()
51
-
52
- # Edge TTS voices mapping for new tags.
53
- TTS_VOICE_MAP = {
54
- "@jennyneural": "en-US-JennyNeural",
55
- "@guyneural": "en-US-GuyNeural",
56
- "@palomaneural": "es-US-PalomaNeural",
57
- "@alonsoneural": "es-US-AlonsoNeural",
58
- "@madhurneural": "hi-IN-MadhurNeural"
59
- }
60
-
61
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
62
- """
63
- Convert text to speech using Edge TTS and save as MP3.
64
- """
65
- communicate = edge_tts.Communicate(text, voice)
66
- await communicate.save(output_file)
67
- return output_file
68
-
69
- def clean_chat_history(chat_history):
70
- """
71
- Filter out any chat entries whose "content" is not a string.
72
- This helps prevent errors when concatenating previous messages.
73
- """
74
- cleaned = []
75
- for msg in chat_history:
76
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
77
- cleaned.append(msg)
78
- return cleaned
79
-
80
- def downsample_video(video_path):
81
- """
82
- Downsamples the video to 10 evenly spaced frames.
83
- Each frame is returned as a PIL image along with its timestamp.
84
- """
85
- vidcap = cv2.VideoCapture(video_path)
86
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
87
- fps = vidcap.get(cv2.CAP_PROP_FPS)
88
- frames = []
89
- # Sample 10 evenly spaced frames.
90
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
91
- for i in frame_indices:
92
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
93
- success, image = vidcap.read()
94
- if success:
95
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
96
- pil_image = Image.fromarray(image)
97
- timestamp = round(i / fps, 2)
98
- frames.append((pil_image, timestamp))
99
- vidcap.release()
100
- return frames
101
-
102
- def progress_bar_html(label: str) -> str:
103
- """
104
- Returns an HTML snippet for a thin progress bar with a label.
105
- The progress bar is styled as a light cyan animated bar.
106
- """
107
- return f'''
108
- <div style="display: flex; align-items: center;">
109
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
110
- <div style="width: 110px; height: 5px; background-color: #B0E0E6; border-radius: 2px; overflow: hidden;">
111
- <div style="width: 100%; height: 100%; background-color: #00FFFF; animation: loading 1.5s linear infinite;"></div>
112
- </div>
113
- </div>
114
- <style>
115
- @keyframes loading {{
116
- 0% {{ transform: translateX(-100%); }}
117
- 100% {{ transform: translateX(100%); }}
118
- }}
119
- </style>
120
- '''
121
-
122
- @spaces.GPU
123
- def generate(input_dict: dict, chat_history: list[dict],
124
- max_new_tokens: int = 1024,
125
- temperature: float = 0.6,
126
- top_p: float = 0.9,
127
- top_k: int = 50,
128
- repetition_penalty: float = 1.2):
129
- """
130
- Generates chatbot responses with support for multimodal input, video processing,
131
- and Edge TTS when using the new tags @JennyNeural or @GuyNeural.
132
- Special command:
133
- - "@video-infer": triggers video processing using Imgscope-OCR
134
- """
135
- text = input_dict["text"]
136
- files = input_dict.get("files", [])
137
- lower_text = text.strip().lower()
138
-
139
- # Check for TTS tag in the prompt.
140
- tts_voice = None
141
- for tag, voice in TTS_VOICE_MAP.items():
142
- if lower_text.startswith(tag):
143
- tts_voice = voice
144
- text = text[len(tag):].strip() # Remove the tag from the prompt.
145
- break
146
-
147
- # Branch for video processing with Callisto OCR3.
148
- if lower_text.startswith("@video-infer"):
149
- prompt = text[len("@video-infer"):].strip() if not tts_voice else text
150
- if files:
151
- # Assume the first file is a video.
152
- video_path = files[0]
153
- frames = downsample_video(video_path)
154
- messages = [
155
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
156
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
157
- ]
158
- # Append each frame with its timestamp.
159
- for frame in frames:
160
- image, timestamp = frame
161
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
162
- image.save(image_path)
163
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
164
- messages[1]["content"].append({"type": "image", "url": image_path})
165
- else:
166
- messages = [
167
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
168
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
169
- ]
170
- # Enable truncation to avoid token/feature mismatch.
171
- inputs = processor.apply_chat_template(
172
- messages,
173
- tokenize=True,
174
- add_generation_prompt=True,
175
- return_dict=True,
176
- return_tensors="pt",
177
- truncation=True,
178
- max_length=MAX_INPUT_TOKEN_LENGTH
179
- ).to("cuda")
180
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
181
- generation_kwargs = {
182
- **inputs,
183
- "streamer": streamer,
184
- "max_new_tokens": max_new_tokens,
185
- "do_sample": True,
186
- "temperature": temperature,
187
- "top_p": top_p,
188
- "top_k": top_k,
189
- "repetition_penalty": repetition_penalty,
190
- }
191
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
192
- thread.start()
193
- buffer = ""
194
- yield progress_bar_html("Processing video with Imgscope-OCR")
195
- for new_text in streamer:
196
- buffer += new_text
197
- buffer = buffer.replace("<|im_end|>", "")
198
- time.sleep(0.01)
199
- yield buffer
200
- return
201
-
202
- # Multimodal processing when files are provided.
203
- if files:
204
- if len(files) > 1:
205
- images = [load_image(image) for image in files]
206
- elif len(files) == 1:
207
- images = [load_image(files[0])]
208
- else:
209
- images = []
210
- messages = [{
211
- "role": "user",
212
- "content": [
213
- *[{"type": "image", "image": image} for image in images],
214
- {"type": "text", "text": text},
215
- ]
216
- }]
217
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
218
- # Enable truncation explicitly here as well.
219
- inputs = processor(
220
- text=[prompt_full],
221
- images=images,
222
- return_tensors="pt",
223
- padding=True,
224
- truncation=True,
225
- max_length=MAX_INPUT_TOKEN_LENGTH
226
- ).to("cuda")
227
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
228
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
229
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
230
- thread.start()
231
- buffer = ""
232
- yield progress_bar_html("Processing image with Imgscope-OCR")
233
- for new_text in streamer:
234
- buffer += new_text
235
- buffer = buffer.replace("<|im_end|>", "")
236
- time.sleep(0.01)
237
- yield buffer
238
- else:
239
- # Normal text conversation processing with Pocket Llama.
240
- conversation = clean_chat_history(chat_history)
241
- conversation.append({"role": "user", "content": text})
242
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
243
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
244
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
245
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
246
- input_ids = input_ids.to(model.device)
247
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
248
- generation_kwargs = {
249
- "input_ids": input_ids,
250
- "streamer": streamer,
251
- "max_new_tokens": max_new_tokens,
252
- "do_sample": True,
253
- "top_p": top_p,
254
- "top_k": top_k,
255
- "temperature": temperature,
256
- "num_beams": 1,
257
- "repetition_penalty": repetition_penalty,
258
- }
259
- t = Thread(target=model.generate, kwargs=generation_kwargs)
260
- t.start()
261
- outputs = []
262
- yield progress_bar_html("Processing With Galactic Qwen")
263
- for new_text in streamer:
264
- outputs.append(new_text)
265
- yield "".join(outputs)
266
- final_response = "".join(outputs)
267
- yield final_response
268
-
269
- # If a TTS voice was specified, convert the final response to speech.
270
- if tts_voice:
271
- output_file = asyncio.run(text_to_speech(final_response, tts_voice))
272
- yield gr.Audio(output_file, autoplay=True)
273
-
274
- # Create the Gradio ChatInterface with the custom CSS applied
275
- demo = gr.ChatInterface(
276
- fn=generate,
277
- additional_inputs=[
278
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
279
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
280
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
281
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
282
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
283
- ],
284
- examples=[
285
- ["Write the code that converts temperatures between Celsius and Fahrenheit in short"],
286
- [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
287
- ["@JennyNeural Who was Nikola Tesla and what were his contributions?"],
288
- [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}]
289
- ],
290
- cache_examples=False,
291
- description="# **Imgscope-OCR-Mini**",
292
- type="messages",
293
- fill_height=True,
294
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
295
- stop_btn="Stop Generation",
296
- multimodal=True,
297
- )
298
-
299
- if __name__ == "__main__":
300
  demo.queue(max_size=20).launch(share=True)
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
+ from threading import Thread
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+ import cv2
15
+ import edge_tts
16
+
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ TextIteratorStreamer,
21
+ Qwen2VLForConditionalGeneration,
22
+ AutoProcessor,
23
+ )
24
+ from transformers.image_utils import load_image
25
+
26
+ # Constants for text generation
27
+ MAX_MAX_NEW_TOKENS = 2048
28
+ DEFAULT_MAX_NEW_TOKENS = 1024
29
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
+
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+
33
+ # Load text-only model and tokenizer
34
+ model_id = "prithivMLmods/Galactic-Qwen-14B-Exp2"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_id,
38
+ device_map="auto",
39
+ torch_dtype=torch.bfloat16,
40
+ )
41
+ model.eval()
42
+
43
+ # Load multimodal processor and model
44
+ MODEL_ID = "prithivMLmods/Imgscope-OCR-2B-0527"
45
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
46
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
47
+ MODEL_ID,
48
+ trust_remote_code=True,
49
+ torch_dtype=torch.float16
50
+ ).to("cuda").eval()
51
+
52
+ # Edge TTS voices mapping for new tags.
53
+ TTS_VOICE_MAP = {
54
+ "@jennyneural": "en-US-JennyNeural",
55
+ "@guyneural": "en-US-GuyNeural",
56
+ "@palomaneural": "es-US-PalomaNeural",
57
+ "@alonsoneural": "es-US-AlonsoNeural",
58
+ "@madhurneural": "hi-IN-MadhurNeural"
59
+ }
60
+
61
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
62
+ """
63
+ Convert text to speech using Edge TTS and save as MP3.
64
+ """
65
+ communicate = edge_tts.Communicate(text, voice)
66
+ await communicate.save(output_file)
67
+ return output_file
68
+
69
+ def clean_chat_history(chat_history):
70
+ """
71
+ Filter out any chat entries whose "content" is not a string.
72
+ This helps prevent errors when concatenating previous messages.
73
+ """
74
+ cleaned = []
75
+ for msg in chat_history:
76
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
77
+ cleaned.append(msg)
78
+ return cleaned
79
+
80
+ def downsample_video(video_path):
81
+ """
82
+ Downsamples the video to 10 evenly spaced frames.
83
+ Each frame is returned as a PIL image along with its timestamp.
84
+ """
85
+ vidcap = cv2.VideoCapture(video_path)
86
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
87
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
88
+ frames = []
89
+ # Sample 10 evenly spaced frames.
90
+ frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
91
+ for i in frame_indices:
92
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
93
+ success, image = vidcap.read()
94
+ if success:
95
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
96
+ pil_image = Image.fromarray(image)
97
+ timestamp = round(i / fps, 2)
98
+ frames.append((pil_image, timestamp))
99
+ vidcap.release()
100
+ return frames
101
+
102
+ def progress_bar_html(label: str) -> str:
103
+ """
104
+ Returns an HTML snippet for a thin progress bar with a label.
105
+ The progress bar is styled as a light cyan animated bar.
106
+ """
107
+ return f'''
108
+ <div style="display: flex; align-items: center;">
109
+ <span style="margin-right: 10px; font-size: 14px;">{label}</span>
110
+ <div style="width: 110px; height: 5px; background-color: #B0E0E6; border-radius: 2px; overflow: hidden;">
111
+ <div style="width: 100%; height: 100%; background-color: #00FFFF; animation: loading 1.5s linear infinite;"></div>
112
+ </div>
113
+ </div>
114
+ <style>
115
+ @keyframes loading {{
116
+ 0% {{ transform: translateX(-100%); }}
117
+ 100% {{ transform: translateX(100%); }}
118
+ }}
119
+ </style>
120
+ '''
121
+
122
+ @spaces.GPU
123
+ def generate(input_dict: dict, chat_history: list[dict],
124
+ max_new_tokens: int = 1024,
125
+ temperature: float = 0.6,
126
+ top_p: float = 0.9,
127
+ top_k: int = 50,
128
+ repetition_penalty: float = 1.2):
129
+ """
130
+ Generates chatbot responses with support for multimodal input, video processing,
131
+ and Edge TTS when using the new tags @JennyNeural or @GuyNeural.
132
+ Special command:
133
+ - "@video-infer": triggers video processing using Imgscope-OCR
134
+ """
135
+ text = input_dict["text"]
136
+ files = input_dict.get("files", [])
137
+ lower_text = text.strip().lower()
138
+
139
+ # Check for TTS tag in the prompt.
140
+ tts_voice = None
141
+ for tag, voice in TTS_VOICE_MAP.items():
142
+ if lower_text.startswith(tag):
143
+ tts_voice = voice
144
+ text = text[len(tag):].strip() # Remove the tag from the prompt.
145
+ break
146
+
147
+ # Branch for video processing with Callisto OCR3.
148
+ if lower_text.startswith("@video-infer"):
149
+ prompt = text[len("@video-infer"):].strip() if not tts_voice else text
150
+ if files:
151
+ # Assume the first file is a video.
152
+ video_path = files[0]
153
+ frames = downsample_video(video_path)
154
+ messages = [
155
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
156
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
157
+ ]
158
+ # Append each frame with its timestamp.
159
+ for frame in frames:
160
+ image, timestamp = frame
161
+ image_path = f"video_frame_{uuid.uuid4().hex}.png"
162
+ image.save(image_path)
163
+ messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
164
+ messages[1]["content"].append({"type": "image", "url": image_path})
165
+ else:
166
+ messages = [
167
+ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
168
+ {"role": "user", "content": [{"type": "text", "text": prompt}]}
169
+ ]
170
+ # Enable truncation to avoid token/feature mismatch.
171
+ inputs = processor.apply_chat_template(
172
+ messages,
173
+ tokenize=True,
174
+ add_generation_prompt=True,
175
+ return_dict=True,
176
+ return_tensors="pt",
177
+ truncation=True,
178
+ max_length=MAX_INPUT_TOKEN_LENGTH
179
+ ).to("cuda")
180
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
181
+ generation_kwargs = {
182
+ **inputs,
183
+ "streamer": streamer,
184
+ "max_new_tokens": max_new_tokens,
185
+ "do_sample": True,
186
+ "temperature": temperature,
187
+ "top_p": top_p,
188
+ "top_k": top_k,
189
+ "repetition_penalty": repetition_penalty,
190
+ }
191
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
192
+ thread.start()
193
+ buffer = ""
194
+ yield progress_bar_html("Processing video with Imgscope-OCR")
195
+ for new_text in streamer:
196
+ buffer += new_text
197
+ buffer = buffer.replace("<|im_end|>", "")
198
+ time.sleep(0.01)
199
+ yield buffer
200
+ return
201
+
202
+ # Multimodal processing when files are provided.
203
+ if files:
204
+ if len(files) > 1:
205
+ images = [load_image(image) for image in files]
206
+ elif len(files) == 1:
207
+ images = [load_image(files[0])]
208
+ else:
209
+ images = []
210
+ messages = [{
211
+ "role": "user",
212
+ "content": [
213
+ *[{"type": "image", "image": image} for image in images],
214
+ {"type": "text", "text": text},
215
+ ]
216
+ }]
217
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
218
+ # Enable truncation explicitly here as well.
219
+ inputs = processor(
220
+ text=[prompt_full],
221
+ images=images,
222
+ return_tensors="pt",
223
+ padding=True,
224
+ truncation=True,
225
+ max_length=MAX_INPUT_TOKEN_LENGTH
226
+ ).to("cuda")
227
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
228
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
229
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
230
+ thread.start()
231
+ buffer = ""
232
+ yield progress_bar_html("Processing image with Imgscope-OCR")
233
+ for new_text in streamer:
234
+ buffer += new_text
235
+ buffer = buffer.replace("<|im_end|>", "")
236
+ time.sleep(0.01)
237
+ yield buffer
238
+ else:
239
+ # Normal text conversation processing with Pocket Llama.
240
+ conversation = clean_chat_history(chat_history)
241
+ conversation.append({"role": "user", "content": text})
242
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
243
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
244
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
245
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
246
+ input_ids = input_ids.to(model.device)
247
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
248
+ generation_kwargs = {
249
+ "input_ids": input_ids,
250
+ "streamer": streamer,
251
+ "max_new_tokens": max_new_tokens,
252
+ "do_sample": True,
253
+ "top_p": top_p,
254
+ "top_k": top_k,
255
+ "temperature": temperature,
256
+ "num_beams": 1,
257
+ "repetition_penalty": repetition_penalty,
258
+ }
259
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
260
+ t.start()
261
+ outputs = []
262
+ yield progress_bar_html("Processing With Galactic Qwen")
263
+ for new_text in streamer:
264
+ outputs.append(new_text)
265
+ yield "".join(outputs)
266
+ final_response = "".join(outputs)
267
+ yield final_response
268
+
269
+ # If a TTS voice was specified, convert the final response to speech.
270
+ if tts_voice:
271
+ output_file = asyncio.run(text_to_speech(final_response, tts_voice))
272
+ yield gr.Audio(output_file, autoplay=True)
273
+
274
+ # Create the Gradio ChatInterface with the custom CSS applied
275
+ demo = gr.ChatInterface(
276
+ fn=generate,
277
+ additional_inputs=[
278
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
279
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
280
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
281
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
282
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
283
+ ],
284
+ examples=[
285
+ ["Write the code that converts temperatures between Celsius and Fahrenheit in short"],
286
+ [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
287
+ ["@JennyNeural Who was Nikola Tesla and what were his contributions?"],
288
+ [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}]
289
+ ],
290
+ cache_examples=False,
291
+ description="# **Imgscope-OCR**",
292
+ type="messages",
293
+ fill_height=True,
294
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
295
+ stop_btn="Stop Generation",
296
+ multimodal=True,
297
+ )
298
+
299
+ if __name__ == "__main__":
300
  demo.queue(max_size=20).launch(share=True)