File size: 4,734 Bytes
3ad9a49
 
 
114ce4a
f3576a5
 
 
 
 
 
 
114ce4a
f3576a5
 
 
 
 
e19a241
f3576a5
 
 
 
 
114ce4a
e19a241
 
 
 
f3576a5
e19a241
 
 
 
 
 
 
 
f3576a5
e19a241
 
 
 
 
f3576a5
e19a241
f3576a5
e19a241
 
 
f3576a5
2ccbf76
 
 
f3576a5
3ad9a49
 
e19a241
f3576a5
 
 
e19a241
f3576a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114ce4a
3ad9a49
114ce4a
f3576a5
 
 
 
 
 
 
 
114ce4a
3ad9a49
 
 
114ce4a
e19a241
114ce4a
 
e19a241
114ce4a
 
3ad9a49
 
 
 
 
 
e19a241
3ad9a49
 
 
 
 
 
 
114ce4a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# AI assistant with a RAG system to query information from the CAMELS cosmological simulations using Langchain
# Author: Pablo Villanueva Domingo

import gradio as gr
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_mistralai import ChatMistralAI
import requests
from langchain_community.document_loaders import WebBaseLoader
import bs4
from langchain_core.rate_limiters import InMemoryRateLimiter
from urllib.parse import urljoin

# Define a limiter to avoid rate limit issues with MistralAI
rate_limiter = InMemoryRateLimiter(
    requests_per_second=0.1,  # <-- MistralAI free. We can only make a request once every second
    check_every_n_seconds=0.01,  # Wake up every 100 ms to check whether allowed to make a request,
    max_bucket_size=10,  # Controls the maximum burst size.
)

# Function to get all the subpages from a base url
def get_subpages(base_url):
    visited_urls = []
    urls_to_visit = [base_url]

    while urls_to_visit:
        url = urls_to_visit.pop(0)
        if url in visited_urls:
            continue
        
        visited_urls.append(url)
        response = requests.get(url)
        soup = bs4.BeautifulSoup(response.content, "html.parser")

        for link in soup.find_all("a", href=True):
            full_url = urljoin(base_url, link['href'])
            if base_url in full_url and full_url.endswith(".html") and full_url not in visited_urls:
                urls_to_visit.append(full_url)
    visited_urls = visited_urls[1:]

    return visited_urls

# Get urls
base_url = "https://camels.readthedocs.io/en/latest/"
urls = get_subpages(base_url)

# Load, chunk and index the contents of the blog.
loader = WebBaseLoader(urls)
docs = loader.load()

print("Pages loaded:",len(docs))

# Join content pages for processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Create a RAG chain
def RAG(llm, docs, embeddings):

    # Split text
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    splits = text_splitter.split_documents(docs)

    # Create vector store
    vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)

    # Retrieve and generate using the relevant snippets of the documents
    retriever = vectorstore.as_retriever()

    # Prompt basis example for RAG systems
    prompt = hub.pull("rlm/rag-prompt")

    # Create the chain
    rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )

    return rag_chain

# LLM model
llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)

# Embeddings
embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
# embed_model = "nvidia/NV-Embed-v2"
embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)

# RAG chain
rag_chain = RAG(llm, docs, embeddings)

# Function to handle prompt and query the RAG chain
def handle_prompt(message, history):
    try:
        # Stream output
        out=""
        for chunk in rag_chain.stream(message):
            out += chunk
            yield out
    except:
        raise gr.Error("Requests rate limit exceeded")

# Predefined messages and examples
description = "AI powered assistant which answers any question related to the [CAMELS simulations](https://www.camel-simulations.org/)."
greetingsmessage = "Hi, I'm the CAMELS DocBot, I'm here to assist you with any question related to the CAMELS simulations."
example_questions = [
                    "How can I read a halo file?",
                    "Which simulation suites are included in CAMELS?",
                    "Which are the largest volumes in CAMELS simulations?",
                    "Write a complete snippet of code getting the power spectrum of a simulation"
                     ]

# Define customized Gradio chatbot
chatbot = gr.Chatbot([{"role":"assistant", "content":greetingsmessage}],
                     type="messages",
                     avatar_images=["ims/userpic.png","ims/camelslogo.jpg"],
                     height="60vh")

# Define Gradio interface
demo = gr.ChatInterface(handle_prompt,
                        type="messages",
                        title="CAMELS DocBot",
                        examples=example_questions,
                        theme=gr.themes.Soft(),
                        description=description,
                        chatbot=chatbot)

demo.launch()