prithivMLmods commited on
Commit
7d34bf2
·
verified ·
1 Parent(s): 804645f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -1
app.py CHANGED
@@ -16,11 +16,17 @@ import cv2
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
 
19
  AutoProcessor,
20
  TextIteratorStreamer,
21
  )
22
  from transformers.image_utils import load_image
23
 
 
 
 
 
 
24
  # Constants for text generation
25
  MAX_MAX_NEW_TOKENS = 2048
26
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -55,6 +61,16 @@ model_o = Qwen2VLForConditionalGeneration.from_pretrained(
55
  torch_dtype=torch.float16
56
  ).to(device).eval()
57
 
 
 
 
 
 
 
 
 
 
 
58
  def downsample_video(video_path):
59
  """
60
  Downsamples the video to evenly spaced frames.
@@ -95,6 +111,9 @@ def generate_image(model_name: str, text: str, image: Image.Image,
95
  elif model_name == "olmOCR-7B-0225":
96
  processor = processor_o
97
  model = model_o
 
 
 
98
  else:
99
  yield "Invalid model selected."
100
  return
@@ -149,6 +168,9 @@ def generate_video(model_name: str, text: str, video_path: str,
149
  elif model_name == "olmOCR-7B-0225":
150
  processor = processor_o
151
  model = model_o
 
 
 
152
  else:
153
  yield "Invalid model selected."
154
  return
@@ -247,7 +269,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
247
  with gr.Column():
248
  output = gr.Textbox(label="Output", interactive=False, lines=2, scale=2)
249
  model_choice = gr.Radio(
250
- choices=["VIREX-062225-exp", "DREX-062225-exp", "olmOCR-7B-0225"],
251
  label="Select Model",
252
  value="VIREX-062225-exp"
253
  )
 
16
  from transformers import (
17
  Qwen2VLForConditionalGeneration,
18
  Qwen2_5_VLForConditionalGeneration,
19
+ AutoModelForImageTextToText,
20
  AutoProcessor,
21
  TextIteratorStreamer,
22
  )
23
  from transformers.image_utils import load_image
24
 
25
+ import subprocess
26
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
27
+
28
+ from io import BytesIO
29
+
30
  # Constants for text generation
31
  MAX_MAX_NEW_TOKENS = 2048
32
  DEFAULT_MAX_NEW_TOKENS = 1024
 
61
  torch_dtype=torch.float16
62
  ).to(device).eval()
63
 
64
+ # Load SmolVLM2-2.2B-Instruct
65
+ MODEL_ID_W = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
66
+ processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True)
67
+ model_w= AutoModelForImageTextToText.from_pretrained(
68
+ MODEL_ID_W,
69
+ trust_remote_code=True,
70
+ _attn_implementation="flash_attention_2",
71
+ torch_dtype=torch.float16
72
+ ).to(device).eval()
73
+
74
  def downsample_video(video_path):
75
  """
76
  Downsamples the video to evenly spaced frames.
 
111
  elif model_name == "olmOCR-7B-0225":
112
  processor = processor_o
113
  model = model_o
114
+ elif model_name == "SmolVLM2":
115
+ processor = processor_w
116
+ model = model_w
117
  else:
118
  yield "Invalid model selected."
119
  return
 
168
  elif model_name == "olmOCR-7B-0225":
169
  processor = processor_o
170
  model = model_o
171
+ elif model_name == "SmolVLM2":
172
+ processor = processor_w
173
+ model = model_w
174
  else:
175
  yield "Invalid model selected."
176
  return
 
269
  with gr.Column():
270
  output = gr.Textbox(label="Output", interactive=False, lines=2, scale=2)
271
  model_choice = gr.Radio(
272
+ choices=["VIREX-062225-exp", "DREX-062225-exp", "olmOCR-7B-0225", "SmolVLM2"],
273
  label="Select Model",
274
  value="VIREX-062225-exp"
275
  )