import gc
import time
import torch
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering

st.set_page_config(page_title="ViBidLawQA - Hệ thống hỏi đáp trực tuyến luật Việt Nam", page_icon="./app/static/ai.png", layout="centered", initial_sidebar_state="expanded")

with open("./static/styles.css") as f:
    st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

if 'messages' not in st.session_state:
    st.session_state.messages = []

st.markdown(f"""
<div class=logo_area>
    <img src="./app/static/ai.png"/>
</div>
""", unsafe_allow_html=True)
st.markdown("<h2 style='text-align: center;'>ViBidLawQA_v2</h2>", unsafe_allow_html=True)

answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if answering_method == 'Generation' and 'aqa_model' not in st.session_state:
    if 'eqa_model' and 'eqa_tokenizer' in st.session_state:
        del st.session_state.eqa_model
        del st.session_state.eqa_tokenizer
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        print('Switching to generative model...')
    print('Loading generative model...')
    st.session_state.aqa_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path='./models/AQA_model').to(device)
    st.session_state.aqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/AQA_model')

if answering_method == 'Extraction' and 'eqa_model' not in st.session_state:
    if 'aqa_model' and 'aqa_tokenizer' in st.session_state:
        del st.session_state.aqa_model
        del st.session_state.aqa_tokenizer
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        print('Switching to extraction model...')
    print('Loading extraction model...')
    st.session_state.eqa_model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path='./models/EQA_model').to(device)
    st.session_state.eqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/EQA_model')

def get_abstractive_answer(context, question, max_length=1024, max_target_length=512):
    inputs = st.session_state.aqa_tokenizer(question,
                                            context,
                                            max_length=max_length,
                                            truncation='only_second',
                                            padding='max_length',
                                            return_tensors='pt')
    outputs = st.session_state.aqa_model.generate(inputs=inputs['input_ids'].to(device),
                                                  attention_mask=inputs['attention_mask'].to(device),
                                                  max_length=max_target_length)
    answer = st.session_state.aqa_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_space=True)

    if not answer.endswith('.'):
        answer += '.'

    return answer

def generate_text_effect(answer):
    words = answer.split()
    for i in range(len(words)):
        time.sleep(0.05)
        yield " ".join(words[:i+1])

def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512):
    inputs = st.session_state.eqa_tokenizer(question,
                                        context,
                                        max_length=max_length,
                                        truncation='only_second',
                                        stride=stride,
                                        return_overflowing_tokens=True,
                                        return_offsets_mapping=True,
                                        padding='max_length')
    for i in range(len(inputs['input_ids'])):
        sequence_ids = inputs.sequence_ids(i)
        offset = inputs['offset_mapping'][i]
        inputs['offset_mapping'][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    input_ids = torch.tensor(inputs["input_ids"]).to(device)
    attention_mask = torch.tensor(inputs["attention_mask"]).to(device)
    
    with torch.no_grad():
        outputs = st.session_state.eqa_model(input_ids=input_ids, attention_mask=attention_mask)
    
    start_logits = outputs.start_logits.cpu().numpy()
    end_logits = outputs.end_logits.cpu().numpy()
    
    answers = []
    for i in range(len(inputs["input_ids"])):
        start_logit = start_logits[i]
        end_logit = end_logits[i]
        offsets = inputs["offset_mapping"][i]
        
        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                    continue
                
                answer = {
                    "text": context[offsets[start_index][0] : offsets[end_index][1]],
                    "logit_score": start_logit[start_index] + end_logit[end_index],
                }
                answers.append(answer)
    
    if len(answers) > 0:
        best_answer = max(answers, key=lambda x: x["logit_score"])
        return best_answer["text"]
    else:
        return ""

for message in st.session_state.messages:
    if message['role'] == 'assistant':
        avatar_class = "assistant-avatar"
        message_class = "assistant-message"
        avatar = './app/static/ai.png'
    else:
        avatar_class = "user-avatar"
        message_class = "user-message"
        avatar = './app/static/human.png'
    st.markdown(f"""
    <div class="{message_class}">
        <img src="{avatar}" class="{avatar_class}" />
        <div class="stMarkdown">{message['content']}</div>
    </div>
    """, unsafe_allow_html=True)

if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
    st.markdown(f"""
    <div class="user-message">
        <img src="./app/static/human.png" class="user-avatar" />
        <div class="stMarkdown">{prompt}</div>
    </div>
    """, unsafe_allow_html=True)
    st.session_state.messages.append({'role': 'user', 'content': prompt})
    
    message_placeholder = st.empty()
    
    for _ in range(2):
        for dots in ["●", "●●", "●●●"]:
            time.sleep(0.2)
            message_placeholder.markdown(f"""
            <div class="assistant-message">
                <img src="./app/static/ai.png" class="assistant-avatar" />
                <div class="stMarkdown">{dots}</div>
            </div>
            """, unsafe_allow_html=True)
    
    full_response = ""
    if answering_method == 'Generation':
        abs_answer = get_abstractive_answer(context=context, question=prompt) 
        for word in generate_text_effect(abs_answer):
            full_response = word

            message_placeholder.markdown(f"""
            <div class="assistant-message">
                <img src="./app/static/ai.png" class="assistant-avatar" />
                <div class="stMarkdown">{full_response}●</div>
            </div>
            """, unsafe_allow_html=True)
        
    else:
        ext_answer = get_extractive_answer(context=context, question=prompt)
        for word in generate_text_effect(ext_answer):
            full_response = word
            
            message_placeholder.markdown(f"""
            <div class="assistant-message">
                <img src="./app/static/ai.png" class="assistant-avatar" />
                <div class="stMarkdown">{full_response}●</div>
            </div>
            """, unsafe_allow_html=True)

    message_placeholder.markdown(f"""
    <div class="assistant-message">
        <img src="./app/static/ai.png" class="assistant-avatar" />
        <div class="stMarkdown">{full_response}</div>
    </div>
    """, unsafe_allow_html=True)
    
    st.session_state.messages.append({'role': 'assistant', 'content': full_response})