jerald
disabled share
dad3b36
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
import gradio as gr
# Load the saved model and tokenizer
model_path = "trained_model"
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(model_path)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Define the function to generate a sequence based on the input text
def generate_sequence(input_text):
test_input_tokens = tokenizer(input_text, return_tensors='pt').to(device)
model.eval()
with torch.no_grad():
test_output_tokens = model.generate(test_input_tokens['input_ids'], num_return_sequences=1)
test_output_sequence = tokenizer.decode(test_output_tokens[0], skip_special_tokens=True)
return test_output_sequence
# Create a Gradio interface
iface = gr.Interface(
fn=generate_sequence,
inputs=gr.inputs.Textbox(lines=3, label="Input Text"),
outputs=gr.outputs.Textbox(label="Generated Sequence"),
title="MIDI Sequence Generator",
description="Generate a MIDI sequence based on a text description",
examples=[
"A popular classical piano piece composed by Ludwig van Beethoven",
"A beautiful and melancholic classical piano piece composed by Frédéric Chopin",
]
)
# Launch the Gradio interface
iface.launch(debug=True)