qnguyen3 commited on
Commit
b6c6d0c
·
verified ·
1 Parent(s): 13b8f1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -35
app.py CHANGED
@@ -6,23 +6,17 @@ from threading import Thread
6
  import re
7
  import time
8
  from PIL import Image
9
- import torch
10
  import spaces
11
  import subprocess
12
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
 
14
- torch.set_default_device('cuda')
15
-
16
  tokenizer = AutoTokenizer.from_pretrained(
17
  'qnguyen3/nanoLLaVA-1.5',
18
  trust_remote_code=True)
19
 
20
- model = LlavaQwen2ForCausalLM.from_pretrained(
21
- 'qnguyen3/nanoLLaVA-1.5',
22
- torch_dtype=torch.float16,
23
- attn_implementation="flash_attention_2",
24
- trust_remote_code=True,
25
- device_map='cpu')
26
 
27
  class KeywordsStoppingCriteria(StoppingCriteria):
28
  def __init__(self, keywords, tokenizer, input_ids):
@@ -61,15 +55,34 @@ class KeywordsStoppingCriteria(StoppingCriteria):
61
 
62
  @spaces.GPU
63
  def bot_streaming(message, history):
64
- messages = []
 
 
 
 
 
 
 
 
 
 
 
 
65
  if message["files"]:
66
- image = message["files"][-1]["path"]
67
  else:
68
- for i, hist in enumerate(history):
69
- if type(hist[0])==tuple:
70
- image = hist[0][0]
71
- image_turn = i
72
-
 
 
 
 
 
 
 
73
  if len(history) > 0 and image is not None:
74
  messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
75
  messages.append({"role": "assistant", "content": history[1][1] })
@@ -86,44 +99,57 @@ def bot_streaming(message, history):
86
  messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
87
  elif len(history) == 0 and image is None:
88
  messages.append({"role": "user", "content": message['text'] })
89
- model = model.to('cuda')
90
 
91
- # if image is None:
92
- # gr.Error("You need to upload an image for LLaVA to work.")
93
  image = Image.open(image).convert("RGB")
 
 
94
  text = tokenizer.apply_chat_template(
95
  messages,
96
  tokenize=False,
97
  add_generation_prompt=True)
98
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
99
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
 
 
100
  stop_str = '<|im_end|>'
101
  keywords = [stop_str]
102
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
103
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
104
 
 
105
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
106
- generation_kwargs = dict(input_ids=input_ids.to('cuda'),
107
- images=image_tensor.to('cuda'),
108
- streamer=streamer, max_new_tokens=512,
109
- stopping_criteria=[stopping_criteria], temperature=0.01)
110
- generated_text = ""
 
 
 
 
111
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
112
  thread.start()
113
- text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
114
 
 
115
  buffer = ""
116
  for new_text in streamer:
117
-
118
- buffer += new_text
119
-
120
- generated_text_without_prompt = buffer[:]
121
- time.sleep(0.04)
122
- yield generated_text_without_prompt
123
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- demo = gr.ChatInterface(fn=bot_streaming, title="🚀nanoLLaVA-1.5", examples=[{"text": "Who is this guy?", "files":["./demo_1.jpg"]},
126
- {"text": "What does the text say?", "files":["./demo_2.jpeg"]}],
127
- description="Try [nanoLLaVA](https://huggingface.co/qnguyen3/nanoLLaVA-1.5) in this demo. Built on top of [Quyen-SE-v0.1](https://huggingface.co/vilm/Quyen-SE-v0.1) (Qwen1.5-0.5B) and [Google SigLIP-400M](https://huggingface.co/google/siglip-so400m-patch14-384). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
128
- stop_btn="Stop Generation", multimodal=True)
129
  demo.queue().launch()
 
6
  import re
7
  import time
8
  from PIL import Image
 
9
  import spaces
10
  import subprocess
11
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
+ # Initialize tokenizer (doesn't require CUDA)
 
14
  tokenizer = AutoTokenizer.from_pretrained(
15
  'qnguyen3/nanoLLaVA-1.5',
16
  trust_remote_code=True)
17
 
18
+ # Don't initialize model here - move it to the GPU-decorated function
19
+ model = None
 
 
 
 
20
 
21
  class KeywordsStoppingCriteria(StoppingCriteria):
22
  def __init__(self, keywords, tokenizer, input_ids):
 
55
 
56
  @spaces.GPU
57
  def bot_streaming(message, history):
58
+ global model
59
+
60
+ # Initialize the model inside the GPU-decorated function
61
+ if model is None:
62
+ model = LlavaQwen2ForCausalLM.from_pretrained(
63
+ 'qnguyen3/nanoLLaVA-1.5',
64
+ torch_dtype=torch.float16,
65
+ attn_implementation="flash_attention_2",
66
+ trust_remote_code=True,
67
+ device_map="auto") # Use "auto" instead of 'cpu' then manual to('cuda')
68
+
69
+ # Get image path
70
+ image = None
71
  if message["files"]:
72
+ image = message["files"][-1]["path"]
73
  else:
74
+ for i, hist in enumerate(history):
75
+ if type(hist[0])==tuple:
76
+ image = hist[0][0]
77
+ image_turn = i
78
+ break
79
+
80
+ # Check if image is available
81
+ if image is None:
82
+ return "Please upload an image for LLaVA to work."
83
+
84
+ # Prepare conversation messages
85
+ messages = []
86
  if len(history) > 0 and image is not None:
87
  messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
88
  messages.append({"role": "assistant", "content": history[1][1] })
 
99
  messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
100
  elif len(history) == 0 and image is None:
101
  messages.append({"role": "user", "content": message['text'] })
 
102
 
103
+ # Process image
 
104
  image = Image.open(image).convert("RGB")
105
+
106
+ # Prepare input for generation
107
  text = tokenizer.apply_chat_template(
108
  messages,
109
  tokenize=False,
110
  add_generation_prompt=True)
111
  text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
112
  input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
113
+
114
+ # Prepare stopping criteria
115
  stop_str = '<|im_end|>'
116
  keywords = [stop_str]
117
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
118
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
119
 
120
+ # Process image and generate text
121
  image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
122
+ generation_kwargs = dict(
123
+ input_ids=input_ids,
124
+ images=image_tensor,
125
+ streamer=streamer,
126
+ max_new_tokens=512,
127
+ stopping_criteria=[stopping_criteria],
128
+ temperature=0.01
129
+ )
130
+
131
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
132
  thread.start()
 
133
 
134
+ # Stream response
135
  buffer = ""
136
  for new_text in streamer:
137
+ buffer += new_text
138
+ generated_text_without_prompt = buffer[:]
139
+ time.sleep(0.04)
140
+ yield generated_text_without_prompt
141
+
 
142
 
143
+ demo = gr.ChatInterface(
144
+ fn=bot_streaming,
145
+ title="🚀nanoLLaVA-1.5",
146
+ examples=[
147
+ {"text": "Who is this guy?", "files":["./demo_1.jpg"]},
148
+ {"text": "What does the text say?", "files":["./demo_2.jpeg"]}
149
+ ],
150
+ description="Try [nanoLLaVA](https://huggingface.co/qnguyen3/nanoLLaVA-1.5) in this demo. Built on top of [Quyen-SE-v0.1](https://huggingface.co/vilm/Quyen-SE-v0.1) (Qwen1.5-0.5B) and [Google SigLIP-400M](https://huggingface.co/google/siglip-so400m-patch14-384). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
151
+ stop_btn="Stop Generation",
152
+ multimodal=True
153
+ )
154
 
 
 
 
 
155
  demo.queue().launch()