Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import pipeline | |
st.set_page_config(page_title="ViBidLawQA - Hệ thống hỏi đáp trực tuyến luật Việt Nam", page_icon="./app/static/ai.png", layout="centered", initial_sidebar_state="expanded") | |
with open("./static/styles.css") as f: | |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
st.markdown(f""" | |
<div class=logo_area> | |
<img src="./app/static/ai.png"/> | |
</div> | |
""", unsafe_allow_html=True) | |
st.markdown("<h2 style='text-align: center;'>ViBidLawQA</h2>", unsafe_allow_html=True) | |
context = st.sidebar.text_area(label='Nội dung văn bản pháp luật Việt Nam:', placeholder='Vui lòng nhập nội dung văn bản pháp luật Việt Nam tại đây...', height=925) | |
device = 0 if torch.cuda.is_available() else -1 | |
if 'model' not in st.session_state: | |
print('Some errors occurred!') | |
st.session_state.model = pipeline("question-answering", model='./model/', device=device) | |
def get_answer(context, question): | |
return st.session_state.model(context=context, question=question, max_answer_len=512) | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
if message['role'] == 'assistant': | |
avatar_class = "assistant-avatar" | |
message_class = "assistant-message" | |
avatar = './app/static/ai.png' | |
else: | |
avatar_class = "user-avatar" | |
message_class = "user-message" | |
avatar = './app/static/human.png' | |
st.markdown(f""" | |
<div class="{message_class}"> | |
<img src="{avatar}" class="{avatar_class}" /> | |
<div class="stMarkdown">{message['content']}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
if prompt := st.chat_input(placeholder='Tôi có thể giúp được gì cho bạn?'): | |
st.markdown(f""" | |
<div class="user-message"> | |
<img src="./app/static/human.png" class="user-avatar" /> | |
<div class="stMarkdown">{prompt}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.session_state.messages.append({'role': 'user', 'content': prompt}) | |
respond = get_answer(context=context, question=prompt)['answer'] | |
respond = respond.strip() | |
respond = respond.rstrip('.,!?;:') | |
respond = respond.capitalize() | |
if not respond.endswith('.'): | |
respond += '.' | |
st.markdown(f""" | |
<div class="assistant-message"> | |
<img src="./app/static/ai.png" class="assistant-avatar" /> | |
<div class="stMarkdown">{respond}</div> | |
</div> | |
""", unsafe_allow_html=True) | |
st.session_state.messages.append({'role': 'assistant', 'content': respond}) |