Spaces:
Runtime error
Runtime error
File size: 3,211 Bytes
e6f968c ce9ef0a e6f968c ce9ef0a fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 0dae114 fadf40f 076d575 fadf40f 0dae114 fadf40f 076d575 fadf40f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import wandb
wandb.login(relogin=True, key=os.getenv("WANDB_API_KEY"))
import streamlit as st
import weave
from medrag_multi_modal.assistant import (
FigureAnnotatorFromPageImage,
LLMClient,
MedQAAssistant,
)
from medrag_multi_modal.assistant.llm_client import (
GOOGLE_MODELS,
MISTRAL_MODELS,
OPENAI_MODELS,
)
from medrag_multi_modal.retrieval import MedCPTRetriever
# Define constants
ALL_AVAILABLE_MODELS = GOOGLE_MODELS + MISTRAL_MODELS + OPENAI_MODELS
# Sidebar for configuration settings
st.sidebar.title("Configuration Settings")
project_name = st.sidebar.text_input(
label="Project Name",
value="ml-colabs/medrag-multi-modal",
placeholder="wandb project name",
help="format: wandb_username/wandb_project_name",
)
chunk_dataset_name = st.sidebar.text_input(
label="Text Chunk WandB Dataset Name",
value="grays-anatomy-chunks:v0",
placeholder="wandb dataset name",
help="format: wandb_dataset_name:version",
)
index_artifact_address = st.sidebar.text_input(
label="WandB Index Artifact Address",
value="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
placeholder="wandb artifact address",
help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
)
image_artifact_address = st.sidebar.text_input(
label="WandB Image Artifact Address",
value="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
placeholder="wandb artifact address",
help="format: wandb_username/wandb_project_name/wandb_artifact_name:version",
)
llm_client_model_name = st.sidebar.selectbox(
label="LLM Client Model Name",
options=ALL_AVAILABLE_MODELS,
index=ALL_AVAILABLE_MODELS.index("gemini-1.5-flash"),
help="select a model from the list",
)
figure_extraction_model_name = st.sidebar.selectbox(
label="Figure Extraction Model Name",
options=ALL_AVAILABLE_MODELS,
index=ALL_AVAILABLE_MODELS.index("pixtral-12b-2409"),
help="select a model from the list",
)
structured_output_model_name = st.sidebar.selectbox(
label="Structured Output Model Name",
options=ALL_AVAILABLE_MODELS,
index=ALL_AVAILABLE_MODELS.index("gpt-4o"),
help="select a model from the list",
)
# Streamlit app layout
st.title("MedQA Assistant App")
# Initialize Weave
weave.init(project_name=project_name)
# Initialize clients and assistants
llm_client = LLMClient(model_name=llm_client_model_name)
retriever = MedCPTRetriever.from_wandb_artifact(
chunk_dataset_name=chunk_dataset_name,
index_artifact_address=index_artifact_address,
)
figure_annotator = FigureAnnotatorFromPageImage(
figure_extraction_llm_client=LLMClient(model_name=figure_extraction_model_name),
structured_output_llm_client=LLMClient(model_name=structured_output_model_name),
image_artifact_address=image_artifact_address,
)
medqa_assistant = MedQAAssistant(
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
)
query = st.chat_input("Enter your question here")
if query:
with st.chat_message("user"):
st.markdown(query)
response = medqa_assistant.predict(query=query)
with st.chat_message("assistant"):
st.markdown(response)
|