import gradio as gr import sys import threading import queue from io import TextIOBase from inference import inference_patch import datetime import subprocess import os # Predefined valid combinations set with open('prompts.txt', 'r') as f: prompts = f.readlines() valid_combinations = set() for prompt in prompts: prompt = prompt.strip() parts = prompt.split('_') valid_combinations.add((parts[0], parts[1], parts[2])) # Generate available options periods = sorted({p for p, _, _ in valid_combinations}) composers = sorted({c for _, c, _ in valid_combinations}) instruments = sorted({i for _, _, i in valid_combinations}) # Dynamic component updates def update_components(period, composer): if not period: return [ gr.Dropdown(choices=[], value=None, interactive=False), gr.Dropdown(choices=[], value=None, interactive=False) ] valid_composers = sorted({c for p, c, _ in valid_combinations if p == period}) valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else [] return [ gr.Dropdown( choices=valid_composers, value=composer if composer in valid_composers else None, interactive=True ), gr.Dropdown( choices=valid_instruments, value=None, interactive=bool(valid_instruments) ) ] class RealtimeStream(TextIOBase): def __init__(self, queue): self.queue = queue def write(self, text): self.queue.put(text) return len(text) def save_and_convert(abc_content, period, composer, instrumentation): if not all([period, composer, instrumentation]): raise gr.Error("Please complete a valid generation first before saving") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") prompt_str = f"{period}_{composer}_{instrumentation}" filename_base = f"{timestamp}_{prompt_str}" abc_filename = f"{filename_base}.abc" with open(abc_filename, "w", encoding="utf-8") as f: f.write(abc_content) xml_filename = f"{filename_base}.xml" try: subprocess.run( ["python", "abc2xml.py", '-o', '.', abc_filename, ], check=True, capture_output=True, text=True ) except subprocess.CalledProcessError as e: error_msg = f"Conversion failed: {e.stderr}" if e.stderr else "Unknown error" raise gr.Error(f"ABC to XML conversion failed: {error_msg}. Please try to generate another composition.") return f"Saved successfully: {abc_filename} -> {xml_filename}" def generate_music(period, composer, instrumentation): if (period, composer, instrumentation) not in valid_combinations: raise gr.Error("Invalid prompt combination! Please re-select from the period options") output_queue = queue.Queue() original_stdout = sys.stdout sys.stdout = RealtimeStream(output_queue) result_container = [] def run_inference(): try: result_container.append(inference_patch(period, composer, instrumentation)) finally: sys.stdout = original_stdout thread = threading.Thread(target=run_inference) thread.start() process_output = "" while thread.is_alive(): try: text = output_queue.get(timeout=0.1) process_output += text yield process_output, None except queue.Empty: continue while not output_queue.empty(): text = output_queue.get() process_output += text yield process_output, None final_result = result_container[0] if result_container else "" yield process_output, final_result with gr.Blocks() as demo: gr.Markdown("## NotaGen") with gr.Row(): # 左侧栏 with gr.Column(): period_dd = gr.Dropdown( choices=periods, value=None, label="Period", interactive=True ) composer_dd = gr.Dropdown( choices=[], value=None, label="Composer", interactive=False ) instrument_dd = gr.Dropdown( choices=[], value=None, label="Instrumentation", interactive=False ) generate_btn = gr.Button("Generate!", variant="primary") process_output = gr.Textbox( label="Generation process", interactive=False, lines=15, max_lines=15, placeholder="Generation progress will be shown here...", elem_classes="process-output" ) # 右侧栏 with gr.Column(): final_output = gr.Textbox( label="Post-processed ABC notation scores", interactive=True, lines=23, placeholder="Post-processed ABC scores will be shown here...", elem_classes="final-output" ) with gr.Row(): save_btn = gr.Button("💾 Save as ABC & XML files", variant="secondary") save_status = gr.Textbox( label="Save Status", interactive=False, visible=True, max_lines=2 ) period_dd.change( update_components, inputs=[period_dd, composer_dd], outputs=[composer_dd, instrument_dd] ) composer_dd.change( update_components, inputs=[period_dd, composer_dd], outputs=[composer_dd, instrument_dd] ) generate_btn.click( generate_music, inputs=[period_dd, composer_dd, instrument_dd], outputs=[process_output, final_output] ) save_btn.click( save_and_convert, inputs=[final_output, period_dd, composer_dd, instrument_dd], outputs=[save_status] ) css = """ .process-output { background-color: #f0f0f0; font-family: monospace; padding: 10px; border-radius: 5px; } .final-output { background-color: #ffffff; font-family: sans-serif; padding: 10px; border-radius: 5px; } .process-output textarea { max-height: 500px !important; overflow-y: auto !important; white-space: pre-wrap; } """ css += """ button#💾-save-convert:hover { background-color: #ffe6e6; } """ demo.css = css if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7861 )