hysts HF staff commited on
Commit
3a57265
·
1 Parent(s): 08f127f
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -1,12 +1,15 @@
1
  #!/usr/bin/env python
2
 
 
 
3
  from collections.abc import Iterator
4
  from threading import Thread
5
 
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
- import re
10
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
11
 
12
  model_id = "google/gemma-3-12b-it"
@@ -15,17 +18,13 @@ model = Gemma3ForConditionalGeneration.from_pretrained(
15
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
16
  )
17
 
18
- import cv2
19
- from PIL import Image
20
- import numpy as np
21
- import tempfile
22
 
23
  def downsample_video(video_path):
24
  vidcap = cv2.VideoCapture(video_path)
25
  fps = vidcap.get(cv2.CAP_PROP_FPS)
26
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
27
-
28
- frame_interval = int(fps / 3)
29
  frames = []
30
 
31
  for i in range(0, total_frames, frame_interval):
@@ -34,7 +33,7 @@ def downsample_video(video_path):
34
  if success:
35
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
36
  pil_image = Image.fromarray(image)
37
- timestamp = round(i / fps, 2)
38
  frames.append((pil_image, timestamp))
39
 
40
  vidcap.release()
@@ -46,8 +45,8 @@ def process_new_user_message(message: dict) -> list[dict]:
46
  if "<image>" in message["text"]:
47
  content = []
48
  print("message[files]", message["files"])
49
- parts = re.split(r'(<image>)', message["text"])
50
- image_index = 0
51
  print("parts", parts)
52
  for part in parts:
53
  print("part", part)
@@ -55,29 +54,30 @@ def process_new_user_message(message: dict) -> list[dict]:
55
  content.append({"type": "image", "url": message["files"][image_index]})
56
  print("file", message["files"][image_index])
57
  image_index += 1
58
- elif part.strip():
59
  content.append({"type": "text", "text": part.strip()})
60
  elif isinstance(part, str) and not part == "<image>":
61
  content.append({"type": "text", "text": part})
62
  print(content)
63
  return content
64
- elif message["files"][0].endswith(".mp4"):
65
  content = []
66
  video = message["files"].pop(0)
67
  frames = downsample_video(video)
68
  for frame in frames:
69
  pil_image, timestamp = frame
70
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_file:
71
  pil_image.save(temp_file.name)
72
  content.append({"type": "text", "text": f"Frame {timestamp}:"})
73
  content.append({"type": "image", "url": temp_file.name})
74
  print(content)
75
  return content
76
- else:
77
- # non interleaved images
78
- return [{"type": "text", "text": message["text"]}, *[{"type": "image", "url": path} for path in message["files"]]]
79
- else:
80
- return [{"type": "text", "text": message["text"]}]
 
81
 
82
 
83
  def process_history(history: list[dict]) -> list[dict]:
 
1
  #!/usr/bin/env python
2
 
3
+ import re
4
+ import tempfile
5
  from collections.abc import Iterator
6
  from threading import Thread
7
 
8
+ import cv2
9
  import gradio as gr
10
  import spaces
11
  import torch
12
+ from PIL import Image
13
  from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
14
 
15
  model_id = "google/gemma-3-12b-it"
 
18
  model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
19
  )
20
 
 
 
 
 
21
 
22
  def downsample_video(video_path):
23
  vidcap = cv2.VideoCapture(video_path)
24
  fps = vidcap.get(cv2.CAP_PROP_FPS)
25
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
26
+
27
+ frame_interval = int(fps / 3)
28
  frames = []
29
 
30
  for i in range(0, total_frames, frame_interval):
 
33
  if success:
34
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
35
  pil_image = Image.fromarray(image)
36
+ timestamp = round(i / fps, 2)
37
  frames.append((pil_image, timestamp))
38
 
39
  vidcap.release()
 
45
  if "<image>" in message["text"]:
46
  content = []
47
  print("message[files]", message["files"])
48
+ parts = re.split(r"(<image>)", message["text"])
49
+ image_index = 0
50
  print("parts", parts)
51
  for part in parts:
52
  print("part", part)
 
54
  content.append({"type": "image", "url": message["files"][image_index]})
55
  print("file", message["files"][image_index])
56
  image_index += 1
57
+ elif part.strip():
58
  content.append({"type": "text", "text": part.strip()})
59
  elif isinstance(part, str) and not part == "<image>":
60
  content.append({"type": "text", "text": part})
61
  print(content)
62
  return content
63
+ if message["files"][0].endswith(".mp4"):
64
  content = []
65
  video = message["files"].pop(0)
66
  frames = downsample_video(video)
67
  for frame in frames:
68
  pil_image, timestamp = frame
69
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
70
  pil_image.save(temp_file.name)
71
  content.append({"type": "text", "text": f"Frame {timestamp}:"})
72
  content.append({"type": "image", "url": temp_file.name})
73
  print(content)
74
  return content
75
+ # non interleaved images
76
+ return [
77
+ {"type": "text", "text": message["text"]},
78
+ *[{"type": "image", "url": path} for path in message["files"]],
79
+ ]
80
+ return [{"type": "text", "text": message["text"]}]
81
 
82
 
83
  def process_history(history: list[dict]) -> list[dict]: