#basics
from http import server
import time
import pandas as pd
import numpy as np
import pickle
from PIL import Image

#DL
import torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

#streamlit
import streamlit as st
# from streamlit_server_state import server_state, server_state_lock
# import SessionState
from load_css import local_css
local_css("./style.css")

#text preprocess
import re
from pyvi import ViTokenizer
from rank_bm25 import BM25Okapi

#helper functions
from inspect import getsourcefile
import os.path as path, sys
from pathlib import Path
current_dir = path.dirname(path.abspath(getsourcefile(lambda:0)))
sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)])
# import src.clean_dataset as clean


def preprocess(sentence):
  sentence=str(sentence)
  sentence = sentence.lower()
  sentence=sentence.replace('{html}',"") 
  cleanr = re.compile('<.*?>')
  cleantext = re.sub(cleanr, '', sentence)
  rem_url=re.sub(r'http\S+', '',cleantext)
  word_list = rem_url.split()
  preped = ViTokenizer.tokenize(" ".join(word_list))
  return preped

DEFAULT = '< PICK A VALUE >'

def selectbox_with_default(text, values, default=DEFAULT, sidebar=False):
    func = st.sidebar.selectbox if sidebar else st.selectbox
    return func(text, np.insert(np.array(values, object), 0, default))

@st.cache_resource()
def loadmodels():
    model = T5ForConditionalGeneration.from_pretrained("wanderer2k1/T5-LawsQA")
    tokenizer = T5TokenizerFast.from_pretrained("wanderer2k1/T5-LawsQA")

    bi_encoder = SentenceTransformer('wanderer2k1/BertCondenser_LawsQA')
    return tokenizer, model, bi_encoder


def hf_run_model(tokenizer, model, input_string, **generator_args):
  generator_args = {
  "max_length": 256,
  "temperature":0.0,
  "num_beams": 4,
  "length_penalty": 0.1,
  "no_repeat_ngram_size": 8,
  "early_stopping": True,
  }
  input_string = "generate questions: " + input_string + " </s>"
  input_ids = tokenizer.encode(input_string, return_tensors="pt")
  res = model.generate(input_ids, **generator_args)
  output = tokenizer.batch_decode(res, skip_special_tokens=True)
  output = [item.split("<sep>") for item in output]
  return output

#%%
sys.path.pop(0)

#1. load in complete transformed and processed dataset  
if 'df' not in st.session_state:
    st.session_state['df'] = pd.read_csv('./data/corpus.pkl', sep = '\t')
    st.session_state['passages'] = st.session_state['df']['text'].values.tolist()
    st.session_state['passage_id'] = st.session_state['df']['title'].values.tolist()


#2 load corpus embeddings for neural QA:
if 'embedded_passages' not in st.session_state:
    with open("./data/embedded_corpus_BertCondenser_tuples.pkl", 'rb') as inp:  
        embedded_passages = pickle.load(inp)
        st.session_state['embedded_passages'] = torch.Tensor(embedded_passages)

#3 load BM25:
if 'bm25' not in st.session_state:
    with open("models/BM25_pyvi_segmented_splitted.pkl", 'rb') as inp: 
        st.session_state['bm25'] = pickle.load(inp)

#4: model
if 'model' not in st.session_state:
    st.session_state['tokenizer'], st.session_state['model'], st.session_state['bi_encoder'] = loadmodels()

#%%

def deploy(question):
    top_k = returns  # Number of passages we want to retrieve with the bi-encoder

    tokenized_query = preprocess(question).split()
    query = ' '.join(tokenized_query)
    emb_query = st.session_state['bi_encoder'].encode(query)

    scores = st.session_state['bm25'].get_scores(tokenized_query)
    top_score_ids = np.argpartition(scores, -50)[-50:]

    emb_candidates = torch.Tensor()

    for i in top_score_ids:
        emb_candidates = torch.cat([emb_candidates,st.session_state['embedded_passages'][i:i+1]], axis = 0)


    cosine_sim = cos_sim(emb_query, emb_candidates)

    doc_inds = np.argpartition(cosine_sim.numpy()[0], -top_k)[-top_k:]

    top_score_ids = top_score_ids.take(doc_inds)

    matches = []
    ids = []
    answers = []

    for doc_ind in top_score_ids:
        doc = st.session_state['passages'][doc_ind].replace('_',' ')

        matches.append(doc)#' '.join(doc).replace('_',' '))
        ids.append(st.session_state['passage_id'][doc_ind].replace('_',' '))#' '.join(doc[:30].split()[:3]))
    # i=0
    for context in matches:
        q = "Trả lời câu hỏi: "+query + " Trong ngữ cảnh: "+context#tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(context))
        a = hf_run_model(st.session_state['tokenizer'], st.session_state['model'], q)[0][0]
        answers.append(a)
        
    # generate result df
    df_results = pd.DataFrame(
        {'Title': ids,
            'Answer': answers,
            'Retrieved': matches,
        })

    # st.header("Retrieved Answers:")
    # df_results.set_index('title', inplace=True)
    st.header("Results:")
    st.table(df_results)

    # del tokenizer, model, bi_encoder, emb_candidates




#%%
#title start page
st.title('Closed Domain QA System on Vietnamese Laws')

sdg = Image.open('./logo.jpg')
st.sidebar.image(sdg, width=300)
st.sidebar.title('Settings')


st.caption("by HoangNV - on custom laws QA data set")
returns = st.sidebar.slider('Number of answer suggestions:', 1, 3, 2)


question = st.text_input('Type in your legal question:')

if len(question) != 0:
    t0 = time.time()
    with st.spinner('Finding best answers...'):
        deploy(question)
        st.write("Runtime: "+str(time.time()-t0))




#%%
p = Path('.')