import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Global variable to store history of attempts
history = []

def predict_house_price(area):
    """Simple house price prediction based on area"""
    # Using the simple formula: price = 0.1 * area (as per your slides)
    price = 0.1 * area
    return float(price)

def calculate_sse(x, y, m, b):
    """Calculate Sum of Squared Errors"""
    y_predicted = m * x + b
    sse = np.sum((y - y_predicted) ** 2)
    return sse

def plot_regression(data, m, b):
    try:
        df = data if isinstance(data, pd.DataFrame) else pd.read_csv(data)
        df['X'] = pd.to_numeric(df['X'])
        df['Y'] = pd.to_numeric(df['Y'])
        
        sse = calculate_sse(df['X'], df['Y'], m, b)
        
        history.append({
            'm': m,
            'b': b,
            'sse': sse,
            'color': plt.cm.rainbow(len(history) % 10 / 10)
        })
        
        fig = plt.figure(figsize=(15, 6))
        
        # First subplot - Regression lines
        ax1 = fig.add_subplot(121)
        ax1.scatter(df['X'], df['Y'], color='black', alpha=0.5, label='Data points')
        
        for i, attempt in enumerate(history):
            x_range = np.linspace(df['X'].min(), df['X'].max(), 100)
            y_line = attempt['m'] * x_range + attempt['b']
            label = f"m={attempt['m']:.1f}, b={attempt['b']:.1f}"
            ax1.plot(x_range, y_line, color=attempt['color'], linewidth=2, 
                    label=f"Try {i+1}: {label}")
        
        ax1.set_xlabel('X')
        ax1.set_ylabel('Y')
        ax1.set_title('Linear Regression Attempts')
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # Second subplot - SSE values
        ax2 = fig.add_subplot(122)
        attempts = range(1, len(history) + 1)
        sse_values = [attempt['sse'] for attempt in history]
        colors = [attempt['color'] for attempt in history]
        
        ax2.scatter(attempts, sse_values, c=colors)
        ax2.plot(attempts, sse_values, 'gray', alpha=0.3)
        
        for i, (attempt, sse) in enumerate(zip(attempts, sse_values)):
            label = f"m={history[i]['m']:.1f}\nb={history[i]['b']:.1f}"
            ax2.annotate(label, (attempt, sse), 
                        xytext=(5, 5), textcoords='offset points')
        
        ax2.set_xlabel('Attempt Number')
        ax2.set_ylabel('Sum of Squared Errors')
        ax2.set_title('SSE for Each Attempt')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.close()
        return fig
        
    except Exception as e:
        print(f"Error: {e}")
        return None

def clear_history():
    history.clear()
    return None

# Create the Gradio interface with tabs
with gr.Blocks() as app:
    gr.Markdown("# Linear Regression Learning Tools")
    
    with gr.Tabs():
        # First Tab - House Price Prediction
        with gr.TabItem("House Price Predictor"):
            gr.Markdown("""
            # House Price Predictor
            Enter the area of the house (in m²) to predict its price.
            Based on the simple model: Price = 0.1 × Area
            """)
            
            with gr.Row():
                area_input = gr.Number(
                    label="House Area (m²)",
                    value=100
                )
                price_output = gr.Number(
                    label="Predicted Price ($M)",
                    value=None
                )
            
            predict_button = gr.Button("Predict Price")
            predict_button.click(
                fn=predict_house_price,
                inputs=area_input,
                outputs=price_output
            )
            
            # Example table
            gr.Markdown("""
            ### Example Data Points:
            | Area (m²) | Price ($M) |
            |-----------|------------|
            | 100       | 10         |
            | 200       | 20         |
            | 300       | 30         |
            | 400       | 40         |
            | 500       | 50         |
            """)

        # Second Tab - Regression Playground
        with gr.TabItem("Understanding Squares Error"):
            gr.Markdown("""
            # Understanding Squares Error
            See how different lines affect the total squared error:
            - The data shows the relationship between house area and price
            - Try different slopes (m) and y-intercepts (b) for the line
            - Watch how the squared errors (orange boxes) change
            - Lower total squared error means a better fitting line
            """)
            
            with gr.Row():
                data_input = gr.Dataframe(
                    headers=["X", "Y"],
                    datatype=["number", "number"],
                    row_count=5,
                    col_count=2,
                    label="Dataset",
                    interactive=True,
                    value=[[100, 10],
                          [200, 20],
                          [300, 30],
                          [400, 40],
                          [500, 50]]
                )
                
                with gr.Column():
                    m_slider = gr.Slider(
                        minimum=-10,
                        maximum=10,
                        value=1.0,
                        step=0.1,
                        label="Slope (m)",
                    )

                    b_slider = gr.Slider(
                        minimum=-10,
                        maximum=10,
                        value=0.0,
                        step=0.1,
                        label="Intercept (b)",
                    )
                    
                    submit_button = gr.Button("Submit")
                    clear_button = gr.Button("Clear History")
            
            plot_output = gr.Plot()
            
            # Set up the event handlers
            inputs = [data_input, m_slider, b_slider]
            clear_button.click(fn=clear_history, inputs=None, outputs=plot_output)
            submit_button.click(fn=plot_regression, inputs=inputs, outputs=plot_output)

if __name__ == "__main__":
    app.launch(show_api=False)