from datetime import datetime
import json
import os
import streamlit as st
import requests
import pandas as pd
from io import StringIO
import plotly.graph_objects as go

# Top down page rendering 
st.set_page_config(page_title='Hockey Breeds v3 - Pressure Meter', layout="wide",
                   page_icon=":frame_with_picture:")

st.title('Hockey Breeds v3 - Pressure Meter')
intro = '''Version 3 of Hockey Breeds introduces a new feature: the **Pressure Meter**.  Pressure is a term used in hockey to describe the buildup of offensive momuntum  which often leads to goals.

The **Pressure Meter** builds on a number of major enhancements to the Top-Shelf AI platform:
1. Improved and expanded data set and improved model
1. Parallelized processing pipeline for processing input video and generating output metrics in *real time*
1. Analysis and metrics include:
    * Team jersey color determination
    * Player team assignments
    * Skater speeds and accelerations
    * Player positions relative to nearest goalie & net
    * Improved puck tracking and interpolation
    * Game play state analysis (stoppage vs live play)
'''
st.markdown(intro)

st.subheader('Pressure Meter Visualization')

# get  the data file location   
data_location = st.text_input('Enter the location of the stream analytics metadata file', 
                              value='https://storage.googleapis.com/topshelf-clients/pressure-meter/2025-02-09/22809/stream_metadata.json')
metadata = None
stream_base_url = None
if data_location:
    # should be an http link
    if not data_location.startswith('http'):
        st.error('Data location must be an http link')
    else:
        # download the data from the link
        if data_location.endswith('/'):
            data_location = data_location + 'stream_metadata.json'
        data = requests.get(data_location)
        # load the data from the json file
        metadata = json.loads(data.text)

        # determine the base url for the stream
        stream_base_url = data_location.split('stream_metadata.json')[0]


# load the data from the csv files
if metadata:
    # get the data from the csv files
    files = metadata['output_files']

    # get the base timestamp for the stream
    base_timestamp = datetime.fromisoformat(metadata['video_start_time'])

    # Create an empty list to store individual dataframes
    dfs = []
    
    for ts, file in files.items():
        try:
            response = requests.get(stream_base_url + file)
            response.raise_for_status()
            
            data_string = StringIO(response.text)
            df = pd.read_csv(data_string)
            
            ts_delta = datetime.fromtimestamp(int(ts)).astimezone(base_timestamp.tzinfo) - base_timestamp
            df['second_offset'] = df['second_offset'] + ts_delta.total_seconds()
            
            dfs.append(df)
            
        except Exception as e:
            st.error(f"Failed to load data for timestamp {ts}, file: {file}")
            st.error(f"Error: {str(e)}")
            continue
    
    # Log the number of files processed
    st.info(f"Successfully loaded {len(dfs)} out of {len(files)} files")
    
    # Concatenate all dataframes and sort by the second_offset
    combined_df = pd.concat(dfs, ignore_index=True)
    combined_df = combined_df.sort_values('second_offset')
    
    # Check for gaps in the sequence
    expected_range = set(range(int(combined_df['second_offset'].min()), 
                             int(combined_df['second_offset'].max()) + 1))
    actual_range = set(combined_df['second_offset'].astype(int))
    missing_seconds = sorted(expected_range - actual_range)
    
    if missing_seconds:
        st.warning("Found gaps in the data sequence:")
        # Group consecutive missing seconds into ranges for cleaner output
        gaps = []
        start = missing_seconds[0]
        prev = start
        for curr in missing_seconds[1:] + [None]:
            if curr != prev + 1:
                if start == prev:
                    gaps.append(f"{start}")
                else:
                    gaps.append(f"{start}-{prev}")
                start = curr
            prev = curr
        
        st.warning(f"Missing seconds: {', '.join(gaps)}")

    # Calculate cumulative counts and ratios - only count actual pressure values
    combined_df['team1_cumulative'] = (combined_df['pressure_balance'] > 0).astype(int).cumsum()
    combined_df['team2_cumulative'] = (combined_df['pressure_balance'] < 0).astype(int).cumsum()
    combined_df['total_cumulative'] = combined_df['team1_cumulative'] + combined_df['team2_cumulative']
    
    # Avoid division by zero by using where
    combined_df['team1_pressure_ratio'] = (combined_df['team1_cumulative'] / 
                                         combined_df['total_cumulative'].where(combined_df['total_cumulative'] > 0, 1))
    combined_df['team2_pressure_ratio'] = (combined_df['team2_cumulative'] / 
                                         combined_df['total_cumulative'].where(combined_df['total_cumulative'] > 0, 1))
    
    # Calculate the ratio difference for the balance visualization
    combined_df['pressure_ratio_diff'] = combined_df['team1_pressure_ratio'] - combined_df['team2_pressure_ratio']

    # Add pressure balance visualization using the ratio difference
    st.subheader("Pressure Waves")
    balance_df = pd.DataFrame({
        'second_offset': combined_df['second_offset'],
        'pressure_ratio_diff': combined_df['pressure_ratio_diff']
    })
    
    # Get team colors from metadata and parse them
    def parse_rgb(color_str):
        # Extract numbers from format 'rgb(r,g,b)'
        r, g, b = map(int, color_str.strip('rgb()').split(','))
        return r, g, b

    team1_color = metadata.get('team1_color', 'rgb(54, 162, 235)')  # default blue if not found
    team2_color = metadata.get('team2_color', 'rgb(255, 99, 132)')  # default red if not found
    
    # Parse RGB values
    team1_rgb = parse_rgb(team1_color)
    team2_rgb = parse_rgb(team2_color)
    
    fig = go.Figure()
    
    # Add positive values with team1 color
    fig.add_trace(
        go.Scatter(
            x=combined_df['second_offset'],
            y=combined_df['pressure_ratio_diff'].clip(lower=0),
            fill='tozeroy',
            fillcolor=f'rgba{(*team1_rgb, 0.2)}',
            line=dict(
                color=team1_color,
                shape='hv'
            ),
            name='Team 1 Dominant',
            hovertemplate='Time: %{x:.1f}s<br>Dominance: %{y:.2f}<br>Team 1<extra></extra>',
            hoveron='points+fills'
        )
    )
    
    # Add negative values with team2 color
    fig.add_trace(
        go.Scatter(
            x=combined_df['second_offset'],
            y=combined_df['pressure_ratio_diff'].clip(upper=0),
            fill='tozeroy',
            fillcolor=f'rgba{(*team2_rgb, 0.2)}',
            line=dict(
                color=team2_color,
                shape='hv'
            ),
            name='Team 2 Dominant',
            hovertemplate='Time: %{x:.1f}s<br>Dominance: %{y:.2f}<br>Team 2<extra></extra>',
            hoveron='points+fills'
        )
    )

    fig.update_layout(
        yaxis=dict(
            range=[-1, 1],
            zeroline=True,
            zerolinewidth=2,
            zerolinecolor='rgba(0,0,0,0.2)',
            gridcolor='rgba(0,0,0,0.1)',
            title='Team Dominance'
        ),
        xaxis=dict(
            title='Time (seconds)',
            gridcolor='rgba(0,0,0,0.1)'
        ),
        plot_bgcolor='white',
        height=400,
        margin=dict(l=0, r=0, t=20, b=0),
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        )
    )

    st.plotly_chart(fig, use_container_width=True)

    with st.expander("Pressure Data"):
        st.write(combined_df)

    # add details in a sub section with expander
    with st.expander("Pressure Meter Details"):
        st.write("""
        The Pressure Meter is a visualization of the pressure waves in the game.  It is a line chart of the cumulative pressure counts for each team over time.
        """)

        # Create two columns for charts
        col1, col2 = st.columns(2)
        
        with col1:
            st.subheader("Cumulative Pressure Counts")
            st.line_chart(combined_df, x='second_offset', y=['team1_cumulative', 'team2_cumulative'])
        
        with col2:
            st.subheader("Pressure Ratio Over Time")
            st.area_chart(combined_df, 
                        x='second_offset', 
                        y=['team1_pressure_ratio', 'team2_pressure_ratio'])


        # Show current dominance percentage
        current_ratio = combined_df.iloc[-1]['pressure_balance']
        if current_ratio > 0:
            dominant_team = 'Team 1'
            pressure_value = current_ratio
        elif current_ratio < 0:
            dominant_team = 'Team 2'
            pressure_value = abs(current_ratio)
        else:
            dominant_team = 'Neutral'
            pressure_value = 0

        st.metric(
            label="Dominant Team Pressure", 
            value=f"{dominant_team}",
            delta=f"{pressure_value*100:.1f}%"
        )

    # After loading metadata
    st.subheader("Data Files Analysis")
    
    # Analyze the timestamps in the metadata
    timestamps = sorted([int(ts) for ts in files.keys()])
    time_diffs = [timestamps[i+1] - timestamps[i] for i in range(len(timestamps)-1)]
    
    st.info(f"Number of data files: {len(files)}")
    st.info(f"Time range: {datetime.fromtimestamp(timestamps[0])} to {datetime.fromtimestamp(timestamps[-1])}")
    st.info(f"Time differences between files: {set(time_diffs)} seconds")
    
    # Show the actual files and timestamps
    with st.expander("Stream Metadata Details"):
        st.write(metadata)
        # Log the data range
        st.write(f"Data range: {combined_df['second_offset'].min():.1f}s to {combined_df['second_offset'].max():.1f}s")
        st.write(f"Total rows: {len(combined_df)}")

        for ts in sorted(files.keys()):
            st.text(f"Timestamp: {datetime.fromtimestamp(int(ts))} - File: {files[ts]}")