import itertools
import json
import re
from collections import defaultdict
from functools import partial
from pathlib import Path

import pandas as pd
import requests
import streamlit as st

from generate_text_api import SummarizerGenerator
from model_inferences.utils.files import get_captions_from_vtt, get_transcript


def segmented_control(labels, key, default = None, max_size = 3) -> str:
    """Group of buttons with the given labels. Return the selected label."""
    if key not in st.session_state:
        st.session_state[key] = default or labels[0]

    selected_label = st.session_state[key]

    def set_label(label: str) -> None:
        st.session_state.update(**{key: label})

    cols = st.columns([1] * len(labels))

    for col, label in zip(cols, labels):
        btn_type = "primary" if selected_label == label else "secondary"
        col.button(label, on_click=set_label, args=(label,), use_container_width=True, type=btn_type)

    return selected_label

USE_PARAGRAPHING_MODEL = True

def get_sublist_by_flattened_index(A, i):
    current_index = 0
    for sublist in A:
        sublist_length = len(sublist)
        if current_index <= i < current_index + sublist_length:
            return sublist, A.index(sublist)
        current_index += sublist_length
    return None, None

import requests


def get_talk_metadata(video_id):
    url = "https://www.ted.com/graphql"

    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json",
        "x-operation-name": "Transcript",  # Replace with the actual operation name
    }

    data = {
        "query": """
        query GetTalk($videoId: ID!) {
            video(id: $videoId) {
                title,
                presenterDisplayName,
                nativeDownloads {medium}
            }
        }
        """,
        "variables": {
            "videoId": video_id,  # Corrected key to "videoId"
        },
    }

    response = requests.post(url, json=data, headers=headers)

    if response.status_code == 200:
        result = response.json()
        return result
    else:
        print(f"Error: {response.status_code}, {response.text}")

class OfflineTextSegmenterClient:
    def __init__(self, host_url):
        self.host_url = host_url.rstrip("/") + "/segment"

    def segment(self, text, captions=None, generate_titles=False, threshold=0.4):
        payload = {
            'text': text,
            'captions': captions,
            'generate_titles': generate_titles,
            "prefix_titles": True,
            "threshold": threshold,
        }

        headers = {
            'Content-Type': 'application/json'
        }

        response = requests.post(self.host_url, data=json.dumps(payload), headers=headers).json()
        #segments =  response["annotated_segments"] if "annotated_segments" in response else response["segments"]
        return {'segments':response["segments"], 'titles': response["titles"], 'sentences': response["sentences"]}

class Toc:

    def __init__(self):
        self._items = []
        self._placeholder = None
    
    def title(self, text):
        self._markdown(text, "h1")

    def header(self, text):
        self._markdown(text, "h2", " " * 2)

    def subheader(self, text):
        self._markdown(text, "h3", " " * 4)

    def placeholder(self, sidebar=False):
        self._placeholder = st.sidebar.empty() if sidebar else st.empty()

    def generate(self):
        if self._placeholder:
            self._placeholder.markdown("\n".join(self._items), unsafe_allow_html=True)
    
    def _markdown(self, text, level, space=""):
        key = re.sub(r'[^\w-]', '', text.replace(" ", "-").replace("'", "-").lower())
        st.markdown(f"<{level} id='{key}'>{text}</{level}>", unsafe_allow_html=True)
        self._items.append(f"{space}* <a href='#{key}'>{text}</a>")

import os

endpoint = os.getenv('summarize_stream_url')

client = OfflineTextSegmenterClient(os.getenv('chapterize_url'))
if USE_PARAGRAPHING_MODEL:
    paragrapher = OfflineTextSegmenterClient(os.getenv('paragraph_url'))
summarizer = SummarizerGenerator(endpoint)

import re


def replace_newlines(text):
    updated_text = re.sub(r'\n+', r'\n\n', text)
    return updated_text

def generate_summary(summarizer, generated_text_box, input_, prefix=""):
    all_generated_text = prefix
    for generated_text in summarizer.generate_summary_stream(input_):
        all_generated_text += replace_newlines(generated_text)
        generated_text_box.info(all_generated_text)
    print(all_generated_text)
    return all_generated_text.strip()

st.header("Demo: Intelligent Recap")

if not hasattr(st, 'global_state'):
    st.global_state = {'NIPS 2021 Talks': None, 'TED Talks': None}
    # NIPS 2021 Talks
    transcript_files = itertools.islice(Path("demo_data/nips-2021/").rglob("transcript_whisper_large-v2.vtt"), 15)
    # get titles from metadata.json
    transcripts_map = {}
    for transcript_file in transcript_files:
        base_path = transcript_file.parent
        metadata = base_path / "metadata.json"
        txt_file = base_path / "transcript_whisper_large-v2.txt"
        with open(metadata) as f:
            metadata = json.load(f)
            title = metadata["title"]
            transcript = get_transcript(txt_file)
            captions = get_captions_from_vtt(transcript_file)
            transcripts_map[title] = {"transcript": transcript, "captions": captions, "video": base_path / "video.mp4"}
    st.global_state['NIPS 2021 Talks'] = transcripts_map

    data = pd.read_json("demo_data/ted_talks.json")
    video_ids = data.talk_id.tolist()
    transcripts = data.text.apply(lambda x: " ".join(x)).tolist()
    transcripts_map = {}
    for video_id, transcript in zip(video_ids, transcripts):
        metadata = get_talk_metadata(video_id)
        title = metadata["data"]["video"]["title"]
        presenter = metadata["data"]["video"]["presenterDisplayName"]
        print(metadata["data"])
        if metadata["data"]["video"]["nativeDownloads"] is None:
            continue
        video_url = metadata["data"]["video"]["nativeDownloads"]["medium"]
        transcripts_map[title] = {"transcript": transcript, "video": video_url, "presenter": presenter}
    st.global_state['TED Talks'] = transcripts_map

    def get_lecture_id(path):
        return int(path.parts[-2].split('-')[1])

    transcript_files = Path("demo_data/lectures/").rglob("English.vtt")
    sorted_path_list = sorted(transcript_files, key=get_lecture_id)

    transcripts_map = {}
    for transcript_file in sorted_path_list:
        base_path = transcript_file.parent
        lecture_id = base_path.parts[-1]
        transcript = " ".join([c["text"].strip() for c in get_captions_from_vtt(transcript_file)]).replace("\n", " ")
        video_path = Path(base_path, "video.mp4")
        transcripts_map["Machine Translation: " + lecture_id] = {"transcript": transcript, "video": video_path}
    st.global_state['KIT Lectures'] = transcripts_map

#preloaded_document, youtube_video, custom_text = st.tabs(["Preloaded Document", "YouTube Video", "Custom Text"])
selected = segmented_control(["Preloaded Document", "YouTube Video", "Custom Text"], default="Preloaded Document", key="tabs")

input_text = ""
transcripts_map = defaultdict(dict)

if selected == "Preloaded Document":
    print("Preloaded Document")
    type_of_document = st.selectbox('What kind of document do you want to test it on?', list(st.global_state.keys()))

    transcripts_map = st.global_state[type_of_document]

    selected_talk = st.selectbox("Choose a document...", list(transcripts_map.keys()))

    st.video(str(transcripts_map[selected_talk]['video']), format="video/mp4", start_time=0)

    input_text = st.text_area("Transcript", value=transcripts_map[selected_talk]['transcript'], height=300)

from youtube_transcript_api import NoTranscriptFound, TranscriptsDisabled, YouTubeTranscriptApi


def get_transcript(video_id, lang="en"):
  try:
    transcripts = YouTubeTranscriptApi.list_transcripts(video_id)
    transcript = transcripts.find_manually_created_transcript([lang]).fetch()
  except NoTranscriptFound:
    return transcripts.find_manually_created_transcript(["en", "en-US", "en-GB", "en-CA"]).fetch()
  return transcript

def get_title(video_url):
    response = requests.get(f"https://noembed.com/embed?dataType=json&url={video_url}")
    result = response.json()
    return result["title"]

if selected == "YouTube Video":
    print("YouTube Video")
    video_url = st.text_input("Enter YouTube Link", value="https://www.youtube.com/watch?v=YuIc4mq7zMU")
    video_id = video_url.split("v=")[-1]
    try:
        subs = get_transcript(video_id)
        selected_talk = get_title(video_url)
    except (TranscriptsDisabled, NoTranscriptFound):
        subs = None
    if subs is not None:
        st.video(video_url, format="video/mp4", start_time=0)
        input_text = " ".join([sub["text"] for sub in subs])
        input_text = re.sub(r'\n+', r' ', input_text).replace("  ", " ")
        input_text = st.text_area("Transcript", value=input_text, height=300)
    else:
        st.error("No transcript found for this video.")

if selected == "Custom Text":
    print("Custom Text")
    input_text = st.text_area("Transcript", height=300, placeholder="Insert your transcript here...")
    input_text = re.sub(r'\n+', r' ', input_text)
    selected_talk = "Your Transcript"

toc = Toc()

summarization_todos = []

with st.expander("Adjust Thresholds"):
    threshold = st.slider('Chapter Segmentation Threshold', 0.00, 1.00, value=0.5, step=0.05)
    paragraphing_threshold = st.slider('Paragraphing Threshold', 0.00, 1.00, value=0.5, step=0.05)

if st.button("Process Transcript", disabled=not bool(input_text.strip())):
    with st.sidebar:
        st.header("Table of Contents")
        toc.placeholder()

    st.header(selected_talk, divider='rainbow')
    # if 'presenter' in transcripts_map[selected_talk]:
    #     st.markdown(f"### *by **{transcripts_map[selected_talk]['presenter']}***")

    captions = transcripts_map[selected_talk]['captions'] if 'captions' in transcripts_map[selected_talk] else None
    result = client.segment(input_text, captions, generate_titles=True, threshold=threshold)
    if USE_PARAGRAPHING_MODEL:
        presult = paragrapher.segment(input_text, captions, generate_titles=False, threshold=paragraphing_threshold)
        paragraphs = presult['segments']
    segments, titles, sentences = result['segments'], result['titles'], result['sentences']

    if USE_PARAGRAPHING_MODEL:
        prev_chapter_idx = 0
        prev_paragraph_idx = 0
        segment = []
        for i, sentence in enumerate(sentences):
            chapter, chapter_idx = get_sublist_by_flattened_index(segments, i)
            paragraph, paragraph_idx = get_sublist_by_flattened_index(paragraphs, i)

            if (chapter_idx != prev_chapter_idx and paragraph_idx == prev_paragraph_idx) or (paragraph_idx != prev_paragraph_idx and chapter_idx != prev_chapter_idx):
                print("Chapter / Chapter & Paragraph")
                segment_text = " ".join(segment)
                toc.subheader(titles[prev_chapter_idx])
                if len(segment_text) > 450:
                    generated_text_box = st.info("")
                    summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text))
                st.write(segment_text)
                segment = []
            elif paragraph_idx != prev_paragraph_idx and chapter_idx == prev_chapter_idx:
                print("Paragraph")
                segment.append("\n\n")
            
            segment.append(sentence)

            prev_chapter_idx = chapter_idx
            prev_paragraph_idx = paragraph_idx

        segment_text = " ".join(segment)
        toc.subheader(titles[prev_chapter_idx])
        generated_text_box = st.info("")
        summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment_text))
        st.write(segment_text)

    else:
        segments = [" ".join([sentence for sentence in segment]) for segment in segments]
        for title, segment in zip(titles, segments):
            toc.subheader(title)
            generated_text_box = st.info("")
            summarization_todos.append(partial(generate_summary, summarizer, generated_text_box, segment))
            st.write(segment)
    toc.generate()

for summarization_todo in summarization_todos:
    summarization_todo()