prithivMLmods commited on
Commit
aad98bd
·
verified ·
1 Parent(s): cb3f55e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -9
app.py CHANGED
@@ -16,7 +16,6 @@ import cv2
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
19
- Qwen2_5_VLForConditionalGeneration,
20
  AutoModelForImageTextToText,
21
  AutoProcessor,
22
  TextIteratorStreamer,
@@ -28,12 +27,13 @@ MAX_MAX_NEW_TOKENS = 2048
28
  DEFAULT_MAX_NEW_TOKENS = 1024
29
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
 
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
  # Load VIREX-062225-exp
34
  MODEL_ID_M = "prithivMLmods/VIREX-062225-exp"
35
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
36
- model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
37
  MODEL_ID_M,
38
  trust_remote_code=True,
39
  torch_dtype=torch.float16
@@ -42,13 +42,13 @@ model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
42
  # Load DREX-062225-exp
43
  MODEL_ID_X = "prithivMLmods/DREX-062225-exp"
44
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
45
- model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
  MODEL_ID_X,
47
  trust_remote_code=True,
48
  torch_dtype=torch.float16
49
  ).to(device).eval()
50
 
51
- # Load Gemma3n-E4B-it
52
  MODEL_ID_G = "google/gemma-3n-E4B-it"
53
  processor_g = AutoProcessor.from_pretrained(MODEL_ID_G, trust_remote_code=True)
54
  model_g = AutoModelForImageTextToText.from_pretrained(
@@ -57,7 +57,7 @@ model_g = AutoModelForImageTextToText.from_pretrained(
57
  torch_dtype=torch.float16
58
  ).to(device).eval()
59
 
60
- # Load Gemma3n-E2B-it
61
  MODEL_ID_N = "google/gemma-3n-E2B-it"
62
  processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
63
  model_n = AutoModelForImageTextToText.from_pretrained(
@@ -128,7 +128,7 @@ def generate_image(model_name: str, text: str, image_path: str,
128
  add_generation_prompt=True,
129
  return_dict=True,
130
  return_tensors="pt",
131
- truncation=False,
132
  max_length=MAX_INPUT_TOKEN_LENGTH
133
  ).to(device)
134
  else:
@@ -138,10 +138,16 @@ def generate_image(model_name: str, text: str, image_path: str,
138
  images=[image_path],
139
  return_tensors="pt",
140
  padding=True,
141
- truncation=False,
142
  max_length=MAX_INPUT_TOKEN_LENGTH
143
  ).to(device)
144
 
 
 
 
 
 
 
145
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
146
  generation_kwargs = {
147
  **inputs,
@@ -153,6 +159,12 @@ def generate_image(model_name: str, text: str, image_path: str,
153
  "top_k": top_k,
154
  "repetition_penalty": repetition_penalty,
155
  }
 
 
 
 
 
 
156
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
157
  thread.start()
158
  buffer = ""
@@ -205,7 +217,7 @@ def generate_video(model_name: str, text: str, video_path: str,
205
  add_generation_prompt=True,
206
  return_dict=True,
207
  return_tensors="pt",
208
- truncation=False,
209
  max_length=MAX_INPUT_TOKEN_LENGTH
210
  ).to(device)
211
  else:
@@ -216,10 +228,16 @@ def generate_video(model_name: str, text: str, video_path: str,
216
  images=images,
217
  return_tensors="pt",
218
  padding=True,
219
- truncation=False,
220
  max_length=MAX_INPUT_TOKEN_LENGTH
221
  ).to(device)
222
 
 
 
 
 
 
 
223
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
224
  generation_kwargs = {
225
  **inputs,
@@ -231,6 +249,12 @@ def generate_video(model_name: str, text: str, video_path: str,
231
  "top_k": top_k,
232
  "repetition_penalty": repetition_penalty,
233
  }
 
 
 
 
 
 
234
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
235
  thread.start()
236
  buffer = ""
 
16
 
17
  from transformers import (
18
  Qwen2VLForConditionalGeneration,
 
19
  AutoModelForImageTextToText,
20
  AutoProcessor,
21
  TextIteratorStreamer,
 
27
  DEFAULT_MAX_NEW_TOKENS = 1024
28
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
29
 
30
+ # Determine device
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
  # Load VIREX-062225-exp
34
  MODEL_ID_M = "prithivMLmods/VIREX-062225-exp"
35
  processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
36
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
37
  MODEL_ID_M,
38
  trust_remote_code=True,
39
  torch_dtype=torch.float16
 
42
  # Load DREX-062225-exp
43
  MODEL_ID_X = "prithivMLmods/DREX-062225-exp"
44
  processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
45
+ model_x = Qwen2VLForConditionalGeneration.from_pretrained(
46
  MODEL_ID_X,
47
  trust_remote_code=True,
48
  torch_dtype=torch.float16
49
  ).to(device).eval()
50
 
51
+ # Load Gemma3n-E4B-it (Placeholder: Adjust model class if incorrect)
52
  MODEL_ID_G = "google/gemma-3n-E4B-it"
53
  processor_g = AutoProcessor.from_pretrained(MODEL_ID_G, trust_remote_code=True)
54
  model_g = AutoModelForImageTextToText.from_pretrained(
 
57
  torch_dtype=torch.float16
58
  ).to(device).eval()
59
 
60
+ # Load Gemma3n-E2B-it (Placeholder: Adjust model class if incorrect)
61
  MODEL_ID_N = "google/gemma-3n-E2B-it"
62
  processor_n = AutoProcessor.from_pretrained(MODEL_ID_N, trust_remote_code=True)
63
  model_n = AutoModelForImageTextToText.from_pretrained(
 
128
  add_generation_prompt=True,
129
  return_dict=True,
130
  return_tensors="pt",
131
+ truncation=True, # Enable truncation to prevent overflow
132
  max_length=MAX_INPUT_TOKEN_LENGTH
133
  ).to(device)
134
  else:
 
138
  images=[image_path],
139
  return_tensors="pt",
140
  padding=True,
141
+ truncation=True, # Enable truncation to prevent overflow
142
  max_length=MAX_INPUT_TOKEN_LENGTH
143
  ).to(device)
144
 
145
+ # Check input token length
146
+ input_length = inputs["input_ids"].shape[1]
147
+ if input_length > MAX_INPUT_TOKEN_LENGTH:
148
+ yield f"Input too long. Max {MAX_INPUT_TOKEN_LENGTH} tokens. Got {input_length} tokens.", ""
149
+ return
150
+
151
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
152
  generation_kwargs = {
153
  **inputs,
 
159
  "top_k": top_k,
160
  "repetition_penalty": repetition_penalty,
161
  }
162
+
163
+ # Ensure all tensors are on the correct device
164
+ for key in generation_kwargs:
165
+ if isinstance(generation_kwargs[key], torch.Tensor):
166
+ generation_kwargs[key] = generation_kwargs[key].to(device)
167
+
168
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
169
  thread.start()
170
  buffer = ""
 
217
  add_generation_prompt=True,
218
  return_dict=True,
219
  return_tensors="pt",
220
+ truncation=True, # Enable truncation to prevent overflow
221
  max_length=MAX_INPUT_TOKEN_LENGTH
222
  ).to(device)
223
  else:
 
228
  images=images,
229
  return_tensors="pt",
230
  padding=True,
231
+ truncation=True, # Enable truncation to prevent overflow
232
  max_length=MAX_INPUT_TOKEN_LENGTH
233
  ).to(device)
234
 
235
+ # Check input token length
236
+ input_length = inputs["input_ids"].shape[1]
237
+ if input_length > MAX_INPUT_TOKEN_LENGTH:
238
+ yield f"Input too long. Max {MAX_INPUT_TOKEN_LENGTH} tokens. Got {input_length} tokens.", ""
239
+ return
240
+
241
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
242
  generation_kwargs = {
243
  **inputs,
 
249
  "top_k": top_k,
250
  "repetition_penalty": repetition_penalty,
251
  }
252
+
253
+ # Ensure all tensors are on the correct device
254
+ for key in generation_kwargs:
255
+ if isinstance(generation_kwargs[key], torch.Tensor):
256
+ generation_kwargs[key] = generation_kwargs[key].to(device)
257
+
258
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
259
  thread.start()
260
  buffer = ""