Spaces:
Running
on
Zero
Running
on
Zero
陈硕
commited on
Commit
·
e7b0784
1
Parent(s):
538c34c
add vlm
Browse files
app.py
CHANGED
@@ -28,6 +28,8 @@ from diffusers import (
|
|
28 |
)
|
29 |
from diffusers.utils import load_video, load_image
|
30 |
from datetime import datetime, timedelta
|
|
|
|
|
31 |
|
32 |
from diffusers.image_processor import VaeImageProcessor
|
33 |
from openai import OpenAI
|
@@ -171,55 +173,73 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
|
|
171 |
return temp_video_path
|
172 |
|
173 |
|
174 |
-
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
return prompt
|
222 |
|
|
|
223 |
@spaces.GPU
|
224 |
def infer(
|
225 |
prompt: str,
|
@@ -323,11 +343,11 @@ with gr.Blocks() as demo:
|
|
323 |
""")
|
324 |
with gr.Row():
|
325 |
with gr.Column():
|
326 |
-
image_in = gr.Image(label="Image
|
327 |
examples_component_images = gr.Examples(examples_images, inputs=[image_in], cache_examples=False)
|
328 |
# prompt = gr.Textbox(label="Prompt")
|
329 |
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
|
330 |
-
submit_btn = gr.Button("Submit")
|
331 |
|
332 |
# with gr.Column():
|
333 |
# with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
@@ -341,9 +361,9 @@ with gr.Blocks() as demo:
|
|
341 |
|
342 |
with gr.Row():
|
343 |
gr.Markdown(
|
344 |
-
"✨Upon pressing the enhanced prompt button, we will use [
|
345 |
)
|
346 |
-
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
347 |
with gr.Group():
|
348 |
with gr.Column():
|
349 |
with gr.Row():
|
|
|
28 |
)
|
29 |
from diffusers.utils import load_video, load_image
|
30 |
from datetime import datetime, timedelta
|
31 |
+
from PIL import Image
|
32 |
+
from transformers import AutoModelForCausalLM, LlamaTokenizer
|
33 |
|
34 |
from diffusers.image_processor import VaeImageProcessor
|
35 |
from openai import OpenAI
|
|
|
173 |
return temp_video_path
|
174 |
|
175 |
|
176 |
+
def convert_prompt(prompt: str, image_path: str = None, retry_times: int = 3) -> str:
|
177 |
+
# Define model and tokenizer paths
|
178 |
+
MODEL_PATH = "THUDM/cogagent-chat-hf"
|
179 |
+
TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5"
|
180 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
181 |
+
torch_type = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
182 |
+
|
183 |
+
# Initialize model and tokenizer
|
184 |
+
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
|
185 |
+
model = AutoModelForCausalLM.from_pretrained(
|
186 |
+
MODEL_PATH,
|
187 |
+
torch_dtype=torch_type,
|
188 |
+
low_cpu_mem_usage=True,
|
189 |
+
trust_remote_code=True
|
190 |
+
).to(DEVICE).eval()
|
191 |
+
|
192 |
+
# Conversation template for text-only queries
|
193 |
+
text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
|
194 |
+
|
195 |
+
# Check if image is available
|
196 |
+
if image_path and os.path.isfile(image_path):
|
197 |
+
image = Image.open(image_path).convert('RGB')
|
198 |
+
else:
|
199 |
+
image = None
|
200 |
+
|
201 |
+
# Initialize history for conversation context
|
202 |
+
history = []
|
203 |
+
query = prompt.strip()
|
204 |
+
|
205 |
+
for _ in range(retry_times):
|
206 |
+
if image is None:
|
207 |
+
# Text-only query, format as required by CogAgent
|
208 |
+
query = text_only_template.format(query)
|
209 |
+
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base')
|
210 |
+
inputs = {
|
211 |
+
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
|
212 |
+
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
|
213 |
+
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE)
|
214 |
+
}
|
215 |
+
else:
|
216 |
+
# Image-based input with initial query
|
217 |
+
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
|
218 |
+
inputs = {
|
219 |
+
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
|
220 |
+
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
|
221 |
+
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
|
222 |
+
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]]
|
223 |
+
}
|
224 |
+
if 'cross_images' in input_by_model and input_by_model['cross_images']:
|
225 |
+
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
|
226 |
+
|
227 |
+
# Generation settings
|
228 |
+
gen_kwargs = {"max_length": 2048, "do_sample": False}
|
229 |
+
|
230 |
+
with torch.no_grad():
|
231 |
+
outputs = model.generate(**inputs, **gen_kwargs)
|
232 |
+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
233 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
234 |
+
response = response.split("</s>")[0].strip() # Clean up response
|
235 |
+
|
236 |
+
if response:
|
237 |
+
return response # Return the response if generated successfully
|
238 |
+
|
239 |
+
# Return original prompt if all retries fail
|
240 |
return prompt
|
241 |
|
242 |
+
|
243 |
@spaces.GPU
|
244 |
def infer(
|
245 |
prompt: str,
|
|
|
343 |
""")
|
344 |
with gr.Row():
|
345 |
with gr.Column():
|
346 |
+
image_in = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
347 |
examples_component_images = gr.Examples(examples_images, inputs=[image_in], cache_examples=False)
|
348 |
# prompt = gr.Textbox(label="Prompt")
|
349 |
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
|
350 |
+
# submit_btn = gr.Button("Submit")
|
351 |
|
352 |
# with gr.Column():
|
353 |
# with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
|
|
361 |
|
362 |
with gr.Row():
|
363 |
gr.Markdown(
|
364 |
+
"✨Upon pressing the enhanced prompt button, we will use [CogVLM](https://github.com/THUDM/CogVLM) to polish the prompt and overwrite the original one."
|
365 |
)
|
366 |
+
enhance_button = gr.Button("✨ Enhance Prompt(Optional but highly recommend)")
|
367 |
with gr.Group():
|
368 |
with gr.Column():
|
369 |
with gr.Row():
|