snsynth's picture
precompute at start
e4800a1
import os
import mesop as me
from dataclasses import dataclass, field
from typing import Callable, Generator, Literal
import time
# from rag_app.rag import extract_final_answer, answer_question
from rag_app.rag_2 import check_if_exists, precompute_index, answer_question
Role = Literal["user", "assistant"]
_ROLE_USER = "user"
_ROLE_ASSISTANT = "assistant"
_COLOR_CHAT_BUBBLE_YOU = me.theme_var("surface-container-low")
_COLOR_CHAT_BUBBLE_BOT = me.theme_var("secondary-container")
_DEFAULT_BORDER_SIDE = me.BorderSide(
width="1px", style="solid", color=me.theme_var("secondary-fixed")
)
_STYLE_CHAT_BUBBLE_NAME = me.Style(
font_weight="bold",
font_size="12px",
padding=me.Padding(left=15, right=15, bottom=5),
)
@dataclass(kw_only=True)
class ChatMessage:
role: Role = "user"
content: str = ""
@me.stateclass
class State:
input: str = ""
output: list[ChatMessage] = field(default_factory=list)
in_progress: bool = False
pdf_files: list[str] = field(default_factory=list) # Changed to a list
def respond_to_chat(query: str, history: list[ChatMessage]):
# if not check_if_exists():
# print("computing the vector index and the BM 25 retriever which will later be used")
# precompute_index()
assistant_message = ChatMessage(role=_ROLE_ASSISTANT)
yield assistant_message
state = me.state(State)
if len(state.pdf_files) == 0:
response = answer_question(query)
else:
response = answer_question(query)
print("Agent response=", response)
yield response
def on_chat_input(e: me.InputEvent):
state = me.state(State)
state.input = e.value
def on_click_submit_chat_msg(e: me.ClickEvent | me.InputEnterEvent):
state = me.state(State)
if state.in_progress or not state.input:
return
input_ = state.input
state.input = ""
yield
output = state.output
output.append(ChatMessage(role=_ROLE_USER, content=input_))
state.in_progress = True
me.scroll_into_view(key="scroll-to")
yield
start_time = time.time()
for content in respond_to_chat(input_, state.output):
if isinstance(content, ChatMessage):
assistant_message = content
output.append(assistant_message)
state.output = output
else:
assistant_message.content += content
if (time.time() - start_time) >= 0.25:
start_time = time.time()
yield
state.in_progress = False
yield
def _make_style_chat_bubble_wrapper(role: Role) -> me.Style:
align_items = "end" if role == _ROLE_USER else "start"
return me.Style(
display="flex",
flex_direction="column",
align_items=align_items,
)
def _make_chat_bubble_style(role: Role) -> me.Style:
background = _COLOR_CHAT_BUBBLE_YOU
if role == _ROLE_ASSISTANT:
background = _COLOR_CHAT_BUBBLE_BOT
return me.Style(
width="80%",
font_size="13px",
background=background,
border_radius="15px",
padding=me.Padding(right=15, left=15, bottom=3),
margin=me.Margin(bottom=10),
border=me.Border(
left=_DEFAULT_BORDER_SIDE,
right=_DEFAULT_BORDER_SIDE,
top=_DEFAULT_BORDER_SIDE,
bottom=_DEFAULT_BORDER_SIDE,
),
)
def save_uploaded_file(uploaded_file: me.UploadedFile):
save_directory = "data"
os.makedirs(save_directory, exist_ok=True)
file_path = os.path.join(save_directory, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getvalue())
print(f"File saved successfully at {file_path}")
def handle_pdf_upload(event: me.UploadEvent):
state = me.state(State)
save_uploaded_file(event.file)
print("precomputing vector indices")
precompute_index()
state.pdf_files.append(os.path.join("docs", event.file.name))