import gradio as gr
import subprocess
import os 
import shutil
import tempfile

"""
# Set the PATH and LD_LIBRARY_PATH for CUDA 12.3
cuda_bin_path = "/usr/local/cuda/bin"
cuda_lib_path = "/usr/local/cuda/lib64"

# Update the environment variables
os.environ['PATH'] = f"{cuda_bin_path}:{os.environ.get('PATH', '')}"
os.environ['LD_LIBRARY_PATH'] = f"{cuda_lib_path}:{os.environ.get('LD_LIBRARY_PATH', '')}"
"""

# Install required package
def install_flash_attn():
    try:
        print("Installing flash-attn...")
        subprocess.run(
            ["pip", "install", "flash-attn", "--no-build-isolation"], 
            check=True
        )
        print("flash-attn installed successfully!")
    except subprocess.CalledProcessError as e:
        print(f"Failed to install flash-attn: {e}")
        exit(1)

# Install flash-attn
install_flash_attn()

from huggingface_hub import snapshot_download 

# Create xcodec_mini_infer folder
folder_path = './inference/xcodec_mini_infer'

# Create the folder if it doesn't exist
if not os.path.exists(folder_path):
    os.mkdir(folder_path)
    print(f"Folder created at: {folder_path}")
else:
    print(f"Folder already exists at: {folder_path}")

snapshot_download(
    repo_id = "m-a-p/xcodec_mini_infer",
    local_dir = "./inference/xcodec_mini_infer"
)

# Change to the "inference" directory
inference_dir = "./inference"
try:
    os.chdir(inference_dir)
    print(f"Changed working directory to: {os.getcwd()}")
except FileNotFoundError:
    print(f"Directory not found: {inference_dir}")
    exit(1)

def empty_output_folder(output_dir):
    # List all files in the output directory
    files = os.listdir(output_dir)
    
    # Iterate over the files and remove them
    for file in files:
        file_path = os.path.join(output_dir, file)
        try:
            if os.path.isdir(file_path):
                # If it's a directory, remove it recursively
                shutil.rmtree(file_path)
            else:
                # If it's a file, delete it
                os.remove(file_path)
        except Exception as e:
            print(f"Error deleting file {file_path}: {e}")

# Function to create a temporary file with string content
def create_temp_file(content, prefix, suffix=".txt"):
    temp_file = tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=prefix, suffix=suffix)
    # Ensure content ends with newline and normalize line endings
    content = content.strip() + "\n\n"  # Add extra newline at end
    content = content.replace("\r\n", "\n").replace("\r", "\n")
    temp_file.write(content)
    temp_file.close()
    
    # Debug: Print file contents
    print(f"\nContent written to {prefix}{suffix}:")
    print(content)
    print("---")
    
    return temp_file.name

def get_last_mp3_file(output_dir):
    # List all files in the output directory
    files = os.listdir(output_dir)
    
    # Filter only .mp3 files
    mp3_files = [file for file in files if file.endswith('.mp3')]
    
    if not mp3_files:
        print("No .mp3 files found in the output folder.")
        return None
    
    # Get the full path for the mp3 files
    mp3_files_with_path = [os.path.join(output_dir, file) for file in mp3_files]
    
    # Sort the files based on the modification time (most recent first)
    mp3_files_with_path.sort(key=lambda x: os.path.getmtime(x), reverse=True)
    
    # Return the most recent .mp3 file
    return mp3_files_with_path[0]

def infer(genre_txt_content, lyrics_txt_content):
    # Create temporary files
    genre_txt_path = create_temp_file(genre_txt_content, prefix="genre_")
    lyrics_txt_path = create_temp_file(lyrics_txt_content, prefix="lyrics_")

    print(f"Genre TXT path: {genre_txt_path}")
    print(f"Lyrics TXT path: {lyrics_txt_path}")

    # Ensure the output folder exists
    output_dir = "./output"
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output folder ensured at: {output_dir}")

    empty_output_folder(output_dir)
 
    # Command and arguments
    command = [
        "python", "infer.py",
        "--stage1_model", "m-a-p/YuE-s1-7B-anneal-en-cot",
        "--stage2_model", "m-a-p/YuE-s2-1B-general",
        "--genre_txt", f"{genre_txt_path}",
        "--lyrics_txt", f"{lyrics_txt_path}",
        "--run_n_segments", "2",
        "--stage2_batch_size", "4",
        "--output_dir", f"{output_dir}",
        "--cuda_idx", "0",
        "--max_new_tokens", "3000",
        "--disable_offload_model"
    ]

    # Set up environment variables for CUDA
    env = os.environ.copy()  # Copy current environment
    env.update({
        "CUDA_VISIBLE_DEVICES": "0",
        "CUDA_HOME": "/usr/local/cuda",
        "PATH": f"/usr/local/cuda/bin:{env.get('PATH', '')}",
        "LD_LIBRARY_PATH": f"/usr/local/cuda/lib64:{env.get('LD_LIBRARY_PATH', '')}"
    })
    
    # Execute the command
    try:
        subprocess.run(command, check=True, env=env)
        print("Command executed successfully!")
        
        # Check and print the contents of the output folder
        output_files = os.listdir(output_dir)
        if output_files:
            print("Output folder contents:")
            for file in output_files:
                print(f"- {file}")

            last_mp3 = get_last_mp3_file(output_dir)

            if last_mp3:
                print("Last .mp3 file:", last_mp3)
                return last_mp3
            else:
                return None
        else:
            print("Output folder is empty.")
            return None
    except subprocess.CalledProcessError as e:
        print(f"Error occurred: {e}")
        return None
    finally:
        # Clean up temporary files
        os.remove(genre_txt_path)
        os.remove(lyrics_txt_path)
        print("Temporary files deleted.")

# Gradio 

with gr.Blocks() as demo:
    with gr.Column():
        gr.Markdown("# YuE")
        with gr.Row():
            with gr.Column():
                genre_txt = gr.Textbox(label="Genre")
                lyrics_txt = gr.Textbox(label="Lyrics")
                submit_btn = gr.Button("Submit")
            with gr.Column():
                music_out = gr.Audio(label="Audio Result")
    
    submit_btn.click(
        fn = infer, 
        inputs = [genre_txt, lyrics_txt],
        outputs = [music_out]
    )
demo.queue().launch(show_api=False, show_error=True)