import torch import torch.nn as nn import numpy as np import gradio as gr from model import MusicLSTM from train import DataLoader, Config, generate_song as generate_ABC_notation from utils import load_vocab from convert import abc_to_audio import subprocess class GradioApp(): def __init__(self): # Set up configuration and data subprocess.run(['./setup.sh'], check=True) self.config = Config() self.CHECKPOINT_FILE = "checkpoint/model.pth" self.data_loader = DataLoader(self.config.INPUT_FILE, self.config) self.checkpoint = torch.load(self.CHECKPOINT_FILE, weights_only=False) char_idx, char_list = load_vocab() self.model = MusicLSTM( input_size=len(char_idx), hidden_size=self.config.HIDDEN_SIZE, output_size=len(char_idx), ) self.model.load_state_dict(self.checkpoint) self.model.eval() def launch(self): # Define Gradio interface without a clear button with gr.Blocks() as demo: gr.Markdown("# AI Music Generator") gr.Markdown("Click the button below to generate a new random song using a trained RNN model.") generate_button = gr.Button("Generate Music") output_audio = gr.Audio(label="Generated Music") generate_button.click(self.generate_music, inputs=None, outputs=output_audio) demo.launch() def generate_music(self, input): """Generate a new song using the trained model.""" abc_notation = generate_ABC_notation(self.model, self.data_loader) abc_notation = abc_notation.strip("").strip("").strip() audio = abc_to_audio(abc_notation) return audio if __name__ == '__main__': app = GradioApp() app.launch()