prithivMLmods commited on
Commit
e817668
·
verified ·
1 Parent(s): 3d4caeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +373 -223
app.py CHANGED
@@ -10,19 +10,24 @@ import gradio as gr
10
  import spaces
11
  import torch
12
  import numpy as np
13
- from PIL import Image
14
  import cv2
15
- import edge_tts
16
 
17
  from transformers import (
18
- AutoModelForCausalLM,
19
- AutoTokenizer,
20
- TextIteratorStreamer,
21
  Qwen2VLForConditionalGeneration,
 
 
22
  AutoProcessor,
 
23
  )
24
  from transformers.image_utils import load_image
25
 
 
 
 
 
 
 
26
  # Constants for text generation
27
  MAX_MAX_NEW_TOKENS = 2048
28
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -30,271 +35,416 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
- # Load text-only model and tokenizer
34
- model_id = "prithivMLmods/Galactic-Qwen-14B-Exp2"
35
- tokenizer = AutoTokenizer.from_pretrained(model_id)
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_id,
38
- device_map="auto",
39
- torch_dtype=torch.bfloat16,
40
- )
41
- model.eval()
42
-
43
- # Load multimodal processor and model
44
- MODEL_ID = "prithivMLmods/Imgscope-OCR-2B-0527"
45
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
46
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
47
- MODEL_ID,
48
  trust_remote_code=True,
49
  torch_dtype=torch.float16
50
- ).to("cuda").eval()
51
-
52
- # Edge TTS voices mapping for new tags.
53
- TTS_VOICE_MAP = {
54
- "@jennyneural": "en-US-JennyNeural",
55
- "@guyneural": "en-US-GuyNeural",
56
- "@palomaneural": "es-US-PalomaNeural",
57
- "@alonsoneural": "es-US-AlonsoNeural",
58
- "@madhurneural": "hi-IN-MadhurNeural"
59
- }
60
 
61
- async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
62
- """
63
- Convert text to speech using Edge TTS and save as MP3.
64
- """
65
- communicate = edge_tts.Communicate(text, voice)
66
- await communicate.save(output_file)
67
- return output_file
68
-
69
- def clean_chat_history(chat_history):
70
- """
71
- Filter out any chat entries whose "content" is not a string.
72
- This helps prevent errors when concatenating previous messages.
73
- """
74
- cleaned = []
75
- for msg in chat_history:
76
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
77
- cleaned.append(msg)
78
- return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def downsample_video(video_path):
81
- """
82
- Downsamples the video to 10 evenly spaced frames.
83
- Each frame is returned as a PIL image along with its timestamp.
84
- """
85
  vidcap = cv2.VideoCapture(video_path)
86
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
87
  fps = vidcap.get(cv2.CAP_PROP_FPS)
88
  frames = []
89
- # Sample 10 evenly spaced frames.
90
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
91
  for i in frame_indices:
92
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
93
  success, image = vidcap.read()
94
  if success:
95
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
96
  pil_image = Image.fromarray(image)
97
  timestamp = round(i / fps, 2)
98
  frames.append((pil_image, timestamp))
99
  vidcap.release()
100
  return frames
101
 
102
- def progress_bar_html(label: str) -> str:
103
- """
104
- Returns an HTML snippet for a thin progress bar with a label.
105
- The progress bar is styled as a light cyan animated bar.
106
- """
107
- return f'''
108
- <div style="display: flex; align-items: center;">
109
- <span style="margin-right: 10px; font-size: 14px;">{label}</span>
110
- <div style="width: 110px; height: 5px; background-color: #B0E0E6; border-radius: 2px; overflow: hidden;">
111
- <div style="width: 100%; height: 100%; background-color: #00FFFF; animation: loading 1.5s linear infinite;"></div>
112
- </div>
113
- </div>
114
- <style>
115
- @keyframes loading {{
116
- 0% {{ transform: translateX(-100%); }}
117
- 100% {{ transform: translateX(100%); }}
118
- }}
119
- </style>
120
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  @spaces.GPU
123
- def generate(input_dict: dict, chat_history: list[dict],
124
- max_new_tokens: int = 1024,
125
- temperature: float = 0.6,
126
- top_p: float = 0.9,
127
- top_k: int = 50,
128
- repetition_penalty: float = 1.2):
129
- """
130
- Generates chatbot responses with support for multimodal input, video processing,
131
- and Edge TTS when using the new tags @JennyNeural or @GuyNeural.
132
- Special command:
133
- - "@video-infer": triggers video processing using Imgscope-OCR
134
- """
135
- text = input_dict["text"]
136
- files = input_dict.get("files", [])
137
- lower_text = text.strip().lower()
138
-
139
- # Check for TTS tag in the prompt.
140
- tts_voice = None
141
- for tag, voice in TTS_VOICE_MAP.items():
142
- if lower_text.startswith(tag):
143
- tts_voice = voice
144
- text = text[len(tag):].strip() # Remove the tag from the prompt.
145
- break
146
-
147
- # Branch for video processing with Callisto OCR3.
148
- if lower_text.startswith("@video-infer"):
149
- prompt = text[len("@video-infer"):].strip() if not tts_voice else text
150
- if files:
151
- # Assume the first file is a video.
152
- video_path = files[0]
153
- frames = downsample_video(video_path)
154
- messages = [
155
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
156
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
157
- ]
158
- # Append each frame with its timestamp.
159
- for frame in frames:
160
- image, timestamp = frame
161
- image_path = f"video_frame_{uuid.uuid4().hex}.png"
162
- image.save(image_path)
163
- messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
164
- messages[1]["content"].append({"type": "image", "url": image_path})
165
  else:
166
- messages = [
167
- {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
168
- {"role": "user", "content": [{"type": "text", "text": prompt}]}
169
- ]
170
- # Enable truncation to avoid token/feature mismatch.
171
- inputs = processor.apply_chat_template(
172
- messages,
173
- tokenize=True,
174
- add_generation_prompt=True,
175
- return_dict=True,
176
- return_tensors="pt",
177
- truncation=True,
178
- max_length=MAX_INPUT_TOKEN_LENGTH
179
- ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
180
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
181
  generation_kwargs = {
182
  **inputs,
183
  "streamer": streamer,
184
  "max_new_tokens": max_new_tokens,
185
- "do_sample": True,
186
  "temperature": temperature,
187
  "top_p": top_p,
188
  "top_k": top_k,
189
  "repetition_penalty": repetition_penalty,
190
  }
191
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
192
  thread.start()
 
193
  buffer = ""
194
- yield progress_bar_html("Processing video with Imgscope-OCR")
195
  for new_text in streamer:
196
- buffer += new_text
197
- buffer = buffer.replace("<|im_end|>", "")
198
- time.sleep(0.01)
199
  yield buffer
200
- return
201
-
202
- # Multimodal processing when files are provided.
203
- if files:
204
- if len(files) > 1:
205
- images = [load_image(image) for image in files]
206
- elif len(files) == 1:
207
- images = [load_image(files[0])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  else:
209
- images = []
210
- messages = [{
211
- "role": "user",
212
- "content": [
213
- *[{"type": "image", "image": image} for image in images],
214
- {"type": "text", "text": text},
215
- ]
216
- }]
217
- prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
218
- # Enable truncation explicitly here as well.
219
- inputs = processor(
220
- text=[prompt_full],
221
- images=images,
222
- return_tensors="pt",
223
- padding=True,
224
- truncation=True,
225
- max_length=MAX_INPUT_TOKEN_LENGTH
226
- ).to("cuda")
 
 
 
 
 
 
 
 
 
227
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
228
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
229
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
230
- thread.start()
231
- buffer = ""
232
- yield progress_bar_html("Processing image with Imgscope-OCR")
233
- for new_text in streamer:
234
- buffer += new_text
235
- buffer = buffer.replace("<|im_end|>", "")
236
- time.sleep(0.01)
237
- yield buffer
238
- else:
239
- # Normal text conversation processing with Pocket Llama.
240
- conversation = clean_chat_history(chat_history)
241
- conversation.append({"role": "user", "content": text})
242
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
243
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
244
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
245
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
246
- input_ids = input_ids.to(model.device)
247
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
248
  generation_kwargs = {
249
- "input_ids": input_ids,
250
  "streamer": streamer,
251
  "max_new_tokens": max_new_tokens,
252
- "do_sample": True,
253
  "top_p": top_p,
254
  "top_k": top_k,
255
- "temperature": temperature,
256
- "num_beams": 1,
257
  "repetition_penalty": repetition_penalty,
258
  }
259
- t = Thread(target=model.generate, kwargs=generation_kwargs)
260
- t.start()
261
- outputs = []
262
- yield progress_bar_html("Processing With Galactic Qwen")
 
263
  for new_text in streamer:
264
- outputs.append(new_text)
265
- yield "".join(outputs)
266
- final_response = "".join(outputs)
267
- yield final_response
268
-
269
- # If a TTS voice was specified, convert the final response to speech.
270
- if tts_voice:
271
- output_file = asyncio.run(text_to_speech(final_response, tts_voice))
272
- yield gr.Audio(output_file, autoplay=True)
273
-
274
- # Create the Gradio ChatInterface with the custom CSS applied
275
- demo = gr.ChatInterface(
276
- fn=generate,
277
- additional_inputs=[
278
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
279
- gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
280
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
281
- gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
282
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
283
- ],
284
- examples=[
285
- ["Write the code that converts temperatures between Celsius and Fahrenheit in short"],
286
- [{"text": "Create a short story based on the image.", "files": ["examples/1.jpg"]}],
287
- ["@JennyNeural Who was Nikola Tesla and what were his contributions?"],
288
- [{"text": "@video-infer Describe the video", "files": ["examples/Missing.mp4"]}]
289
- ],
290
- cache_examples=False,
291
- description="# **Imgscope-OCR**",
292
- type="messages",
293
- fill_height=True,
294
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", "video"], file_count="multiple"),
295
- stop_btn="Stop Generation",
296
- multimodal=True,
297
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  if __name__ == "__main__":
300
- demo.queue(max_size=20).launch(share=True)
 
10
  import spaces
11
  import torch
12
  import numpy as np
13
+ from PIL import Image, ImageOps
14
  import cv2
 
15
 
16
  from transformers import (
 
 
 
17
  Qwen2VLForConditionalGeneration,
18
+ VisionEncoderDecoderModel,
19
+ AutoModelForVision2Seq,
20
  AutoProcessor,
21
+ TextIteratorStreamer,
22
  )
23
  from transformers.image_utils import load_image
24
 
25
+ from docling_core.types.doc import DoclingDocument, DocTagsDocument
26
+
27
+ import re
28
+ import ast
29
+ import html
30
+
31
  # Constants for text generation
32
  MAX_MAX_NEW_TOKENS = 2048
33
  DEFAULT_MAX_NEW_TOKENS = 1024
 
35
 
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
+ # Load olmOCR-7B-0225-preview
39
+ MODEL_ID_M = "allenai/olmOCR-7B-0225-preview"
40
+ processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
41
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
42
+ MODEL_ID_M,
43
  trust_remote_code=True,
44
  torch_dtype=torch.float16
45
+ ).to(device).eval()
 
 
 
 
 
 
 
 
 
46
 
47
+ # Load ByteDance's Dolphin
48
+ MODEL_ID_K = "ByteDance/Dolphin"
49
+ processor_k = AutoProcessor.from_pretrained(MODEL_ID_K, trust_remote_code=True)
50
+ model_k = VisionEncoderDecoderModel.from_pretrained(
51
+ MODEL_ID_K,
52
+ trust_remote_code=True,
53
+ torch_dtype=torch.float16
54
+ ).to(device).eval()
55
+
56
+ # Load SmolDocling-256M-preview
57
+ MODEL_ID_X = "ds4sd/SmolDocling-256M-preview"
58
+ processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
59
+ model_x = AutoModelForVision2Seq.from_pretrained(
60
+ MODEL_ID_X,
61
+ trust_remote_code=True,
62
+ torch_dtype=torch.float16
63
+ ).to(device).eval()
64
+
65
+
66
+ # Preprocessing functions for SmolDocling-256M
67
+ def add_random_padding(image, min_percent=0.1, max_percent=0.10):
68
+ """Add random padding to an image based on its size."""
69
+ image = image.convert("RGB")
70
+ width, height = image.size
71
+ pad_w_percent = random.uniform(min_percent, max_percent)
72
+ pad_h_percent = random.uniform(min_percent, max_percent)
73
+ pad_w = int(width * pad_w_percent)
74
+ pad_h = int(height * pad_h_percent)
75
+ corner_pixel = image.getpixel((0, 0)) # Top-left corner
76
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
77
+ return padded_image
78
+
79
+ def normalize_values(text, target_max=500):
80
+ """Normalize numerical values in text to a target maximum."""
81
+ def normalize_list(values):
82
+ max_value = max(values) if values else 1
83
+ return [round((v / max_value) * target_max) for v in values]
84
+
85
+ def process_match(match):
86
+ num_list = ast.literal_eval(match.group(0))
87
+ normalized = normalize_list(num_list)
88
+ return "".join([f"<loc_{num}>" for num in normalized])
89
+
90
+ pattern = r"\[([\d\.\s,]+)\]"
91
+ normalized_text = re.sub(pattern, process_match, text)
92
+ return normalized_text
93
 
94
  def downsample_video(video_path):
95
+ """Downsample a video to evenly spaced frames, returning PIL images with timestamps."""
 
 
 
96
  vidcap = cv2.VideoCapture(video_path)
97
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
98
  fps = vidcap.get(cv2.CAP_PROP_FPS)
99
  frames = []
 
100
  frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
101
  for i in frame_indices:
102
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
103
  success, image = vidcap.read()
104
  if success:
105
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
106
  pil_image = Image.fromarray(image)
107
  timestamp = round(i / fps, 2)
108
  frames.append((pil_image, timestamp))
109
  vidcap.release()
110
  return frames
111
 
112
+ # Dolphin-specific functions
113
+ def model_chat(prompt, image):
114
+ """Use Dolphin model for inference."""
115
+ processor = processor_k
116
+ model = model_k
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ inputs = processor(image, return_tensors="pt").to(device)
119
+ pixel_values = inputs.pixel_values.half()
120
+ prompt_inputs = processor.tokenizer(
121
+ f"<s>{prompt} <Answer/>",
122
+ add_special_tokens=False,
123
+ return_tensors="pt"
124
+ ).to(device)
125
+ outputs = model.generate(
126
+ pixel_values=pixel_values,
127
+ decoder_input_ids=prompt_inputs.input_ids,
128
+ decoder_attention_mask=prompt_inputs.attention_mask,
129
+ min_length=1,
130
+ max_length=4096,
131
+ pad_token_id=processor.tokenizer.pad_token_id,
132
+ eos_token_id=processor.tokenizer.eos_token_id,
133
+ use_cache=True,
134
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
135
+ return_dict_in_generate=True,
136
+ do_sample=False,
137
+ num_beams=1,
138
+ repetition_penalty=1.1
139
+ )
140
+ sequence = processor.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0]
141
+ cleaned = sequence.replace(f"<s>{prompt} <Answer/>", "").replace("<pad>", "").replace("</s>", "").strip()
142
+ return cleaned
143
+
144
+ def process_elements(layout_results, image):
145
+ """Parse layout results and extract elements from the image."""
146
+ # Placeholder parsing logic based on expected Dolphin output
147
+ # Assuming layout_results is a string like "[(x1,y1,x2,y2,label), ...]"
148
+ try:
149
+ elements = ast.literal_eval(layout_results)
150
+ except:
151
+ elements = [] # Fallback if parsing fails
152
+
153
+ recognition_results = []
154
+ reading_order = 0
155
+
156
+ for bbox, label in elements:
157
+ try:
158
+ x1, y1, x2, y2 = map(int, bbox)
159
+ cropped = image.crop((x1, y1, x2, y2))
160
+ if cropped.size[0] > 0 and cropped.size[1] > 0:
161
+ if label == "text":
162
+ text = model_chat("Read text in the image.", cropped)
163
+ recognition_results.append({
164
+ "label": label,
165
+ "bbox": [x1, y1, x2, y2],
166
+ "text": text.strip(),
167
+ "reading_order": reading_order
168
+ })
169
+ elif label == "table":
170
+ table_text = model_chat("Parse the table in the image.", cropped)
171
+ recognition_results.append({
172
+ "label": label,
173
+ "bbox": [x1, y1, x2, y2],
174
+ "text": table_text.strip(),
175
+ "reading_order": reading_order
176
+ })
177
+ elif label == "figure":
178
+ recognition_results.append({
179
+ "label": label,
180
+ "bbox": [x1, y1, x2, y2],
181
+ "text": "[Figure]", # Placeholder for figure content
182
+ "reading_order": reading_order
183
+ })
184
+ reading_order += 1
185
+ except Exception as e:
186
+ print(f"Error processing element: {e}")
187
+ continue
188
+
189
+ return recognition_results
190
+
191
+ def generate_markdown(recognition_results):
192
+ """Generate markdown from extracted elements."""
193
+ markdown = ""
194
+ for element in sorted(recognition_results, key=lambda x: x["reading_order"]):
195
+ if element["label"] == "text":
196
+ markdown += f"{element['text']}\n\n"
197
+ elif element["label"] == "table":
198
+ markdown += f"**Table:**\n{element['text']}\n\n"
199
+ elif element["label"] == "figure":
200
+ markdown += f"{element['text']}\n\n"
201
+ return markdown.strip()
202
+
203
+ def process_image_with_dolphin(image):
204
+ """Process a single image with Dolphin model."""
205
+ layout_output = model_chat("Parse the reading order of this document.", image)
206
+ elements = process_elements(layout_output, image)
207
+ markdown_content = generate_markdown(elements)
208
+ return markdown_content
209
 
210
  @spaces.GPU
211
+ def generate_image(model_name: str, text: str, image: Image.Image,
212
+ max_new_tokens: int = 1024,
213
+ temperature: float = 0.6,
214
+ top_p: float = 0.9,
215
+ top_k: int = 50,
216
+ repetition_penalty: float = 1.2):
217
+ """Generate responses for image input using the selected model."""
218
+ if model_name == "ByteDance-s-Dolphin":
219
+ if image is None:
220
+ yield "Please upload an image."
221
+ return
222
+ markdown_content = process_image_with_dolphin(image)
223
+ yield markdown_content
224
+ else:
225
+ # Existing logic for other models
226
+ if model_name == "olmOCR-7B-0225-preview":
227
+ processor = processor_m
228
+ model = model_m
229
+ elif model_name == "SmolDocling-256M-preview":
230
+ processor = processor_x
231
+ model = model_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  else:
233
+ yield "Invalid model selected."
234
+ return
235
+
236
+ if image is None:
237
+ yield "Please upload an image."
238
+ return
239
+
240
+ images = [image]
241
+
242
+ if model_name == "SmolDocling-256M-preview":
243
+ if "OTSL" in text or "code" in text:
244
+ images = [add_random_padding(img) for img in images]
245
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
246
+ text = normalize_values(text, target_max=500)
247
+
248
+ messages = [
249
+ {
250
+ "role": "user",
251
+ "content": [{"type": "image"} for _ in images] + [
252
+ {"type": "text", "text": text}
253
+ ]
254
+ }
255
+ ]
256
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
257
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
258
+
259
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
260
  generation_kwargs = {
261
  **inputs,
262
  "streamer": streamer,
263
  "max_new_tokens": max_new_tokens,
 
264
  "temperature": temperature,
265
  "top_p": top_p,
266
  "top_k": top_k,
267
  "repetition_penalty": repetition_penalty,
268
  }
269
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
270
  thread.start()
271
+
272
  buffer = ""
273
+ full_output = ""
274
  for new_text in streamer:
275
+ full_output += new_text
276
+ buffer += new_text.replace("<|im_end|>", "")
 
277
  yield buffer
278
+
279
+ if model_name == "SmolDocling-256M-preview":
280
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
281
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
282
+ if "<chart>" in cleaned_output:
283
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
284
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
285
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
286
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
287
+ markdown_output = doc.export_to_markdown()
288
+ yield f"**MD Output:**\n\n{markdown_output}"
289
+ else:
290
+ yield cleaned_output
291
+
292
+ @spaces.GPU
293
+ def generate_video(model_name: str, text: str, video_path: str,
294
+ max_new_tokens: int = 1024,
295
+ temperature: float = 0.6,
296
+ top_p: float = 0.9,
297
+ top_k: int = 50,
298
+ repetition_penalty: float = 1.2):
299
+ """Generate responses for video input using the selected model."""
300
+ if model_name == "ByteDance-s-Dolphin":
301
+ if video_path is None:
302
+ yield "Please upload a video."
303
+ return
304
+ frames = downsample_video(video_path)
305
+ markdown_contents = []
306
+ for frame, _ in frames:
307
+ markdown_content = process_image_with_dolphin(frame)
308
+ markdown_contents.append(markdown_content)
309
+ combined_markdown = "\n\n".join(markdown_contents)
310
+ yield combined_markdown
311
+ else:
312
+ # Existing logic for other models
313
+ if model_name == "olmOCR-7B-0225-preview":
314
+ processor = processor_m
315
+ model = model_m
316
+ elif model_name == "SmolDocling-256M-preview":
317
+ processor = processor_x
318
+ model = model_x
319
  else:
320
+ yield "Invalid model selected."
321
+ return
322
+
323
+ if video_path is None:
324
+ yield "Please upload a video."
325
+ return
326
+
327
+ frames = downsample_video(video_path)
328
+ images = [frame for frame, _ in frames]
329
+
330
+ if model_name == "SmolDocling-256M-preview":
331
+ if "OTSL" in text or "code" in text:
332
+ images = [add_random_padding(img) for img in images]
333
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
334
+ text = normalize_values(text, target_max=500)
335
+
336
+ messages = [
337
+ {
338
+ "role": "user",
339
+ "content": [{"type": "image"} for _ in images] + [
340
+ {"type": "text", "text": text}
341
+ ]
342
+ }
343
+ ]
344
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
345
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
346
+
347
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  generation_kwargs = {
349
+ **inputs,
350
  "streamer": streamer,
351
  "max_new_tokens": max_new_tokens,
352
+ "temperature": temperature,
353
  "top_p": top_p,
354
  "top_k": top_k,
 
 
355
  "repetition_penalty": repetition_penalty,
356
  }
357
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
358
+ thread.start()
359
+
360
+ buffer = ""
361
+ full_output = ""
362
  for new_text in streamer:
363
+ full_output += new_text
364
+ buffer += new_text.replace("<|im_end|>", "")
365
+ yield buffer
366
+
367
+ if model_name == "SmolDocling-256M-preview":
368
+ cleaned_output = full_output.replace("<end_of_utterance>", "").strip()
369
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
370
+ if "<chart>" in cleaned_output:
371
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
372
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
373
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
374
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
375
+ markdown_output = doc.export_to_markdown()
376
+ yield f"**MD Output:**\n\n{markdown_output}"
377
+ else:
378
+ yield cleaned_output
379
+
380
+ # Define examples for image and video inference
381
+ image_examples = [
382
+ ["Convert this page to docling", "images/1.png"],
383
+ ["OCR the image", "images/2.jpg"],
384
+ ["Convert this page to docling", "images/3.png"],
385
+ ]
386
+
387
+ video_examples = [
388
+ ["Explain the ad in detail", "example/1.mp4"],
389
+ ["Identify the main actions in the coca cola ad...", "example/2.mp4"]
390
+ ]
391
+
392
+ css = """
393
+ .submit-btn {
394
+ background-color: #2980b9 !important;
395
+ color: white !important;
396
+ }
397
+ .submit-btn:hover {
398
+ background-color: #3498db !important;
399
+ }
400
+ """
401
+
402
+ # Create the Gradio Interface
403
+ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
404
+ gr.Markdown("# **[Core OCR](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
405
+ with gr.Row():
406
+ with gr.Column():
407
+ with gr.Tabs():
408
+ with gr.TabItem("Image Inference"):
409
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
410
+ image_upload = gr.Image(type="pil", label="Image")
411
+ image_submit = gr.Button("Submit", elem_classes="submit-btn")
412
+ gr.Examples(
413
+ examples=image_examples,
414
+ inputs=[image_query, image_upload]
415
+ )
416
+ with gr.TabItem("Video Inference"):
417
+ video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
418
+ video_upload = gr.Video(label="Video")
419
+ video_submit = gr.Button("Submit", elem_classes="submit-btn")
420
+ gr.Examples(
421
+ examples=video_examples,
422
+ inputs=[video_query, video_upload]
423
+ )
424
+ with gr.Accordion("Advanced options", open=False):
425
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
426
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
427
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
428
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
429
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
430
+ with gr.Column():
431
+ output = gr.Textbox(label="Output", interactive=False, lines=3, scale=2)
432
+ model_choice = gr.Radio(
433
+ choices=["olmOCR-7B-0225-preview", "SmolDocling-256M-preview", "ByteDance-s-Dolphin"],
434
+ label="Select Model",
435
+ value="olmOCR-7B-0225-preview"
436
+ )
437
+
438
+ image_submit.click(
439
+ fn=generate_image,
440
+ inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
441
+ outputs=output
442
+ )
443
+ video_submit.click(
444
+ fn=generate_video,
445
+ inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
446
+ outputs=output
447
+ )
448
 
449
  if __name__ == "__main__":
450
+ demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)