MaziyarPanahi KingNish commited on
Commit
868ab7f
·
verified ·
1 Parent(s): df30ad6

Added Video Support (#18)

Browse files

- Added Video Support (24f8595f5086f8051c077849203d663bfba52f7e)
- Update requirements.txt (cb25f513b7f5cca52c02e57a1ffadb3d0bbbd80f)
- Added streaming output and error handling (b9caa337fc287469039d003e38787dc4db8123b4)


Co-authored-by: Nishith Jain <[email protected]>

Files changed (2) hide show
  1. app.py +89 -69
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,79 +1,95 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
  from qwen_vl_utils import process_vision_info
5
  import torch
6
  from PIL import Image
7
  import subprocess
8
- from datetime import datetime
9
  import numpy as np
10
  import os
 
 
 
11
 
12
-
13
- # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
14
-
15
- # models = {
16
- # "Qwen/Qwen2-VL-2B-Instruct": AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
17
-
18
- # }
19
- def array_to_image_path(image_array):
20
- # Convert numpy array to PIL Image
21
- img = Image.fromarray(np.uint8(image_array))
22
-
23
- # Generate a unique filename using timestamp
24
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
25
- filename = f"image_{timestamp}.png"
26
-
27
- # Save the image
28
- img.save(filename)
29
-
30
- # Get the full path of the saved image
31
- full_path = os.path.abspath(filename)
32
-
33
- return full_path
34
-
35
- models = {
36
- "Qwen/Qwen2-VL-2B-Instruct": Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype="auto").cuda().eval()
37
-
38
- }
39
-
40
- processors = {
41
- "Qwen/Qwen2-VL-2B-Instruct": AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
42
- }
43
 
44
  DESCRIPTION = "[Qwen2-VL-2B Demo](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)"
45
 
46
- kwargs = {}
47
- kwargs['torch_dtype'] = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- user_prompt = '<|user|>\n'
50
- assistant_prompt = '<|assistant|>\n'
51
- prompt_suffix = "<|end|>\n"
52
 
53
  @spaces.GPU
54
- def run_example(image, text_input=None, model_id="Qwen/Qwen2-VL-2B-Instruct"):
55
- image_path = array_to_image_path(image)
56
-
57
- print(image_path)
58
- model = models[model_id]
59
- processor = processors[model_id]
60
-
61
- prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
62
- image = Image.fromarray(image).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
63
  messages = [
64
- {
65
  "role": "user",
66
  "content": [
67
  {
68
- "type": "image",
69
- "image": image_path,
 
70
  },
71
  {"type": "text", "text": text_input},
72
  ],
73
  }
74
  ]
75
-
76
- # Preparation for inference
77
  text = processor.apply_chat_template(
78
  messages, tokenize=False, add_generation_prompt=True
79
  )
@@ -84,19 +100,20 @@ def run_example(image, text_input=None, model_id="Qwen/Qwen2-VL-2B-Instruct"):
84
  videos=video_inputs,
85
  padding=True,
86
  return_tensors="pt",
 
 
 
 
87
  )
88
- inputs = inputs.to("cuda")
89
-
90
- # Inference: Generation of the output
91
- generated_ids = model.generate(**inputs, max_new_tokens=128)
92
- generated_ids_trimmed = [
93
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
94
- ]
95
- output_text = processor.batch_decode(
96
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
97
- )
98
-
99
- return output_text[0]
100
 
101
  css = """
102
  #output {
@@ -108,17 +125,20 @@ css = """
108
 
109
  with gr.Blocks(css=css) as demo:
110
  gr.Markdown(DESCRIPTION)
111
- with gr.Tab(label="Qwen2-VL-2B Input"):
 
112
  with gr.Row():
113
  with gr.Column():
114
- input_img = gr.Image(label="Input Picture")
115
- model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="Qwen/Qwen2-VL-2B-Instruct")
 
116
  text_input = gr.Textbox(label="Question")
117
  submit_btn = gr.Button(value="Submit")
118
  with gr.Column():
119
  output_text = gr.Textbox(label="Output Text")
120
 
121
- submit_btn.click(run_example, [input_img, text_input, model_selector], [output_text])
 
 
122
 
123
- demo.queue(api_open=False)
124
  demo.launch(debug=True)
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
  from qwen_vl_utils import process_vision_info
5
  import torch
6
  from PIL import Image
7
  import subprocess
 
8
  import numpy as np
9
  import os
10
+ from threading import Thread
11
+ import uuid
12
+ import io
13
 
14
+ # Model and Processor Loading (Done once at startup)
15
+ MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
16
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
17
+ MODEL_ID,
18
+ trust_remote_code=True,
19
+ torch_dtype=torch.float16
20
+ ).to("cuda").eval()
21
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  DESCRIPTION = "[Qwen2-VL-2B Demo](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)"
24
 
25
+ image_extensions = Image.registered_extensions()
26
+ video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
27
+
28
+
29
+ def identify_and_save_blob(blob_path):
30
+ """Identifies if the blob is an image or video and saves it accordingly."""
31
+ try:
32
+ with open(blob_path, 'rb') as file:
33
+ blob_content = file.read()
34
+
35
+ # Try to identify if it's an image
36
+ try:
37
+ Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
38
+ extension = ".png" # Default to PNG for saving
39
+ media_type = "image"
40
+ except (IOError, SyntaxError):
41
+ # If it's not a valid image, assume it's a video
42
+ extension = ".mp4" # Default to MP4 for saving
43
+ media_type = "video"
44
+
45
+ # Create a unique filename
46
+ filename = f"temp_{uuid.uuid4()}_media{extension}"
47
+ with open(filename, "wb") as f:
48
+ f.write(blob_content)
49
+
50
+ return filename, media_type
51
+
52
+ except FileNotFoundError:
53
+ raise ValueError(f"The file {blob_path} was not found.")
54
+ except Exception as e:
55
+ raise ValueError(f"An error occurred while processing the file: {e}")
56
 
 
 
 
57
 
58
  @spaces.GPU
59
+ def qwen_inference(media_input, text_input=None):
60
+ if isinstance(media_input, str): # If it's a filepath
61
+ media_path = media_input
62
+ if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
63
+ media_type = "image"
64
+ elif media_path.endswith(video_extensions):
65
+ media_type = "video"
66
+ else:
67
+ try:
68
+ media_path, media_type = identify_and_save_blob(media_input)
69
+ print(media_path, media_type)
70
+ except Exception as e:
71
+ print(e)
72
+ raise ValueError(
73
+ "Unsupported media type. Please upload an image or video."
74
+ )
75
+
76
+
77
+ print(media_path)
78
+
79
  messages = [
80
+ {
81
  "role": "user",
82
  "content": [
83
  {
84
+ "type": media_type,
85
+ media_type: media_path,
86
+ **({"fps": 8.0} if media_type == "video" else {}),
87
  },
88
  {"type": "text", "text": text_input},
89
  ],
90
  }
91
  ]
92
+
 
93
  text = processor.apply_chat_template(
94
  messages, tokenize=False, add_generation_prompt=True
95
  )
 
100
  videos=video_inputs,
101
  padding=True,
102
  return_tensors="pt",
103
+ ).to("cuda")
104
+
105
+ streamer = TextIteratorStreamer(
106
+ processor, skip_prompt=True, **{"skip_special_tokens": True}
107
  )
108
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
109
+
110
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
111
+ thread.start()
112
+
113
+ buffer = ""
114
+ for new_text in streamer:
115
+ buffer += new_text
116
+ yield buffer
 
 
 
117
 
118
  css = """
119
  #output {
 
125
 
126
  with gr.Blocks(css=css) as demo:
127
  gr.Markdown(DESCRIPTION)
128
+
129
+ with gr.Tab(label="Image/Video Input"):
130
  with gr.Row():
131
  with gr.Column():
132
+ input_media = gr.File(
133
+ label="Upload Image or Video", type="filepath"
134
+ )
135
  text_input = gr.Textbox(label="Question")
136
  submit_btn = gr.Button(value="Submit")
137
  with gr.Column():
138
  output_text = gr.Textbox(label="Output Text")
139
 
140
+ submit_btn.click(
141
+ qwen_inference, [input_media, text_input], [output_text]
142
+ )
143
 
 
144
  demo.launch(debug=True)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ torch
5
  torchvision
6
  git+https://github.com/huggingface/transformers.git
7
  accelerate
8
- qwen-vl-utils
 
 
5
  torchvision
6
  git+https://github.com/huggingface/transformers.git
7
  accelerate
8
+ qwen-vl-utils
9
+ av