Spaces:
Sleeping
Sleeping
import os | |
import mesop as me | |
from dataclasses import dataclass, field | |
from typing import Callable, Generator, Literal | |
import time | |
# from rag_app.rag import extract_final_answer, answer_question | |
from rag_app.rag_2 import check_if_exists, precompute_index, answer_question | |
Role = Literal["user", "assistant"] | |
_ROLE_USER = "user" | |
_ROLE_ASSISTANT = "assistant" | |
_COLOR_CHAT_BUBBLE_YOU = me.theme_var("surface-container-low") | |
_COLOR_CHAT_BUBBLE_BOT = me.theme_var("secondary-container") | |
_DEFAULT_BORDER_SIDE = me.BorderSide( | |
width="1px", style="solid", color=me.theme_var("secondary-fixed") | |
) | |
_STYLE_CHAT_BUBBLE_NAME = me.Style( | |
font_weight="bold", | |
font_size="12px", | |
padding=me.Padding(left=15, right=15, bottom=5), | |
) | |
class ChatMessage: | |
role: Role = "user" | |
content: str = "" | |
class State: | |
input: str = "" | |
output: list[ChatMessage] = field(default_factory=list) | |
in_progress: bool = False | |
pdf_files: list[str] = field(default_factory=list) # Changed to a list | |
def respond_to_chat(query: str, history: list[ChatMessage]): | |
# if not check_if_exists(): | |
# print("computing the vector index and the BM 25 retriever which will later be used") | |
# precompute_index() | |
assistant_message = ChatMessage(role=_ROLE_ASSISTANT) | |
yield assistant_message | |
state = me.state(State) | |
if len(state.pdf_files) == 0: | |
response = answer_question(query) | |
else: | |
response = answer_question(query) | |
print("Agent response=", response) | |
yield response | |
def on_chat_input(e: me.InputEvent): | |
state = me.state(State) | |
state.input = e.value | |
def on_click_submit_chat_msg(e: me.ClickEvent | me.InputEnterEvent): | |
state = me.state(State) | |
if state.in_progress or not state.input: | |
return | |
input_ = state.input | |
state.input = "" | |
yield | |
output = state.output | |
output.append(ChatMessage(role=_ROLE_USER, content=input_)) | |
state.in_progress = True | |
me.scroll_into_view(key="scroll-to") | |
yield | |
start_time = time.time() | |
for content in respond_to_chat(input_, state.output): | |
if isinstance(content, ChatMessage): | |
assistant_message = content | |
output.append(assistant_message) | |
state.output = output | |
else: | |
assistant_message.content += content | |
if (time.time() - start_time) >= 0.25: | |
start_time = time.time() | |
yield | |
state.in_progress = False | |
yield | |
def _make_style_chat_bubble_wrapper(role: Role) -> me.Style: | |
align_items = "end" if role == _ROLE_USER else "start" | |
return me.Style( | |
display="flex", | |
flex_direction="column", | |
align_items=align_items, | |
) | |
def _make_chat_bubble_style(role: Role) -> me.Style: | |
background = _COLOR_CHAT_BUBBLE_YOU | |
if role == _ROLE_ASSISTANT: | |
background = _COLOR_CHAT_BUBBLE_BOT | |
return me.Style( | |
width="80%", | |
font_size="13px", | |
background=background, | |
border_radius="15px", | |
padding=me.Padding(right=15, left=15, bottom=3), | |
margin=me.Margin(bottom=10), | |
border=me.Border( | |
left=_DEFAULT_BORDER_SIDE, | |
right=_DEFAULT_BORDER_SIDE, | |
top=_DEFAULT_BORDER_SIDE, | |
bottom=_DEFAULT_BORDER_SIDE, | |
), | |
) | |
def save_uploaded_file(uploaded_file: me.UploadedFile): | |
save_directory = "data" | |
os.makedirs(save_directory, exist_ok=True) | |
file_path = os.path.join(save_directory, uploaded_file.name) | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getvalue()) | |
print(f"File saved successfully at {file_path}") | |
def handle_pdf_upload(event: me.UploadEvent): | |
state = me.state(State) | |
save_uploaded_file(event.file) | |
print("precomputing vector indices") | |
precompute_index() | |
state.pdf_files.append(os.path.join("docs", event.file.name)) | |