陈硕 commited on
Commit
e7b0784
·
1 Parent(s): 538c34c
Files changed (1) hide show
  1. app.py +71 -51
app.py CHANGED
@@ -28,6 +28,8 @@ from diffusers import (
28
  )
29
  from diffusers.utils import load_video, load_image
30
  from datetime import datetime, timedelta
 
 
31
 
32
  from diffusers.image_processor import VaeImageProcessor
33
  from openai import OpenAI
@@ -171,55 +173,73 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
171
  return temp_video_path
172
 
173
 
174
- def convert_prompt(prompt: str, retry_times: int = 3) -> str:
175
- if not os.environ.get("OPENAI_API_KEY"):
176
- return prompt
177
- client = OpenAI()
178
- text = prompt.strip()
179
-
180
- for i in range(retry_times):
181
- response = client.chat.completions.create(
182
- messages=[
183
- {"role": "system", "content": sys_prompt},
184
- {
185
- "role": "user",
186
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
187
- },
188
- {
189
- "role": "assistant",
190
- "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
191
- },
192
- {
193
- "role": "user",
194
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
195
- },
196
- {
197
- "role": "assistant",
198
- "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
199
- },
200
- {
201
- "role": "user",
202
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
203
- },
204
- {
205
- "role": "assistant",
206
- "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
207
- },
208
- {
209
- "role": "user",
210
- "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
211
- },
212
- ],
213
- model="glm-4-plus",
214
- temperature=0.01,
215
- top_p=0.7,
216
- stream=False,
217
- max_tokens=200,
218
- )
219
- if response.choices:
220
- return response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  return prompt
222
 
 
223
  @spaces.GPU
224
  def infer(
225
  prompt: str,
@@ -323,11 +343,11 @@ with gr.Blocks() as demo:
323
  """)
324
  with gr.Row():
325
  with gr.Column():
326
- image_in = gr.Image(label="Image Input", type="filepath")
327
  examples_component_images = gr.Examples(examples_images, inputs=[image_in], cache_examples=False)
328
  # prompt = gr.Textbox(label="Prompt")
329
  orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
330
- submit_btn = gr.Button("Submit")
331
 
332
  # with gr.Column():
333
  # with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
@@ -341,9 +361,9 @@ with gr.Blocks() as demo:
341
 
342
  with gr.Row():
343
  gr.Markdown(
344
- "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
345
  )
346
- enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
347
  with gr.Group():
348
  with gr.Column():
349
  with gr.Row():
 
28
  )
29
  from diffusers.utils import load_video, load_image
30
  from datetime import datetime, timedelta
31
+ from PIL import Image
32
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
33
 
34
  from diffusers.image_processor import VaeImageProcessor
35
  from openai import OpenAI
 
173
  return temp_video_path
174
 
175
 
176
+ def convert_prompt(prompt: str, image_path: str = None, retry_times: int = 3) -> str:
177
+ # Define model and tokenizer paths
178
+ MODEL_PATH = "THUDM/cogagent-chat-hf"
179
+ TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5"
180
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
181
+ torch_type = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
182
+
183
+ # Initialize model and tokenizer
184
+ tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
185
+ model = AutoModelForCausalLM.from_pretrained(
186
+ MODEL_PATH,
187
+ torch_dtype=torch_type,
188
+ low_cpu_mem_usage=True,
189
+ trust_remote_code=True
190
+ ).to(DEVICE).eval()
191
+
192
+ # Conversation template for text-only queries
193
+ text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
194
+
195
+ # Check if image is available
196
+ if image_path and os.path.isfile(image_path):
197
+ image = Image.open(image_path).convert('RGB')
198
+ else:
199
+ image = None
200
+
201
+ # Initialize history for conversation context
202
+ history = []
203
+ query = prompt.strip()
204
+
205
+ for _ in range(retry_times):
206
+ if image is None:
207
+ # Text-only query, format as required by CogAgent
208
+ query = text_only_template.format(query)
209
+ input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base')
210
+ inputs = {
211
+ 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
212
+ 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
213
+ 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE)
214
+ }
215
+ else:
216
+ # Image-based input with initial query
217
+ input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
218
+ inputs = {
219
+ 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
220
+ 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
221
+ 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
222
+ 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]]
223
+ }
224
+ if 'cross_images' in input_by_model and input_by_model['cross_images']:
225
+ inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
226
+
227
+ # Generation settings
228
+ gen_kwargs = {"max_length": 2048, "do_sample": False}
229
+
230
+ with torch.no_grad():
231
+ outputs = model.generate(**inputs, **gen_kwargs)
232
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
233
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
234
+ response = response.split("</s>")[0].strip() # Clean up response
235
+
236
+ if response:
237
+ return response # Return the response if generated successfully
238
+
239
+ # Return original prompt if all retries fail
240
  return prompt
241
 
242
+
243
  @spaces.GPU
244
  def infer(
245
  prompt: str,
 
343
  """)
344
  with gr.Row():
345
  with gr.Column():
346
+ image_in = gr.Image(label="Input Image (will be cropped to 720 * 480)")
347
  examples_component_images = gr.Examples(examples_images, inputs=[image_in], cache_examples=False)
348
  # prompt = gr.Textbox(label="Prompt")
349
  orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
350
+ # submit_btn = gr.Button("Submit")
351
 
352
  # with gr.Column():
353
  # with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
 
361
 
362
  with gr.Row():
363
  gr.Markdown(
364
+ "✨Upon pressing the enhanced prompt button, we will use [CogVLM](https://github.com/THUDM/CogVLM) to polish the prompt and overwrite the original one."
365
  )
366
+ enhance_button = gr.Button("✨ Enhance Prompt(Optional but highly recommend)")
367
  with gr.Group():
368
  with gr.Column():
369
  with gr.Row():