suric's picture
adjuct captioning
061e8b0
raw
history blame
17.2 kB
import os
import gradio as gr
from gradio_components.image import generate_caption
from gradio_components.prediction import predict, transcribe
theme = gr.themes.Glass(
primary_hue="fuchsia",
secondary_hue="indigo",
neutral_hue="slate",
font=[
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
).set(
body_background_fill_dark="*background_fill_primary",
embed_radius="*table_radius",
background_fill_primary="*neutral_50",
background_fill_primary_dark="*neutral_950",
background_fill_secondary_dark="*neutral_900",
border_color_accent="*neutral_600",
border_color_accent_subdued="*color_accent",
border_color_primary_dark="*neutral_700",
block_background_fill="*background_fill_primary",
block_background_fill_dark="*neutral_800",
block_border_width="1px",
block_label_background_fill="*background_fill_primary",
block_label_background_fill_dark="*background_fill_secondary",
block_label_text_color="*neutral_500",
block_label_text_size="*text_sm",
block_label_text_weight="400",
block_shadow="none",
block_shadow_dark="none",
block_title_text_color="*neutral_500",
block_title_text_weight="400",
panel_border_width="0",
panel_border_width_dark="0",
checkbox_background_color_dark="*neutral_800",
checkbox_border_width="*input_border_width",
checkbox_label_border_width="*input_border_width",
input_background_fill="*neutral_100",
input_background_fill_dark="*neutral_700",
input_border_color_focus_dark="*neutral_700",
input_border_width="0px",
input_border_width_dark="0px",
slider_color="#2563eb",
slider_color_dark="#2563eb",
table_even_background_fill_dark="*neutral_950",
table_odd_background_fill_dark="*neutral_900",
button_border_width="*input_border_width",
button_shadow_active="none",
button_primary_background_fill="*primary_200",
button_primary_background_fill_dark="*primary_700",
button_primary_background_fill_hover="*button_primary_background_fill",
button_primary_background_fill_hover_dark="*button_primary_background_fill",
button_secondary_background_fill="*neutral_200",
button_secondary_background_fill_dark="*neutral_600",
button_secondary_background_fill_hover="*button_secondary_background_fill",
button_secondary_background_fill_hover_dark="*button_secondary_background_fill",
button_cancel_background_fill="*button_secondary_background_fill",
button_cancel_background_fill_dark="*button_secondary_background_fill",
button_cancel_background_fill_hover="*button_cancel_background_fill",
button_cancel_background_fill_hover_dark="*button_cancel_background_fill",
)
_AUDIOCRAFT_MODELS = [
"facebook/musicgen-melody",
"facebook/musicgen-medium",
"facebook/musicgen-small",
"facebook/musicgen-large",
"facebook/musicgen-melody-large",
"facebook/audiogen-medium",
]
def generate_prompt(difficulty, style):
_DIFFICULTY_MAPPIN = {
"Easy": "beginner player",
"Medum": "player who has 2-3 years experience",
"Hard": "player who has more than 4 years experiences",
}
prompt = "piano only music for a {} to pratice with the touch of {}".format(
_DIFFICULTY_MAPPIN[difficulty], style
)
return prompt
def toggle_melody_condition(melody_condition):
if melody_condition:
return gr.Audio(
sources=["microphone", "upload"],
label="Record or upload your audio",
show_label=True,
visible=True,
)
else:
return gr.Audio(
sources=["microphone", "upload"],
label="Record or upload your audio",
show_label=True,
visible=False,
)
def show_caption(show_caption_condition, description, prompt):
if show_caption_condition:
return (
gr.Textbox(
label="Image Caption",
value=description,
interactive=False,
show_label=True,
visible=True,
),
gr.Textbox(
label="Generated Prompt",
value=prompt,
interactive=True,
show_label=True,
visible=True,
),
gr.Button("Generate Music", interactive=True, visible=True),
)
else:
return (
gr.Textbox(
label="Image Caption",
value=description,
interactive=False,
show_label=True,
visible=False,
),
gr.Textbox(
label="Generated Prompt",
value=prompt,
interactive=True,
show_label=True,
visible=False,
),
gr.Button(label="Generate Music", interactive=True, visible=True),
)
def post_submit(show_caption, model_path, image_input):
_, description, prompt = generate_caption(image_input, model_path)
return (
gr.Textbox(
label="Image Caption",
value=description,
interactive=False,
show_label=True,
visible=show_caption,
),
gr.Textbox(
label="Generated Prompt",
value=prompt,
interactive=True,
show_label=True,
visible=show_caption,
),
gr.Button("Generate Music", interactive=True, visible=True),
)
def UI():
with gr.Blocks() as demo:
with gr.Tab("Generate Music by melody"):
with gr.Row():
with gr.Column():
with gr.Row():
model_path = gr.Dropdown(
choices=_AUDIOCRAFT_MODELS,
label="Select the model",
value="facebook/musicgen-melody-large",
)
with gr.Row():
duration = gr.Slider(
minimum=10,
maximum=60,
value=10,
label="Duration",
interactive=True,
)
with gr.Row():
topk = gr.Number(label="Top-k", value=250, interactive=True)
topp = gr.Number(label="Top-p", value=0, interactive=True)
temperature = gr.Number(
label="Temperature", value=1.0, interactive=True
)
sample_rate = gr.Number(
label="output music sample rate",
value=32000,
interactive=True,
)
difficulty = gr.Radio(
["Easy", "Medium", "Hard"],
label="Difficulty",
value="Easy",
interactive=True,
)
style = gr.Radio(
["Jazz", "Classical Music", "Hip Hop", "Others"],
value="Classical Music",
label="music genre",
interactive=True,
)
if style == "Others":
style = gr.Textbox(label="Type your music genre")
prompt = generate_prompt(difficulty.value, style.value)
customize = gr.Checkbox(
label="Customize the prompt", interactive=True
)
if customize:
prompt = gr.Textbox(label="Type your prompt")
with gr.Column():
with gr.Row():
melody = gr.Audio(
sources=["microphone", "upload"],
label="Record or upload your audio",
# interactive=True,
show_label=True,
)
with gr.Row():
submit = gr.Button("Generate Music")
output_audio = gr.Audio(
"listen to the generated music", type="filepath"
)
with gr.Row():
transcribe_button = gr.Button("Transcribe")
d = gr.DownloadButton("Download the file", visible=False)
transcribe_button.click(
transcribe, inputs=[output_audio], outputs=d
)
submit.click(
fn=predict,
inputs=[
model_path,
prompt,
melody,
duration,
topk,
topp,
temperature,
sample_rate,
],
outputs=output_audio,
)
gr.Examples(
examples=[
[
os.path.join(
os.path.dirname(__file__),
"./data/audio/twinkle_twinkle_little_stars_mozart_20sec"
".mp3",
),
"Easy",
32000,
20,
],
[
os.path.join(
os.path.dirname(__file__),
"./data/audio/golden_hour_20sec.mp3",
),
"Easy",
32000,
20,
],
[
os.path.join(
os.path.dirname(__file__),
"./data/audio/turkish_march_mozart_20sec.mp3",
),
"Easy",
32000,
20,
],
[
os.path.join(
os.path.dirname(__file__),
"./data/audio/golden_hour_20sec.mp3",
),
"Hard",
32000,
20,
],
[
os.path.join(
os.path.dirname(__file__),
"./data/audio/golden_hour_20sec.mp3",
),
"Hard",
32000,
40,
],
[
os.path.join(
os.path.dirname(__file__),
"./data/audio/golden_hour_20sec.mp3",
),
"Hard",
16000,
20,
],
],
inputs=[melody, difficulty, sample_rate, duration],
label="Audio Examples",
outputs=[output_audio],
# cache_examples=True,
)
with gr.Tab("Generate Music by image"):
with gr.Row():
with gr.Column():
image_input = gr.Image("Upload an image", type="filepath")
melody_condition = gr.Checkbox(
label="Generate music by melody", interactive=True, value=False
)
melody = gr.Audio(
sources=["microphone", "upload"],
label="Record or upload your audio",
show_label=True,
visible=False,
)
melody_condition.change(
fn=toggle_melody_condition,
inputs=[melody_condition],
outputs=melody,
)
description = gr.Textbox(
label="Image Captioning",
show_label=True,
interactive=False,
visible=False,
)
prompt = gr.Textbox(
label="Generated Prompt",
show_label=True,
interactive=True,
visible=False,
)
show_prompt = gr.Checkbox(label="Show the prompt", interactive=True)
submit = gr.Button("submit", interactive=True, visible=True)
generate = gr.Button(
"Generate Music", interactive=True, visible=False
)
with gr.Column():
with gr.Row():
model_path = gr.Dropdown(
choices=_AUDIOCRAFT_MODELS,
label="Select the model",
value="facebook/musicgen-large",
)
with gr.Row():
duration = gr.Slider(
minimum=10,
maximum=60,
value=10,
label="Duration",
interactive=True,
)
topk = gr.Number(label="Top-k", value=250, interactive=True)
topp = gr.Number(label="Top-p", value=0, interactive=True)
temperature = gr.Number(
label="Temperature", value=1.0, interactive=True
)
sample_rate = gr.Number(
label="output music sample rate", value=32000, interactive=True
)
with gr.Column():
output_audio = gr.Audio(
"listen to the generated music",
type="filepath",
show_label=True,
)
transcribe_button = gr.Button("Transcribe")
d = gr.DownloadButton("Download the file", visible=False)
submit.click(
fn=post_submit,
inputs=[show_prompt, image_input, model_path],
outputs=[description, prompt, generate],
)
show_prompt.change(
fn=show_caption,
inputs=[show_prompt, description, prompt],
outputs=[description, prompt, generate],
)
transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
generate.click(
fn=predict,
inputs=[
model_path,
prompt,
melody,
duration,
topk,
topp,
temperature,
sample_rate,
],
outputs=output_audio,
)
gr.Examples(
examples=[
[
os.path.join(
os.path.dirname(__file__),
"./data/image/kids_drawing.jpeg",
),
False,
None,
"facebook/musicgen-large",
],
[
os.path.join(
os.path.dirname(__file__),
"./data/image/cat.jpeg",
),
False,
None,
"facebook/musicgen-large",
],
[
os.path.join(
os.path.dirname(__file__),
"./data/image/cat.jpeg",
),
True,
"./data/audio/the_nutcracker_dance_of_the_reed_flutes.mp3",
"facebook/musicgen-melody-large",
],
[
os.path.join(
os.path.dirname(__file__),
"./data/image/beach.jpeg",
),
False,
None,
"facebook/audiogen-medium",
],
],
inputs=[image_input, melody_condition, melody, model_path],
label="Audio Examples",
outputs=[output_audio],
# cache_examples=True,
)
demo.queue().launch()
if __name__ == "__main__":
UI()