prithivMLmods commited on
Commit
86a82e4
·
verified ·
1 Parent(s): 8c4fbc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -89
app.py CHANGED
@@ -7,123 +7,93 @@ import torch
7
  import spaces
8
  import subprocess
9
 
 
10
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
 
12
- from io import BytesIO
13
-
14
  processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
15
- model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct",
16
- _attn_implementation="flash_attention_2",
17
- torch_dtype=torch.bfloat16).to("cuda:0")
 
 
18
 
19
  @spaces.GPU
20
- def model_inference(
21
- input_dict, history, max_tokens
22
- ):
23
- text = input_dict["text"]
24
- images = []
25
- user_content = []
26
  media_queue = []
27
- if history == []:
28
- text = input_dict["text"].strip()
29
-
30
- for file in input_dict.get("files", []):
31
- if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
32
- media_queue.append({"type": "image", "path": file})
33
- elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
34
- media_queue.append({"type": "video", "path": file})
35
 
36
- if "<image>" in text or "<video>" in text:
37
- parts = re.split(r'(<image>|<video>)', text)
38
- for part in parts:
39
- if part == "<image>" and media_queue:
40
- user_content.append(media_queue.pop(0))
41
- elif part == "<video>" and media_queue:
42
- user_content.append(media_queue.pop(0))
43
- elif part.strip():
44
- user_content.append({"type": "text", "text": part.strip()})
45
- else:
46
- user_content.append({"type": "text", "text": text})
47
-
48
- for media in media_queue:
49
- user_content.append(media)
50
 
51
- resulting_messages = [{"role": "user", "content": user_content}]
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- elif len(history) > 0:
54
- resulting_messages = []
55
- user_content = []
56
- media_queue = []
57
- for hist in history:
58
- if hist["role"] == "user" and isinstance(hist["content"], tuple):
59
- file_name = hist["content"][0]
60
- if file_name.endswith((".png", ".jpg", ".jpeg")):
61
- media_queue.append({"type": "image", "path": file_name})
62
- elif file_name.endswith(".mp4"):
63
- media_queue.append({"type": "video", "path": file_name})
64
 
 
 
65
  for hist in history:
66
- if hist["role"] == "user" and isinstance(hist["content"], str):
67
- text = hist["content"]
68
- parts = re.split(r'(<image>|<video>)', text)
69
-
70
- for part in parts:
71
- if part == "<image>" and media_queue:
72
- user_content.append(media_queue.pop(0))
73
- elif part == "<video>" and media_queue:
74
- user_content.append(media_queue.pop(0))
75
- elif part.strip():
76
- user_content.append({"type": "text", "text": part.strip()})
77
-
78
  elif hist["role"] == "assistant":
79
- resulting_messages.append({
80
- "role": "user",
81
- "content": user_content
82
- })
83
- resulting_messages.append({
84
- "role": "assistant",
85
- "content": [{"type": "text", "text": hist["content"]}]
86
- })
87
- user_content = []
88
 
89
- if text == "" and not images:
90
- gr.Error("Please input a query and optionally image(s).")
91
 
92
- if text == "" and images:
93
- gr.Error("Please input a text query along the images(s).")
94
- print("resulting_messages", resulting_messages)
95
  inputs = processor.apply_chat_template(
96
  resulting_messages,
97
  add_generation_prompt=True,
98
  tokenize=True,
99
  return_dict=True,
100
- return_tensors="pt",
101
- )
102
-
103
- inputs = inputs.to(model.device)
104
-
105
- # Generate
106
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
107
- generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
108
- generated_text = ""
109
 
110
- thread = Thread(target=model.generate, kwargs=generation_args)
 
 
111
  thread.start()
112
 
113
  yield "..."
114
  buffer = ""
115
-
116
  for new_text in streamer:
117
  buffer += new_text
118
  time.sleep(0.01)
119
  yield buffer
120
 
121
- demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺",
122
- description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
123
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
124
- cache_examples=False,
125
- additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
126
- type="messages"
127
- )
 
 
 
 
128
 
129
  demo.launch(debug=True, share=True)
 
7
  import spaces
8
  import subprocess
9
 
10
+ # Install flash-attn with no CUDA build
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
 
 
13
  processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
14
+ model = AutoModelForImageTextToText.from_pretrained(
15
+ "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
16
+ _attn_implementation="flash_attention_2",
17
+ torch_dtype=torch.bfloat16
18
+ ).to("cuda:0")
19
 
20
  @spaces.GPU
21
+ def model_inference(input_dict, history, max_tokens):
22
+ text = input_dict.get("text", "").strip()
 
 
 
 
23
  media_queue = []
24
+ user_content = []
 
 
 
 
 
 
 
25
 
26
+ # Process uploaded media files
27
+ for file in input_dict.get("files", []):
28
+ if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
29
+ media_queue.append({"type": "image", "path": file})
30
+ elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
31
+ media_queue.append({"type": "video", "path": file})
 
 
 
 
 
 
 
 
32
 
33
+ # Construct user content with placeholders
34
+ if "<image>" in text or "<video>" in text:
35
+ parts = re.split(r'(<image>|<video>)', text)
36
+ for part in parts:
37
+ if part == "<image>" and media_queue:
38
+ user_content.append(media_queue.pop(0))
39
+ elif part == "<video>" and media_queue:
40
+ user_content.append(media_queue.pop(0))
41
+ elif part.strip():
42
+ user_content.append({"type": "text", "text": part.strip()})
43
+ else:
44
+ user_content.append({"type": "text", "text": text})
45
+ user_content.extend(media_queue)
46
 
47
+ resulting_messages = [{"role": "user", "content": user_content}]
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Process history
50
+ if history:
51
  for hist in history:
52
+ if hist["role"] == "user":
53
+ if isinstance(hist["content"], tuple) and len(hist["content"]) > 0:
54
+ file_name = hist["content"][0]
55
+ if file_name.endswith((".png", ".jpg", ".jpeg")):
56
+ media_queue.append({"type": "image", "path": file_name})
57
+ elif file_name.endswith(".mp4"):
58
+ media_queue.append({"type": "video", "path": file_name})
59
+
 
 
 
 
60
  elif hist["role"] == "assistant":
61
+ resulting_messages.append({"role": "assistant", "content": [{"type": "text", "text": hist["content"]}]})
 
 
 
 
 
 
 
 
62
 
63
+ if not text and not media_queue:
64
+ gr.Warning("Please provide text or an image/video.")
65
 
66
+ # Process inputs
 
 
67
  inputs = processor.apply_chat_template(
68
  resulting_messages,
69
  add_generation_prompt=True,
70
  tokenize=True,
71
  return_dict=True,
72
+ return_tensors="pt"
73
+ ).to(model.device, dtype=torch.bfloat16) # Ensure dtype consistency
 
 
 
 
 
 
 
74
 
75
+ # Generate output
76
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
77
+ thread = Thread(target=model.generate, kwargs={"input_ids": inputs["input_ids"], "streamer": streamer, "max_new_tokens": max_tokens})
78
  thread.start()
79
 
80
  yield "..."
81
  buffer = ""
 
82
  for new_text in streamer:
83
  buffer += new_text
84
  time.sleep(0.01)
85
  yield buffer
86
 
87
+ demo = gr.ChatInterface(
88
+ fn=model_inference,
89
+ title="SmolVLM2: The Smollest Video Model Ever 📺",
90
+ description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text.",
91
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"),
92
+ stop_btn="Stop Generation",
93
+ multimodal=True,
94
+ cache_examples=False,
95
+ additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
96
+ type="messages"
97
+ )
98
 
99
  demo.launch(debug=True, share=True)