Spaces:
Sleeping
Sleeping
import gc | |
import time | |
import torch | |
import numpy as np | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering | |
st.set_page_config(page_title="ViBidLawQA - Hệ thống hỏi đáp trực tuyến luật Việt Nam", page_icon="./app/static/ai.png", layout="centered", initial_sidebar_state="expanded") | |
with open("./static/styles.css") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
st.markdown(f""" | |
<div class=logo_area> | |
<img src="./app/static/ai.png"/> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown("<h2 style='text-align: center;'>ViBidLawQA_v2</h2>", unsafe_allow_html=True) | |
answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0) | |
context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if answering_method == 'Generation' and 'aqa_model' not in st.session_state: | |
if 'eqa_model' and 'eqa_tokenizer' in st.session_state: | |
del st.session_state.eqa_model | |
del st.session_state.eqa_tokenizer | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
print('Switching to generative model...') | |
print('Loading generative model...') | |
st.session_state.aqa_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path='./models/AQA_model').to(device) | |
st.session_state.aqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/AQA_model') | |
if answering_method == 'Extraction' and 'eqa_model' not in st.session_state: | |
if 'aqa_model' and 'aqa_tokenizer' in st.session_state: | |
del st.session_state.aqa_model | |
del st.session_state.aqa_tokenizer | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
print('Switching to extraction model...') | |
print('Loading extraction model...') | |
st.session_state.eqa_model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path='./models/EQA_model').to(device) | |
st.session_state.eqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/EQA_model') | |
def get_abstractive_answer(context, question, max_length=1024, max_target_length=512): | |
inputs = st.session_state.aqa_tokenizer(question, | |
context, | |
max_length=max_length, | |
truncation='only_second', | |
padding='max_length', | |
return_tensors='pt') | |
outputs = st.session_state.aqa_model.generate(inputs=inputs['input_ids'].to(device), | |
attention_mask=inputs['attention_mask'].to(device), | |
max_length=max_target_length) | |
answer = st.session_state.aqa_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_space=True) | |
if not answer.endswith('.'): | |
answer += '.' | |
return answer | |
def generate_text_effect(answer): | |
words = answer.split() | |
for i in range(len(words)): | |
time.sleep(0.05) | |
yield " ".join(words[:i+1]) | |
def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512): | |
inputs = st.session_state.eqa_tokenizer(question, | |
context, | |
max_length=max_length, | |
truncation='only_second', | |
stride=stride, | |
return_overflowing_tokens=True, | |
return_offsets_mapping=True, | |
padding='max_length') | |
for i in range(len(inputs['input_ids'])): | |
sequence_ids = inputs.sequence_ids(i) | |
offset = inputs['offset_mapping'][i] | |
inputs['offset_mapping'][i] = [ | |
o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) | |
] | |
input_ids = torch.tensor(inputs["input_ids"]).to(device) | |
attention_mask = torch.tensor(inputs["attention_mask"]).to(device) | |
with torch.no_grad(): | |
outputs = st.session_state.eqa_model(input_ids=input_ids, attention_mask=attention_mask) | |
start_logits = outputs.start_logits.cpu().numpy() | |
end_logits = outputs.end_logits.cpu().numpy() | |
answers = [] | |
for i in range(len(inputs["input_ids"])): | |
start_logit = start_logits[i] | |
end_logit = end_logits[i] | |
offsets = inputs["offset_mapping"][i] | |
start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() | |
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() | |
for start_index in start_indexes: | |
for end_index in end_indexes: | |
if offsets[start_index] is None or offsets[end_index] is None: | |
continue | |
if end_index < start_index or end_index - start_index + 1 > max_answer_length: | |
continue | |
answer = { | |
"text": context[offsets[start_index][0] : offsets[end_index][1]], | |
"logit_score": start_logit[start_index] + end_logit[end_index], | |
} | |
answers.append(answer) | |
if len(answers) > 0: | |
best_answer = max(answers, key=lambda x: x["logit_score"]) | |
return best_answer["text"] | |
else: | |
return "" | |
for message in st.session_state.messages: | |
if message['role'] == 'assistant': | |
avatar_class = "assistant-avatar" | |
message_class = "assistant-message" | |
avatar = './app/static/ai.png' | |
else: | |
avatar_class = "user-avatar" | |
message_class = "user-message" | |
avatar = './app/static/human.png' | |
st.markdown(f""" | |
<div class="{message_class}"> | |
<img src="{avatar}" class="{avatar_class}" /> | |
<div class="stMarkdown">{message['content']}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): | |
st.markdown(f""" | |
<div class="user-message"> | |
<img src="./app/static/human.png" class="user-avatar" /> | |
<div class="stMarkdown">{prompt}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.session_state.messages.append({'role': 'user', 'content': prompt}) | |
message_placeholder = st.empty() | |
for _ in range(2): | |
for dots in ["●", "●●", "●●●"]: | |
time.sleep(0.2) | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.png" class="assistant-avatar" /> | |
<div class="stMarkdown">{dots}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
full_response = "" | |
if answering_method == 'Generation': | |
abs_answer = get_abstractive_answer(context=context, question=prompt) | |
for word in generate_text_effect(abs_answer): | |
full_response = word | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.png" class="assistant-avatar" /> | |
<div class="stMarkdown">{full_response}●</div> | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
ext_answer = get_extractive_answer(context=context, question=prompt) | |
for word in generate_text_effect(ext_answer): | |
full_response = word | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.png" class="assistant-avatar" /> | |
<div class="stMarkdown">{full_response}●</div> | |
</div> | |
""", unsafe_allow_html=True) | |
message_placeholder.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.png" class="assistant-avatar" /> | |
<div class="stMarkdown">{full_response}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.session_state.messages.append({'role': 'assistant', 'content': full_response}) |