File size: 2,499 Bytes
6855cb4
9b1bd0f
6855cb4
 
 
7b527d5
6855cb4
7b527d5
6855cb4
9b1bd0f
9247ac9
6855cb4
 
 
9247ac9
6855cb4
7b527d5
 
9247ac9
6855cb4
 
 
 
9247ac9
6855cb4
 
9247ac9
6855cb4
 
 
 
 
 
 
 
 
 
9247ac9
 
6855cb4
9b1bd0f
6855cb4
31325d7
215eca5
6855cb4
215eca5
 
 
6855cb4
9247ac9
6855cb4
9b1bd0f
 
 
 
9247ac9
6855cb4
9b1bd0f
 
 
 
 
 
 
 
 
9247ac9
 
9b1bd0f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# AI assistant with a RAG system to query information from
#  the gwIAS search pipline
# using Langchain and deployed with Gradio

from rag import RAG, load_docs
# from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain.chat_models import ChatOpenAI
import gradio as gr
import os

# Load the documentation
docs = load_docs()
print("Pages loaded:", len(docs))

# LLM model
llm = ChatOpenAI(model="gpt-4o-mini")
# llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)

# Embeddings
embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
# embed_model = "nvidia/NV-Embed-v2"
embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)

# RAG chain
rag_chain = RAG(llm, docs, embeddings)

# Function to handle prompt and query the RAG chain
def handle_prompt(message, history):
    try:
        # Stream output
        out = ""
        for chunk in rag_chain.stream(message):
            out += chunk
            yield out
    except Exception as e:
        raise gr.Error(f"An error occurred: {str(e)}")


if __name__ == "__main__":

    # Predefined messages and examples
    description = "AI powered assistant to help with [gwfast](https://github.com/CosmoStatGW/gwfast) fisher matrix pipeline."
    greetingsmessage = "Hi, I'm the gwfast Bot, I'm here to assist you with the fisher matrix pipeline."
    example_questions = [
        "How can I calculate the fisher matrix?",
        "What waveforms are available in gwfast?",
        "How is derivative of waveforms calculated?"
    ]

    # Define customized Gradio chatbot
    chatbot = gr.Chatbot([{"role": "assistant", "content": greetingsmessage}],
                         type="messages",
                         avatar_images=["ims/userpic.png", "ims/gwIASlogo.jpg"],
                         height="60vh")

    # Define Gradio interface
    demo = gr.ChatInterface(handle_prompt,
                            type="messages",
                            title="gwIAS DocBot",
                            fill_height=True,
                            examples=example_questions,
                            theme=gr.themes.Soft(),
                            description=description,
                            # cache_examples=False,
                            chatbot=chatbot)

    demo.launch()

# https://arxiv.org/html/2405.17400v2
# https://arxiv.org/html/2312.06631v1
# https://arxiv.org/html/2310.15233v2