File size: 2,599 Bytes
6ff89e0
cfb1a62
daee42b
c6040d0
cfb1a62
 
9cb07f9
669d93a
183168e
669d93a
 
 
 
 
 
 
 
 
daee42b
 
c6040d0
669d93a
daee42b
 
cfb1a62
 
 
6ff89e0
cfb1a62
daee42b
c6040d0
d377a8f
daee42b
 
cfb1a62
 
 
 
 
d0d7d0e
cfb1a62
 
daee42b
 
 
cfb1a62
daee42b
cfb1a62
 
c6040d0
 
 
9cb07f9
c6040d0
6a076b8
 
c6040d0
6a076b8
 
 
daee42b
c6040d0
 
 
cfb1a62
c6040d0
 
daee42b
c6040d0
 
 
 
 
 
 
 
cfb1a62
c6040d0
 
 
 
 
 
 
 
 
daee42b
 
d762ede
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import json
from rag.rag_pipeline import RAGPipeline
from utils.prompts import highlight_prompt, evidence_based_prompt, sample_questions
from config import STUDY_FILES

# Cache for RAG pipelines
rag_cache = {}


def get_rag_pipeline(study_name):
    if study_name not in rag_cache:
        study_file = STUDY_FILES.get(study_name)
        if study_file:
            rag_cache[study_name] = RAGPipeline(study_file)
        else:
            raise ValueError(f"Invalid study name: {study_name}")
    return rag_cache[study_name]


def chat_function(message, history, study_name, prompt_type):
    rag = get_rag_pipeline(study_name)

    if prompt_type == "Highlight":
        prompt = highlight_prompt
    elif prompt_type == "Evidence-based":
        prompt = evidence_based_prompt
    else:
        prompt = None

    response = rag.query(message, prompt_template=prompt)
    return response.response


def get_study_info(study_name):
    study_file = STUDY_FILES.get(study_name)
    if study_file:
        with open(study_file, "r") as f:
            data = json.load(f)
        return f"**Number of documents:** {len(data)}\n\n**First document title:** {data[0]['title']}"
    else:
        return "Invalid study name"


with gr.Blocks() as demo:
    gr.Markdown("# RAG Pipeline Demo")

    with gr.Row():
        study_dropdown = gr.Dropdown(
            choices=list(STUDY_FILES.keys()),
            label="Select Study",
            value=list(STUDY_FILES.keys())[0],
        )
        study_info = gr.Markdown()

    prompt_type = gr.Radio(
        ["Default", "Highlight", "Evidence-based"],
        label="Prompt Type",
        value="Default",
    )

    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, study_name, prompt_type):
        user_message = history[-1][0]
        bot_message = chat_function(user_message, history, study_name, prompt_type)
        history[-1][1] = bot_message
        return history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, study_dropdown, prompt_type], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

    study_dropdown.change(
        fn=get_study_info,
        inputs=study_dropdown,
        outputs=study_info,
    ).then(lambda: None, None, chatbot, queue=False)

    gr.Examples(examples=sample_questions[list(STUDY_FILES.keys())[0]], inputs=msg)

if __name__ == "__main__":
    demo.launch(share=True, debug=True)