Music-Generator / app.py
nullHawk's picture
fix: UI changes
5e5aa33 verified
raw
history blame
1.84 kB
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from model import MusicLSTM
from train import DataLoader, Config, generate_song as generate_ABC_notation
from utils import load_vocab
from convert import abc_to_audio
import subprocess
class GradioApp():
def __init__(self):
# Set up configuration and data
subprocess.run(['./setup.sh'], check=True)
self.config = Config()
self.CHECKPOINT_FILE = "checkpoint/model.pth"
self.data_loader = DataLoader(self.config.INPUT_FILE, self.config)
self.checkpoint = torch.load(self.CHECKPOINT_FILE, weights_only=False)
char_idx, char_list = load_vocab()
self.model = MusicLSTM(
input_size=len(char_idx),
hidden_size=self.config.HIDDEN_SIZE,
output_size=len(char_idx),
)
self.model.load_state_dict(self.checkpoint)
self.model.eval()
def launch(self):
# Define Gradio interface without a clear button
with gr.Blocks() as demo:
gr.Markdown("# AI Music Generator")
gr.Markdown("Click the button below to generate a new random song using a trained RNN model.")
generate_button = gr.Button("Generate Music")
output_audio = gr.Audio(label="Generated Music")
generate_button.click(self.generate_music, inputs=None, outputs=output_audio)
demo.launch()
def generate_music(self, input):
"""Generate a new song using the trained model."""
abc_notation = generate_ABC_notation(self.model, self.data_loader)
abc_notation = abc_notation.strip("<start>").strip("<end>").strip()
audio = abc_to_audio(abc_notation)
return audio
if __name__ == '__main__':
app = GradioApp()
app.launch()