"""
Demo is based on https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html
"""

import sys
import numpy as np
import pandas as pd

symbol_dict = {
    "TOT": "Total",
    "XOM": "Exxon",
    "CVX": "Chevron",
    "COP": "ConocoPhillips",
    "VLO": "Valero Energy",
    "MSFT": "Microsoft",
    "IBM": "IBM",
    "TWX": "Time Warner",
    "CMCSA": "Comcast",
    "CVC": "Cablevision",
    "YHOO": "Yahoo",
    "DELL": "Dell",
    "HPQ": "HP",
    "AMZN": "Amazon",
    "TM": "Toyota",
    "CAJ": "Canon",
    "SNE": "Sony",
    "F": "Ford",
    "HMC": "Honda",
    "NAV": "Navistar",
    "NOC": "Northrop Grumman",
    "BA": "Boeing",
    "KO": "Coca Cola",
    "MMM": "3M",
    "MCD": "McDonald's",
    "PEP": "Pepsi",
    "K": "Kellogg",
    "UN": "Unilever",
    "MAR": "Marriott",
    "PG": "Procter Gamble",
    "CL": "Colgate-Palmolive",
    "GE": "General Electrics",
    "WFC": "Wells Fargo",
    "JPM": "JPMorgan Chase",
    "AIG": "AIG",
    "AXP": "American express",
    "BAC": "Bank of America",
    "GS": "Goldman Sachs",
    "AAPL": "Apple",
    "SAP": "SAP",
    "CSCO": "Cisco",
    "TXN": "Texas Instruments",
    "XRX": "Xerox",
    "WMT": "Wal-Mart",
    "HD": "Home Depot",
    "GSK": "GlaxoSmithKline",
    "PFE": "Pfizer",
    "SNY": "Sanofi-Aventis",
    "NVS": "Novartis",
    "KMB": "Kimberly-Clark",
    "R": "Ryder",
    "GD": "General Dynamics",
    "RTN": "Raytheon",
    "CVS": "CVS",
    "CAT": "Caterpillar",
    "DD": "DuPont de Nemours",
}


symbols, names = np.array(sorted(symbol_dict.items())).T

quotes = []

for symbol in symbols:
    print("Fetching quote history for %r" % symbol, file=sys.stderr)
    url = (
        "https://raw.githubusercontent.com/scikit-learn/examples-data/"
        "master/financial-data/{}.csv"
    )
    quotes.append(pd.read_csv(url.format(symbol)))

close_prices = np.vstack([q["close"] for q in quotes])
open_prices = np.vstack([q["open"] for q in quotes])

# The daily variations of the quotes are what carry the most information
variation = close_prices - open_prices


from sklearn import covariance

alphas = np.logspace(-1.5, 1, num=10)
edge_model = covariance.GraphicalLassoCV(alphas=alphas)

# standardize the time series: using correlations rather than covariance
# former is more efficient for structurerelations rather than covariance
# former is more efficient for structure recovery
X = variation.copy().T
X /= X.std(axis=0)
edge_model.fit(X)


from sklearn import cluster

_, labels = cluster.affinity_propagation(edge_model.covariance_, random_state=0)
n_labels = labels.max()


# Finding a low-dimension embedding for visualization: find the best position of
# the nodes (the stocks) on a 2D plane

from sklearn import manifold

node_position_model = manifold.LocallyLinearEmbedding(
    n_components=3, eigen_solver="dense", n_neighbors=6
)

embedding = node_position_model.fit_transform(X.T).T

import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import plotly.graph_objs as go


def visualize_stocks():
    # Plot the graph of partial correlations
    partial_correlations = edge_model.precision_.copy()
    d = 1 / np.sqrt(np.diag(partial_correlations))
    partial_correlations *= d
    partial_correlations *= d[:, np.newaxis]
    non_zero = np.abs(np.triu(partial_correlations, k=1)) > 0.02

    # Plot the nodes using the coordinates of our embedding
    scatter = go.Scatter3d(
        x=embedding[0],
        y=embedding[1],
        z=embedding[2],
        mode="markers",
        marker=dict(size=35 * d**2, color=labels, colorscale="Viridis"),
        hovertext=names,
        hovertemplate="%{hovertext}<br>",
    )

    # # Plot the edges
    start_idx, end_idx = np.where(non_zero)
    # print(non_zero, non_zero.shape)
    # print(start_idx, start_idx.shape)
    segments = [
        dict(
            x=[embedding[0][start], embedding[0][stop]],
            y=[embedding[1][start], embedding[1][stop]],
            z=[embedding[2][start], embedding[2][stop]],
            colorscale="Hot",
            color=np.abs(partial_correlations[start, stop]),
            line=dict(width=10 * np.abs(partial_correlations[start, stop])),
        )
        for start, stop in zip(start_idx, end_idx)
    ]
    fig = go.Figure(data=[scatter])

    for idx, segment in enumerate(segments, 1):
        fig.add_trace(
            go.Scatter3d(
                x=segment["x"],  # x-coordinates of the line segment
                y=segment["y"],  # y-coordinates of the line segment
                z=segment["z"],  # z-coordinates of the line segment
                mode="lines",  # type of the plot (line)
                line=dict(
                    color=segment["color"],  # color of the line
                    colorscale=segment["colorscale"],  # color scale of the line
                    width=segment["line"]["width"] * 2.5,  # width of the line
                ),
                hoverinfo="none",  # disable hover for the line segments
            ),
        )
        fig.data[idx].showlegend = False

    return fig


import gradio as gr

title = " 📈 Visualizing the stock market structure 📈"

with gr.Blocks(title=title) as demo:
    gr.Markdown(f"# {title}")
    gr.Markdown(" Data is of 56 stocks between the period of 2003 - 2008 <br>")
    gr.Markdown(
        " Stocks the move in together with each other are grouped together in a cluster <br>"
    )

    gr.Markdown(
        " **[Demo is based on sklearn docs](https://scikit-learn.org/stable/auto_examples/applications/plot_stock_market.html)**"
    )

    for i in range(n_labels + 1):
        gr.Markdown(f"Cluster {i + 1}: {', '.join(names[labels == i])}")

    btn = gr.Button(value="Visualize")
    btn.click(
        visualize_stocks, outputs=gr.Plot(label="Visualizing stock into clusters")
    )

demo.launch()