Prat0's picture
Update app.py
91b0ff8 verified
raw
history blame
2.36 kB
from llama_index.core.indices.vector_store.base import VectorStoreIndex
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.core import Settings
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, StorageContext
import qdrant_client
from llama_index.core.indices.query.schema import QueryBundle
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.core.memory import ChatMemoryBuffer
import gradio as gr
def embed_setup():
Settings.embed_model = GeminiEmbedding(api_key=os.getenv("GEMINI_API_KEY"), model_name="models/embedding-001")
Settings.llm = Gemini(api_key=os.getenv("GEMINI_API_KEY"), temperature=0.1,model_name="models/gemini-pro")
def qdrant_setup():
client = qdrant_client.QdrantClient(
os.getenv("QDRANT_URL"),
api_key = os.getenv("QDRANT_API_KEY"),
)
return client
def llm_setup():
llm = Gemini(api_key=os.getenv("GEMINI_API_KEY"), temperature=0.6,model_name="models/gemini-pro")
return llm
def query_index(index, similarity_top_k=3, streaming=True):
memory = ChatMemoryBuffer.from_defaults(token_limit=4000)
chat_engine = index.as_chat_engine(
chat_mode="context",
memory=memory,
system_prompt = (
"""You are an AI assistant named Gemini, created by Google. Your task is to provide helpful, accurate, and concise responses to user queries.
Context information is below:
----------------
{context_str}
----------------
Always answer based on the information in the context and be precise
Given this context, please respond to the following user query:
{query_str}
Also suggest 3 more questions based on the the context that the user can ask
Your response:"""
),)
return chat_engine
def get_response(text,history=None):
# Use the initialized query engine to perform the query
response = str(chat_engine.chat(text))
return response
embed_setup()
client = qdrant_setup()
llm = llm_setup()
vector_store = QdrantVectorStore(client = client,collection_name=os.getenv("COLLECTION_NAME"))
index = VectorStoreIndex.from_vector_store(llm = llm, vector_store = vector_store)
chat_engine = query_index(index) # initialize the query engine
t = gr.ChatInterface(get_response, analytics_enabled=True)
t.launch(debug=True, share=True)