prithivMLmods commited on
Commit
336fc5d
·
verified ·
1 Parent(s): aad98bd

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -349
app.py DELETED
@@ -1,349 +0,0 @@
1
- import os
2
- import random
3
- import uuid
4
- import json
5
- import time
6
- import asyncio
7
- from threading import Thread
8
- import tempfile
9
-
10
- import gradio as gr
11
- import spaces
12
- import torch
13
- import numpy as np
14
- from PIL import Image
15
- import cv2
16
-
17
- from transformers import (
18
- Qwen2VLForConditionalGeneration,
19
- AutoModelForImageTextToText,
20
- AutoProcessor,
21
- TextIteratorStreamer,
22
- )
23
- from transformers.image_utils import load_image
24
-
25
- # Constants for text generation
26
- MAX_MAX_NEW_TOKENS = 2048
27
- DEFAULT_MAX_NEW_TOKENS = 1024
28
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
-
30
- # Determine device
31
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
-
33
- # Load VIREX-062225-exp
34
- MODEL_ID_M = "prithivMLmods/VIREX-062225-exp"
35
- processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
36
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
37
- MODEL_ID_M,
38
- trust_remote_code=True,
39
- torch_dtype=torch.float16
40
- ).to(device).eval()
41
-
42
- # Load DREX-062225-exp
43
- MODEL_ID_X = "prithivMLmods/DREX-062225-exp"
44
- processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
45
- model_x = Qwen2VLForConditionalGeneration.from_pretrained(
46
- MODEL_ID_X,
47
- trust_remote_code=True,
48
- torch_dtype=torch.float16
49
- ).to(device).eval()
50
-
51
- # Load Gemma3n-E4B-it (Placeholder: Adjust model class if incorrect)
52
- MODEL_ID_G = "google/gemma-3n-E4B-it"
53
- processor_g = AutoProcessor.from_pretrained(MODEL_ID_G, trust_remote_code=True)
54
- model_g = AutoModelForImageTextToText.from_pretrained(
55
- MODEL_ID_G,
56
- trust_remote_code=True,
57
- torch_dtype=torch.float16
58
- ).to(device).eval()
59
-
60
- # Load Gemma3n-E2B-it (Placeholder: Adjust model class if incorrect)
61
- MODEL_ID_N = "google/gemma-3n-E2B-it"
62
- processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
63
- model_n = AutoModelForImageTextToText.from_pretrained(
64
- MODEL_ID_N,
65
- trust_remote_code=True,
66
- torch_dtype=torch.float16
67
- ).to(device).eval()
68
-
69
- def downsample_video(video_path):
70
- """
71
- Downsamples the video to evenly spaced frames and saves them to temporary files.
72
- Returns a list of (frame_path, timestamp) and the temp directory.
73
- """
74
- vidcap = cv2.VideoCapture(video_path)
75
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
76
- fps = vidcap.get(cv2.CAP_PROP_FPS)
77
- frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
78
- temp_dir = tempfile.mkdtemp()
79
- frames = []
80
- for i in frame_indices:
81
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
82
- success, image = vidcap.read()
83
- if success:
84
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85
- frame_path = os.path.join(temp_dir, f"frame_{i}.jpg")
86
- Image.fromarray(image).save(frame_path)
87
- timestamp = round(i / fps, 2)
88
- frames.append((frame_path, timestamp))
89
- vidcap.release()
90
- return frames, temp_dir
91
-
92
- @spaces.GPU
93
- def generate_image(model_name: str, text: str, image_path: str,
94
- max_new_tokens: int = 1024,
95
- temperature: float = 0.6,
96
- top_p: float = 0.9,
97
- top_k: int = 50,
98
- repetition_penalty: float = 1.2):
99
- """
100
- Generates responses using the selected model for image input.
101
- """
102
- if model_name == "VIREX-062225-7B-exp":
103
- processor = processor_m
104
- model = model_m
105
- elif model_name == "DREX-062225-7B-exp":
106
- processor = processor_x
107
- model = model_x
108
- elif model_name == "Gemma3n-E4B-it":
109
- processor = processor_g
110
- model = model_g
111
- elif model_name == "Gemma3n-E2B-it":
112
- processor = processor_n
113
- model = model_n
114
- else:
115
- yield "Invalid model selected.", "Invalid model selected."
116
- return
117
-
118
- if image_path is None:
119
- yield "Please upload an image.", "Please upload an image."
120
- return
121
-
122
- messages = [{"role": "user", "content": [{"type": "text", "text": text}, {"type": "image", "image": image_path}]}]
123
-
124
- if model_name in ["Gemma3n-E4B-it", "Gemma3n-E2B-it"]:
125
- inputs = processor.apply_chat_template(
126
- messages,
127
- tokenize=True,
128
- add_generation_prompt=True,
129
- return_dict=True,
130
- return_tensors="pt",
131
- truncation=True, # Enable truncation to prevent overflow
132
- max_length=MAX_INPUT_TOKEN_LENGTH
133
- ).to(device)
134
- else:
135
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
136
- inputs = processor(
137
- text=[prompt_full],
138
- images=[image_path],
139
- return_tensors="pt",
140
- padding=True,
141
- truncation=True, # Enable truncation to prevent overflow
142
- max_length=MAX_INPUT_TOKEN_LENGTH
143
- ).to(device)
144
-
145
- # Check input token length
146
- input_length = inputs["input_ids"].shape[1]
147
- if input_length > MAX_INPUT_TOKEN_LENGTH:
148
- yield f"Input too long. Max {MAX_INPUT_TOKEN_LENGTH} tokens. Got {input_length} tokens.", ""
149
- return
150
-
151
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
152
- generation_kwargs = {
153
- **inputs,
154
- "streamer": streamer,
155
- "max_new_tokens": max_new_tokens,
156
- "do_sample": True,
157
- "temperature": temperature,
158
- "top_p": top_p,
159
- "top_k": top_k,
160
- "repetition_penalty": repetition_penalty,
161
- }
162
-
163
- # Ensure all tensors are on the correct device
164
- for key in generation_kwargs:
165
- if isinstance(generation_kwargs[key], torch.Tensor):
166
- generation_kwargs[key] = generation_kwargs[key].to(device)
167
-
168
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
169
- thread.start()
170
- buffer = ""
171
- for new_text in streamer:
172
- buffer += new_text
173
- time.sleep(0.01)
174
- yield buffer, buffer
175
-
176
- @spaces.GPU
177
- def generate_video(model_name: str, text: str, video_path: str,
178
- max_new_tokens: int = 1024,
179
- temperature: float = 0.6,
180
- top_p: float = 0.9,
181
- top_k: int = 50,
182
- repetition_penalty: float = 1.2):
183
- """
184
- Generates responses using the selected model for video input.
185
- """
186
- if model_name == "VIREX-062225-7B-exp":
187
- processor = processor_m
188
- model = model_m
189
- elif model_name == "DREX-062225-7B-exp":
190
- processor = processor_x
191
- model = model_x
192
- elif model_name == "Gemma3n-E4B-it":
193
- processor = processor_g
194
- model = model_g
195
- elif model_name == "Gemma3n-E2B-it":
196
- processor = processor_n
197
- model = model_n
198
- else:
199
- yield "Invalid model selected.", "Invalid model selected."
200
- return
201
-
202
- if video_path is None:
203
- yield "Please upload a video.", "Please upload a video."
204
- return
205
-
206
- frames, temp_dir = downsample_video(video_path)
207
- content = [{"type": "text", "text": text}]
208
- for frame_path, timestamp in frames:
209
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
210
- content.append({"type": "image", "image": frame_path})
211
- messages = [{"role": "user", "content": content}]
212
-
213
- if model_name in ["Gemma3n-E4B-it", "Gemma3n-E2B-it"]:
214
- inputs = processor.apply_chat_template(
215
- messages,
216
- tokenize=True,
217
- add_generation_prompt=True,
218
- return_dict=True,
219
- return_tensors="pt",
220
- truncation=True, # Enable truncation to prevent overflow
221
- max_length=MAX_INPUT_TOKEN_LENGTH
222
- ).to(device)
223
- else:
224
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
225
- images = [frame_path for frame_path, _ in frames]
226
- inputs = processor(
227
- text=[prompt_full],
228
- images=images,
229
- return_tensors="pt",
230
- padding=True,
231
- truncation=True, # Enable truncation to prevent overflow
232
- max_length=MAX_INPUT_TOKEN_LENGTH
233
- ).to(device)
234
-
235
- # Check input token length
236
- input_length = inputs["input_ids"].shape[1]
237
- if input_length > MAX_INPUT_TOKEN_LENGTH:
238
- yield f"Input too long. Max {MAX_INPUT_TOKEN_LENGTH} tokens. Got {input_length} tokens.", ""
239
- return
240
-
241
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
242
- generation_kwargs = {
243
- **inputs,
244
- "streamer": streamer,
245
- "max_new_tokens": max_new_tokens,
246
- "do_sample": True,
247
- "temperature": temperature,
248
- "top_p": top_p,
249
- "top_k": top_k,
250
- "repetition_penalty": repetition_penalty,
251
- }
252
-
253
- # Ensure all tensors are on the correct device
254
- for key in generation_kwargs:
255
- if isinstance(generation_kwargs[key], torch.Tensor):
256
- generation_kwargs[key] = generation_kwargs[key].to(device)
257
-
258
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
259
- thread.start()
260
- buffer = ""
261
- for new_text in streamer:
262
- buffer += new_text
263
- buffer = buffer.replace("<|im_end|>", "")
264
- time.sleep(0.01)
265
- yield buffer, buffer
266
-
267
- # Define examples for image and video inference
268
- image_examples = [
269
- ["Convert this page to doc [text] precisely.", "images/3.png"],
270
- ["Convert this page to doc [text] precisely.", "images/4.png"],
271
- ["Convert this page to doc [text] precisely.", "images/1.png"],
272
- ["Convert chart to OTSL.", "images/2.png"]
273
- ]
274
-
275
- video_examples = [
276
- ["Explain the video in detail.", "videos/2.mp4"],
277
- ["Explain the ad in detail.", "videos/1.mp4"]
278
- ]
279
-
280
- # Added CSS to style the output area as a "Canvas"
281
- css = """
282
- .submit-btn {
283
- background-color: #2980b9 !important;
284
- color: white !important;
285
- }
286
- .submit-btn:hover {
287
- background-color: #3498db !important;
288
- }
289
- .canvas-output {
290
- border: 2px solid #4682B4;
291
- border-radius: 10px;
292
- padding: 20px;
293
- }
294
- """
295
-
296
- # Create the Gradio Interface
297
- with gr.Blocks(css=css, theme=gr.themes.Citrus()) as demo:
298
- gr.Markdown("# **[Doc VLMs OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
299
- with gr.Row():
300
- with gr.Column():
301
- with gr.Tabs():
302
- with gr.TabItem("Image Inference"):
303
- image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
304
- image_upload = gr.Image(type="filepath", label="Image")
305
- image_submit = gr.Button("Submit", elem_classes="submit-btn")
306
- gr.Examples(
307
- examples=image_examples,
308
- inputs=[image_query, image_upload]
309
- )
310
- with gr.TabItem("Video Inference"):
311
- video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
312
- video_upload = gr.Video(label="Video")
313
- video_submit = gr.Button("Submit", elem_classes="submit-btn")
314
- gr.Examples(
315
- examples=video_examples,
316
- inputs=[video_query, video_upload]
317
- )
318
- with gr.Accordion("Advanced options", open=False):
319
- max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
320
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
321
- top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
322
- top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
323
- repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
324
-
325
- with gr.Column():
326
- with gr.Column(elem_classes="canvas-output"):
327
- gr.Markdown("## Result Canvas")
328
- output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
329
- markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
330
-
331
- model_choice = gr.Radio(
332
- choices=["DREX-062225-7B-exp", "VIREX-062225-7B-exp", "Gemma3n-E4B-it", "Gemma3n-E2B-it"],
333
- label="Select Model",
334
- value="DREX-062225-7B-exp"
335
- )
336
-
337
- image_submit.click(
338
- fn=generate_image,
339
- inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
340
- outputs=[output, markdown_output]
341
- )
342
- video_submit.click(
343
- fn=generate_video,
344
- inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
345
- outputs=[output, markdown_output]
346
- )
347
-
348
- if __name__ == "__main__":
349
- demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)