suric commited on
Commit
061e8b0
·
1 Parent(s): a99ae87

adjuct captioning

Browse files
app.py CHANGED
@@ -145,8 +145,8 @@ def show_caption(show_caption_condition, description, prompt):
145
  )
146
 
147
 
148
- def post_submit(show_caption, image_input):
149
- _, description, prompt = generate_caption(image_input)
150
  return (
151
  gr.Textbox(
152
  label="Image Caption",
@@ -349,16 +349,6 @@ def UI():
349
  generate = gr.Button(
350
  "Generate Music", interactive=True, visible=False
351
  )
352
- submit.click(
353
- fn=post_submit,
354
- inputs=[show_prompt, image_input],
355
- outputs=[description, prompt, generate],
356
- )
357
- show_prompt.change(
358
- fn=show_caption,
359
- inputs=[show_prompt, description, prompt],
360
- outputs=[description, prompt, generate],
361
- )
362
 
363
  with gr.Column():
364
  with gr.Row():
@@ -391,6 +381,16 @@ def UI():
391
  )
392
  transcribe_button = gr.Button("Transcribe")
393
  d = gr.DownloadButton("Download the file", visible=False)
 
 
 
 
 
 
 
 
 
 
394
  transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
395
  generate.click(
396
  fn=predict,
 
145
  )
146
 
147
 
148
+ def post_submit(show_caption, model_path, image_input):
149
+ _, description, prompt = generate_caption(image_input, model_path)
150
  return (
151
  gr.Textbox(
152
  label="Image Caption",
 
349
  generate = gr.Button(
350
  "Generate Music", interactive=True, visible=False
351
  )
 
 
 
 
 
 
 
 
 
 
352
 
353
  with gr.Column():
354
  with gr.Row():
 
381
  )
382
  transcribe_button = gr.Button("Transcribe")
383
  d = gr.DownloadButton("Download the file", visible=False)
384
+ submit.click(
385
+ fn=post_submit,
386
+ inputs=[show_prompt, image_input, model_path],
387
+ outputs=[description, prompt, generate],
388
+ )
389
+ show_prompt.change(
390
+ fn=show_caption,
391
+ inputs=[show_prompt, description, prompt],
392
+ outputs=[description, prompt, generate],
393
+ )
394
  transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
395
  generate.click(
396
  fn=predict,
gradio_components/image.py CHANGED
@@ -22,15 +22,25 @@ Try to make the prompt simple and concise with only 1-2 sentences
22
 
23
  Make sure the ouput is in JSON fomat, with two items `description` and `prompt`"""
24
 
 
 
 
 
 
 
25
 
26
- def generate_caption(image_file, progress=gr.Progress()):
 
 
 
 
27
  with open(image_file, "rb") as f:
28
  image_encoded = base64.b64encode(f.read()).decode("utf-8")
29
  progress(0, desc="Starting image captioning...")
30
  message = client.messages.create(
31
  model="claude-3-opus-20240229",
32
  max_tokens=1024,
33
- system=SYSTEM_PROMPT,
34
  messages=[
35
  {
36
  "role": "user",
 
22
 
23
  Make sure the ouput is in JSON fomat, with two items `description` and `prompt`"""
24
 
25
+ SYSTEM_PROMPT_AUDIO = """You are an expert llm prompt engineer, you understand the structure of llms and facebook musicgen text to audio model. You will be provided with an image, and require to output a prompt for the musicgen model to capture the essense of the image. Try to do it step by step, evaluate and analyze the image thoroughly. After that, develop a prompt that contains the detail of what background sounds this image should have. This prompt will be provided to audiogen model to generate a 15s audio clip.
26
+ Try to make the prompt simple and concise with only 1-2 sentences
27
+
28
+ Make sure the ouput is in JSON fomat, with two items `description` and `prompt`
29
+ """
30
+
31
 
32
+ def generate_caption(image_file, model_file, progress=gr.Progress()):
33
+ if model_file == "facebook/audiogen-medium":
34
+ system_prompt = SYSTEM_PROMPT_AUDIO
35
+ else:
36
+ system_prompt = SYSTEM_PROMPT
37
  with open(image_file, "rb") as f:
38
  image_encoded = base64.b64encode(f.read()).decode("utf-8")
39
  progress(0, desc="Starting image captioning...")
40
  message = client.messages.create(
41
  model="claude-3-opus-20240229",
42
  max_tokens=1024,
43
+ system=system_prompt,
44
  messages=[
45
  {
46
  "role": "user",
gradio_components/prediction.py CHANGED
@@ -21,6 +21,7 @@ def load_model(version="facebook/musicgen-melody"):
21
 
22
 
23
  def _do_predictions(
 
24
  model,
25
  texts,
26
  melodies,
@@ -65,8 +66,16 @@ def _do_predictions(
65
  return_tokens=False,
66
  )
67
  else:
68
- # text only
69
- outputs = model.generate(texts, progress=progress, return_tokens=False)
 
 
 
 
 
 
 
 
70
  except RuntimeError as e:
71
  raise gr.Error("Error while generating " + e.args[0])
72
  outputs = outputs.detach().cpu().float()
@@ -132,6 +141,7 @@ def predict(
132
  model.set_custom_progress_callback(_progress)
133
 
134
  wavs = _do_predictions(
 
135
  model,
136
  [text],
137
  [melody],
 
21
 
22
 
23
  def _do_predictions(
24
+ model_file,
25
  model,
26
  texts,
27
  melodies,
 
66
  return_tokens=False,
67
  )
68
  else:
69
+ if model_file == "facebook/audiogen-medium":
70
+ # audio condition
71
+ outputs = model.generate(
72
+ texts,
73
+ progress=progress
74
+ )
75
+ else:
76
+ # text only
77
+ outputs = model.generate(texts, progress=progress)
78
+
79
  except RuntimeError as e:
80
  raise gr.Error("Error while generating " + e.args[0])
81
  outputs = outputs.detach().cpu().float()
 
141
  model.set_custom_progress_callback(_progress)
142
 
143
  wavs = _do_predictions(
144
+ model_path,
145
  model,
146
  [text],
147
  [melody],