import gc
import time
import torch
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
st.set_page_config(page_title="ViBidLawQA - Hệ thống hỏi đáp trực tuyến luật Việt Nam", page_icon="./app/static/ai.png", layout="centered", initial_sidebar_state="expanded")
with open("./static/styles.css") as f:
st.markdown(f"", unsafe_allow_html=True)
if 'messages' not in st.session_state:
st.session_state.messages = []
st.markdown(f"""
""", unsafe_allow_html=True)
st.markdown("ViBidLawQA_v2
", unsafe_allow_html=True)
answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0)
context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if answering_method == 'Generation' and 'aqa_model' not in st.session_state:
if 'eqa_model' and 'eqa_tokenizer' in st.session_state:
del st.session_state.eqa_model
del st.session_state.eqa_tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print('Switching to generative model...')
print('Loading generative model...')
st.session_state.aqa_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path='./models/AQA_model').to(device)
st.session_state.aqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/AQA_model')
if answering_method == 'Extraction' and 'eqa_model' not in st.session_state:
if 'aqa_model' and 'aqa_tokenizer' in st.session_state:
del st.session_state.aqa_model
del st.session_state.aqa_tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print('Switching to extraction model...')
print('Loading extraction model...')
st.session_state.eqa_model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path='./models/EQA_model').to(device)
st.session_state.eqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/EQA_model')
def get_abstractive_answer(context, question, max_length=1024, max_target_length=512):
inputs = st.session_state.aqa_tokenizer(question,
context,
max_length=max_length,
truncation='only_second',
padding='max_length',
return_tensors='pt')
outputs = st.session_state.aqa_model.generate(inputs=inputs['input_ids'].to(device),
attention_mask=inputs['attention_mask'].to(device),
max_length=max_target_length)
answer = st.session_state.aqa_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_space=True)
if not answer.endswith('.'):
answer += '.'
return answer
def generate_text_effect(answer):
words = answer.split()
for i in range(len(words)):
time.sleep(0.05)
yield " ".join(words[:i+1])
def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512):
inputs = st.session_state.eqa_tokenizer(question,
context,
max_length=max_length,
truncation='only_second',
stride=stride,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding='max_length')
for i in range(len(inputs['input_ids'])):
sequence_ids = inputs.sequence_ids(i)
offset = inputs['offset_mapping'][i]
inputs['offset_mapping'][i] = [
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
]
input_ids = torch.tensor(inputs["input_ids"]).to(device)
attention_mask = torch.tensor(inputs["attention_mask"]).to(device)
with torch.no_grad():
outputs = st.session_state.eqa_model(input_ids=input_ids, attention_mask=attention_mask)
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()
answers = []
for i in range(len(inputs["input_ids"])):
start_logit = start_logits[i]
end_logit = end_logits[i]
offsets = inputs["offset_mapping"][i]
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
if offsets[start_index] is None or offsets[end_index] is None:
continue
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
answer = {
"text": context[offsets[start_index][0] : offsets[end_index][1]],
"logit_score": start_logit[start_index] + end_logit[end_index],
}
answers.append(answer)
if len(answers) > 0:
best_answer = max(answers, key=lambda x: x["logit_score"])
return best_answer["text"]
else:
return ""
for message in st.session_state.messages:
if message['role'] == 'assistant':
avatar_class = "assistant-avatar"
message_class = "assistant-message"
avatar = './app/static/ai.png'
else:
avatar_class = "user-avatar"
message_class = "user-message"
avatar = './app/static/human.png'
st.markdown(f"""
{message['content']}
""", unsafe_allow_html=True)
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'):
st.markdown(f"""
{prompt}
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'user', 'content': prompt})
message_placeholder = st.empty()
for _ in range(2):
for dots in ["●", "●●", "●●●"]:
time.sleep(0.2)
message_placeholder.markdown(f"""
{dots}
""", unsafe_allow_html=True)
full_response = ""
if answering_method == 'Generation':
abs_answer = get_abstractive_answer(context=context, question=prompt)
for word in generate_text_effect(abs_answer):
full_response = word
message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True)
else:
ext_answer = get_extractive_answer(context=context, question=prompt)
for word in generate_text_effect(ext_answer):
full_response = word
message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True)
message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'assistant', 'content': full_response})