import gc import time import torch import numpy as np import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering st.set_page_config(page_title="ViBidLawQA - Hệ thống hỏi đáp trực tuyến luật Việt Nam", page_icon="./app/static/ai.png", layout="centered", initial_sidebar_state="expanded") with open("./static/styles.css") as f: st.markdown(f"", unsafe_allow_html=True) if 'messages' not in st.session_state: st.session_state.messages = [] st.markdown(f"""
""", unsafe_allow_html=True) st.markdown("

ViBidLawQA_v2

", unsafe_allow_html=True) answering_method = st.sidebar.selectbox(options=['Extraction', 'Generation'], label='Chọn mô hình trả lời câu hỏi:', index=0) context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=500) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if answering_method == 'Generation' and 'aqa_model' not in st.session_state: if 'eqa_model' and 'eqa_tokenizer' in st.session_state: del st.session_state.eqa_model del st.session_state.eqa_tokenizer if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print('Switching to generative model...') print('Loading generative model...') st.session_state.aqa_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path='./models/AQA_model').to(device) st.session_state.aqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/AQA_model') if answering_method == 'Extraction' and 'eqa_model' not in st.session_state: if 'aqa_model' and 'aqa_tokenizer' in st.session_state: del st.session_state.aqa_model del st.session_state.aqa_tokenizer if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print('Switching to extraction model...') print('Loading extraction model...') st.session_state.eqa_model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name_or_path='./models/EQA_model').to(device) st.session_state.eqa_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path='./models/EQA_model') def get_abstractive_answer(context, question, max_length=1024, max_target_length=512): inputs = st.session_state.aqa_tokenizer(question, context, max_length=max_length, truncation='only_second', padding='max_length', return_tensors='pt') outputs = st.session_state.aqa_model.generate(inputs=inputs['input_ids'].to(device), attention_mask=inputs['attention_mask'].to(device), max_length=max_target_length) answer = st.session_state.aqa_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_space=True) if not answer.endswith('.'): answer += '.' return answer def generate_text_effect(answer): words = answer.split() for i in range(len(words)): time.sleep(0.05) yield " ".join(words[:i+1]) def get_extractive_answer(context, question, stride=20, max_length=256, n_best=50, max_answer_length=512): inputs = st.session_state.eqa_tokenizer(question, context, max_length=max_length, truncation='only_second', stride=stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding='max_length') for i in range(len(inputs['input_ids'])): sequence_ids = inputs.sequence_ids(i) offset = inputs['offset_mapping'][i] inputs['offset_mapping'][i] = [ o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) ] input_ids = torch.tensor(inputs["input_ids"]).to(device) attention_mask = torch.tensor(inputs["attention_mask"]).to(device) with torch.no_grad(): outputs = st.session_state.eqa_model(input_ids=input_ids, attention_mask=attention_mask) start_logits = outputs.start_logits.cpu().numpy() end_logits = outputs.end_logits.cpu().numpy() answers = [] for i in range(len(inputs["input_ids"])): start_logit = start_logits[i] end_logit = end_logits[i] offsets = inputs["offset_mapping"][i] start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() for start_index in start_indexes: for end_index in end_indexes: if offsets[start_index] is None or offsets[end_index] is None: continue if end_index < start_index or end_index - start_index + 1 > max_answer_length: continue answer = { "text": context[offsets[start_index][0] : offsets[end_index][1]], "logit_score": start_logit[start_index] + end_logit[end_index], } answers.append(answer) if len(answers) > 0: best_answer = max(answers, key=lambda x: x["logit_score"]) return best_answer["text"] else: return "" for message in st.session_state.messages: if message['role'] == 'assistant': avatar_class = "assistant-avatar" message_class = "assistant-message" avatar = './app/static/ai.png' else: avatar_class = "user-avatar" message_class = "user-message" avatar = './app/static/human.png' st.markdown(f"""
{message['content']}
""", unsafe_allow_html=True) if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): st.markdown(f"""
{prompt}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'user', 'content': prompt}) message_placeholder = st.empty() for _ in range(2): for dots in ["●", "●●", "●●●"]: time.sleep(0.2) message_placeholder.markdown(f"""
{dots}
""", unsafe_allow_html=True) full_response = "" if answering_method == 'Generation': abs_answer = get_abstractive_answer(context=context, question=prompt) for word in generate_text_effect(abs_answer): full_response = word message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) else: ext_answer = get_extractive_answer(context=context, question=prompt) for word in generate_text_effect(ext_answer): full_response = word message_placeholder.markdown(f"""
{full_response}●
""", unsafe_allow_html=True) message_placeholder.markdown(f"""
{full_response}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'assistant', 'content': full_response})