File size: 1,261 Bytes
a909949
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from nemo.config import vector_search

OLLAMA_BASE_URL = "http://localhost:11434/v1"
GROQ_BASE_URL = "https://api.groq.com/openai/v1"

def initialize_llm(model_api_key, provider, model):
    if provider == "gemini":
        return ChatGoogleGenerativeAI(google_api_key=model_api_key, model=model)
    elif provider == "openai":
        return ChatOpenAI(openai_api_key=model_api_key, model_name=model)
    elif provider == "groq":
        return ChatOpenAI(openai_api_key=model_api_key, openai_api_base=GROQ_BASE_URL, model_name=model)
    elif provider == "ollama":
        return ChatOpenAI(openai_api_key="", openai_api_base=OLLAMA_BASE_URL, model_name=model)
    else:
        return None

def prompt_template(question, context):
    return f"""You are an **GitDoc AI** Chatbot, a helpful assistant that assists users with their 
    NVIDIA's NeMo Guardrails related questions. 
    Use the following pieces of context to answer the user's question:
    {context}

    USER QUESTION: ```{question}```
    Answer in markdown:"""

def rag_chain(llm, message):
    context = vector_search(message)
    return llm.invoke(prompt_template(message, context)).content