from sklearn.feature_extraction.text import TfidfVectorizer
from sentence_transformers import SentenceTransformer, util
import pandas as pd
import numpy as np
import gradio as gr
import torch
import spacy
import re

nlp = spacy.load("en_core_sci_sm")



# ----------------------------------------------
# Step 1. 讀取檔案 轉換 句子單位 JSON
# ----------------------------------------------

def read_text_to_json(path):
    paper = {}
    with open(path, 'r', encoding='utf-8') as txt:
        key = None
        for line in txt:
            line = line.strip()
            if line.startswith('@Paper') or line.startswith('@Section'):
                key = line.split()[1]
                paper[key] = []
            elif key and line:
                paper[key].append(line)
    return paper

def is_valid_format(paper):
    for key in ['title', 'I', 'M', 'R', 'D']:
        if key not in paper or len(paper[key])==0:
            return False
    return True

def remove_parentheses_with_useless_tokens(text):
    return re.sub(r'\s*\(\s*(?:table|fig|http|www)[^()]*\)', '', text, flags = re.I)  # re.I 不區分大小寫

def segment_sentences(section, pos_para = False):
    sents = []
    sents_break = [".", "?", "!"]
    start = para_i = pre_para_i =  0
    conn = False
    for para in section:
        para = remove_parentheses_with_useless_tokens(para).strip() # 避免末端空白判斷為 token 而無法 sents_break
        doc = nlp(para)
        for sent in doc.sents:
            if any(t in sents_break for t in sent[-1].text): # 部分句尾詞如 3h. 無法分詞, 因此包含 sents_break 即可  
                para_i +=1          
                text = "".join(t.text_with_ws for t in doc[start:sent.end])                         # 原始字串
                tokenize_text = " ".join(t.text for t in doc[start:sent.end])                       # 分詞字串
                sentence = {"text":text, "tokenize_text":tokenize_text, "pos":pre_para_i+para_i}    # 建立句子物件
                if pos_para: sentence['pos_para'] = para_i                                          # pos 句子位置, pos_para 句子於每段位置
                sents.append(sentence)
                start = sent.end
                conn = False
            else:      
                start = start if conn else sent.start   # sent.end 非斷句字符 紀錄此句 start, 直到斷句前不更改 start 位置
                conn = True
        pre_para_i += para_i
        start = para_i =  0
    return sents

def convert_to_sentence_json(paper):
    sentJson = {
        'title': paper['title'],
        'body': {}
    }
    for key in ['I', 'M', 'R', 'D']:
        sentJson['body'][key] = segment_sentences(paper[key], True)
    return sentJson



# ----------------------------------------------
# Step 2. 句子單位 進行 特徵萃取
# ----------------------------------------------

# 句子列表
def sent_lst(sents):
    return [sent['text'] for sent in sents]

# 移除停用詞及標點
def clean_token(doc):
    return [token for token in doc if not (token.is_stop or token.is_punct)]

# 段落之總句數
def add_num_sents_para(sents):
    reset = True
    for index, sent in reversed(list(enumerate(sents))):    
        if reset: ptr = sent['pos_para']
        reset = True if sent['pos_para'] == 1 else False
        sents[index]['ns_para'] = ptr
    return sents

# 位置重要性
def position_imp(cur, ns):
    imp = 1 if cur == 1 else (ns-cur)/ns
    return imp

# 標題詞列表
def title_wlst(txt):
    doc = nlp(txt)
    wlst = [token.text.lower() for token in clean_token(doc)]
    return list(set(wlst))

# 句子之標題詞數量
def title_word_count(doc, wlst):
    titleLen = len(wlst)
    score = 0 if titleLen == 0 else len([token for token in doc if token.text.lower() in wlst])/titleLen
    return score

# 標記詞性之數量
def pos_token(doc, pos_type):
    return len([token for token in doc if token.pos_ == pos_type])

# 自定分詞器
def custom_toknizer(txt):
    doc = nlp(txt)
    words = [token.lemma_.lower() for token in doc if not (token.is_stop or token.is_punct or token.is_digit)]
    return words

# 詞頻-逆向句子頻率 
def Tfisf(lst):
    tf = TfidfVectorizer(tokenizer=custom_toknizer, lowercase=False)
    tfisf_matrix = tf.fit_transform(lst)
    word_count = (tfisf_matrix!=0).sum(1)
    with np.errstate(divide='ignore', invalid='ignore'):
        mean_score = np.where(word_count == 0, 0, np.divide(tfisf_matrix.sum(1), word_count)).flatten()
    return mean_score

# 餘弦相似度
def similarity(lst, ptm):
    model = SentenceTransformer(ptm)
    embeddings = model.encode(lst, convert_to_tensor=True)
    cosine = util.cos_sim(embeddings, embeddings)
    cosine = cosine.sum(1)-1
    cosine = torch.divide(cosine, torch.max(cosine)).numpy() # .cpu().numpy()
    return cosine

# 特徵萃取
def feature_extraction(title, section, sents): 
    lst = sent_lst(sents)
    tfisf = Tfisf(lst)
    cosine = similarity(lst, "pritamdeka/PubMedBERT-mnli-snli-scinli-scitail-mednli-stsb")
    
    # Number of sentences
    ns = len(sents)
    sents = add_num_sents_para(sents)
    # Extracting the features of each sentences
    arr = np.empty((0,9))
    for index, sent in enumerate(sents):
        doc = nlp(sent["text"])
        doc = clean_token(doc)
        
        F1 = len(doc)                                           # Sentence Length (undone) -> len / longest sentence len
        F2 = position_imp(sent["pos"], ns)                      # Sentence Position
        F3 = position_imp(sent["pos_para"], sent["ns_para"])    # Sentence Position (in paragraph)
        F4 = title_word_count(doc, title)                       # Title Word
        F5 = 0 if F1 == 0 else pos_token(doc, "PROPN")/F1       # Proper Noun
        F6 = 0 if F1 == 0 else pos_token(doc, "NUM")/F1         # Numerical Token
        F7 = tfisf[index]                                       # Term Frequency-Inverse Sentence Frequency
        F10 = cosine[index]                                     # Cosine Similarity

        feat = np.array([[section, F1, F2, F3, F4, F5, F6, F7, F10]])
        arr = np.append(arr, feat, axis=0)
    # F1 (done)
    maxLen = np.amax(arr[:,1])
    arr[:,1] = arr[:,1]/maxLen 
    return arr

# 設置欄位類型
def set_dtypes(df):
    df = df.astype({'section': 'int8', 'F1': 'float32', 'F2': 'float32',
                    'F3': 'float32', 'F4': 'float32', 'F5': 'float32',
                    'F6': 'float32', 'F7': 'float32', 'F10': 'float32'})
    return df

# 文章 IMRD - 句子特徵
def feature_from_imrd(body, title):
    paper = np.empty((0,9))
    for index, key in enumerate(['I', 'M', 'R', 'D'], start = 1):
        paper = np.append(paper, feature_extraction(title, index, body[key]), axis = 0)
    df = pd.DataFrame(paper, columns = ['section','F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F10'])
    return set_dtypes(df)

def extract_sentence_features(sentJson):
    title = title_wlst(sentJson['title'][0])
    sentFeat = feature_from_imrd(sentJson['body'], title)
    return sentFeat