import gradio as gr
import numpy as np
import plotly.graph_objects as go
import scipy.signal as ssig
import librosa
import plotly.io as pio

def plot_stft(audio_file):
    # Load audio file
    audio, sampling_rate = librosa.load(audio_file)

    # Compute STFT
    freq, frames, stft = ssig.stft(audio,
                                   sampling_rate,
                                   window='hann',
                                   nperseg=512,
                                   noverlap=412,
                                   nfft=1024,
                                   return_onesided=True,
                                   boundary='zeros',
                                   padded=True,
                                   axis=-1)

    # Create spectrogram heatmap
    spectrogram = go.Heatmap(z=librosa.amplitude_to_db(np.abs(stft), ref=np.max),
                             x=frames,
                             y=freq,
                             colorscale='Viridis')

    # Create Plotly figure
    fig = go.Figure(spectrogram)

    # Customize layout
    fig.update_layout(
        font=dict(family='Latin Modern Roman', size=18),
        xaxis=dict(title='Time (seconds)',
                   titlefont=dict(family='Latin Modern Roman', size=18)),
        yaxis=dict(title='Frequency (Hz)',
                   titlefont=dict(family='Latin Modern Roman', size=18)),
        margin=dict(l=0, r=0, t=0, b=0),
    )

    fig.update_traces(colorbar_thickness=8, selector=dict(type='heatmap'))
    fig.update_traces(showscale=True, showlegend=False, visible=True)
    fig.update_xaxes(visible=True, showgrid=False)
    fig.update_yaxes(visible=True, showgrid=False)

    # Save the figure as an image
    image_path = 'stft_plot.png'
    fig.write_image(image_path)

    return image_path

# Gradio interface
demo = gr.Interface(fn=plot_stft,
                    inputs=gr.Audio(type="filepath"),
                    outputs="image")

demo.launch()