import gradio as gr
import json
import rtmidi
import os
import argparse
import base64
import io
import numpy as np
from huggingface_hub import hf_hub_download
import onnxruntime as rt
import MIDI
from midi_synthesizer import MidiSynthesizer
from midi_tokenizer import MIDITokenizer

# Match the JavaScript constant
MIDI_OUTPUT_BATCH_SIZE = 4

class MIDIDeviceManager:
    """Manages MIDI input/output devices."""
    def __init__(self):
        self.midiout = rtmidi.MidiOut()
        self.midiin = rtmidi.MidiIn()

    def get_device_info(self):
        """Returns a string listing available MIDI devices."""
        out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
        in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
        return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"

    def close(self):
        """Closes open MIDI ports."""
        if self.midiout.is_port_open():
            self.midiout.close_port()
        if self.midiin.is_port_open():
            self.midiin.close_port()
        del self.midiout, self.midiin

class MIDIManager:
    """Handles MIDI processing, generation, and playback."""
    def __init__(self):
        # Load soundfont and models from Hugging Face
        self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
        self.synthesizer = MidiSynthesizer(self.soundfont_path)
        self.tokenizer = self._load_tokenizer("skytnt/midi-model")
        self.model_base = rt.InferenceSession(
            hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
            providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
        )
        self.model_token = rt.InferenceSession(
            hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
            providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
        )
        self.generated_files = []

    def _load_tokenizer(self, repo_id):
        """Loads the MIDI tokenizer configuration."""
        config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        with open(config_path, "r") as f:
            config = json.load(f)
        tokenizer = MIDITokenizer(config["tokenizer"]["version"])
        tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
        return tokenizer

    def load_midi(self, file_path):
        """Loads a MIDI file from the given path."""
        return MIDI.load(file_path)

    def generate_onnx(self, midi_data):
        """Generates a MIDI variation using ONNX models."""
        mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
        input_tensor = np.array([mid_seq], dtype=np.int64)
        cur_len = input_tensor.shape[1]
        max_len = 1024
        while cur_len < max_len:
            inputs = {"x": input_tensor[:, -1:]}
            hidden = self.model_base.run(None, inputs)[0]
            logits = self.model_token.run(None, {"hidden": hidden})[0]
            probs = self._softmax(logits, axis=-1)
            next_token = self._sample_top_p_k(probs, 0.98, 20)
            input_tensor = np.concatenate([input_tensor, next_token], axis=1)
            cur_len += 1
        new_seq = input_tensor[0].tolist()
        generated_midi = self.tokenizer.detokenize(new_seq)
        # Store base64-encoded MIDI data for downloads
        midi_bytes = MIDI.save(generated_midi)
        self.generated_files.append(base64.b64encode(midi_bytes).decode('utf-8'))
        return generated_midi

    def play_midi(self, midi_data):
        """Renders MIDI data to audio bytes."""
        midi_bytes = base64.b64decode(midi_data)
        midi_file = MIDI.load(io.BytesIO(midi_bytes))
        audio = io.BytesIO()
        self.synthesizer.render_midi(midi_file, audio)
        audio.seek(0)
        return audio

    @staticmethod
    def _softmax(x, axis):
        """Computes softmax probabilities."""
        exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
        return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

    @staticmethod
    def _sample_top_p_k(probs, p, k):
        """Samples a token using top-p and top-k sampling (simplified)."""
        # Placeholder: replace with actual sampling logic if needed
        return np.array([[np.random.choice(len(probs[0]))]])

def process_midi(files):
    """Processes uploaded MIDI files and yields updates for Gradio components."""
    if not files:
        yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
        return
    
    for idx, file in enumerate(files):
        output_idx = idx % MIDI_OUTPUT_BATCH_SIZE
        midi_data = midi_processor.load_midi(file.name)
        generated_midi = midi_processor.generate_onnx(midi_data)
        
        # Placeholder for MIDI events; in practice, extract from generated_midi
        # Expected format: ["note", delta_time, track, channel, pitch, velocity, duration]
        events = [
            ["note", 0, 0, 0, 60, 100, 1000],  # Example event
            # Add logic to convert generated_midi to events using tokenizer
        ]
        
        # Prepare updates list: [js_msg, audio0, midi0, audio1, midi1, ...]
        updates = [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)
        
        # Clear visualizer
        updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_clear", "data": [output_idx, "v2"]}]))
        yield updates
        
        # Send MIDI events
        updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_append", "data": [output_idx, events]}]))
        yield updates
        
        # Finalize visualizer and update audio/MIDI outputs
        audio_update = midi_processor.play_midi(generated_midi)
        midi_update = gr.File.update(value=generated_midi, label=f"Generated MIDI {output_idx}")
        updates[0] = js_msg.update(value=json.dumps([{"name": "visualizer_end", "data": output_idx}]))
        updates[1 + 2 * output_idx] = audio_update  # Audio component
        updates[2 + 2 * output_idx] = midi_update  # MIDI file component
        yield updates
    
    # Final yield to ensure all components are in a stable state
    yield [gr.update()] * (1 + 2 * MIDI_OUTPUT_BATCH_SIZE)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="MIDI Composer App")
    parser.add_argument("--port", type=int, default=7860, help="Server port")
    parser.add_argument("--share", action="store_true", help="Share the app publicly")
    opt = parser.parse_args()

    device_manager = MIDIDeviceManager()
    midi_processor = MIDIManager()

    with gr.Blocks(theme=gr.themes.Soft()) as app:
        # Hidden textbox for sending messages to JS
        js_msg = gr.Textbox(visible=False, elem_id="msg_receiver")
        
        with gr.Tabs():
            # MIDI Prompt Tab
            with gr.Tab("MIDI Prompt"):
                midi_upload = gr.File(label="Upload MIDI File(s)", file_count="multiple")
                generate_btn = gr.Button("Generate")
                status = gr.Textbox(label="Status", value="Ready", interactive=False)

            # Outputs Tab
            with gr.Tab("Outputs"):
                output_audios = []
                output_midis = []
                for i in range(MIDI_OUTPUT_BATCH_SIZE):
                    with gr.Column():
                        gr.Markdown(f"## Output {i+1}")
                        gr.HTML(elem_id=f"midi_visualizer_container_{i}")
                        output_audio = gr.Audio(label="Generated Audio", type="bytes", autoplay=True, elem_id=f"midi_audio_{i}")
                        output_midi = gr.File(label="Generated MIDI", file_types=[".mid"])
                        output_audios.append(output_audio)
                        output_midis.append(output_midi)

            # Devices Tab
            with gr.Tab("Devices"):
                device_info = gr.Textbox(label="Connected MIDI Devices", value=device_manager.get_device_info(), interactive=False)
                refresh_btn = gr.Button("Refresh Devices")
                refresh_btn.click(fn=lambda: device_manager.get_device_info(), outputs=[device_info])

        # Define output components for event handling
        outputs = [js_msg] + output_audios + output_midis

        # Bind the generate button to the processing function
        generate_btn.click(fn=process_midi, inputs=[midi_upload], outputs=outputs)

    # Launch the app
    app.launch(server_port=opt.port, share=opt.share, inbrowser=True)
    device_manager.close()