Spaces:
Sleeping
Sleeping
import os | |
import weaviate | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from src.llama_cpp_chat_engine import LlamaCPPChatEngine | |
class ChatRagAgent: | |
def __init__(self): | |
# self._chat_engine = LlamaCPPChatEngine("Phi-3-mini-4k-instruct-q4.gguf") | |
self._chat_engine = LlamaCPPChatEngine("Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf") | |
self.n_ctx = self._chat_engine.n_ctx | |
self._vectorizer = SentenceTransformer( | |
"jinaai/jina-embeddings-v2-base-en", | |
trust_remote_code=True | |
) | |
self._reranker = CrossEncoder( | |
"jinaai/jina-reranker-v1-turbo-en", | |
trust_remote_code=True, | |
) | |
self._collection = weaviate.connect_to_wcs( | |
cluster_url=os.getenv("WCS_URL"), | |
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WCS_KEY")), | |
).collections.get("Collection") | |
def chat(self, messages, user_message): | |
embedding = self._vectorizer.encode(user_message).tolist() | |
docs = self._collection.query.near_vector( | |
near_vector=embedding, | |
limit=10 | |
) | |
ranks = self._reranker.rank( | |
user_message, | |
[i.properties['answer'] for i in docs.objects], | |
top_k=2, | |
apply_softmax=True | |
) | |
context = [ | |
f"""\ | |
Question: {docs.objects[rank['corpus_id']].properties['question']} | |
Answer: {docs.objects[rank['corpus_id']].properties['answer']} | |
""" | |
for rank in ranks if rank["score"] > 0.2 | |
] | |
sources = [ | |
docs.objects[rank['corpus_id']].properties['link'] | |
for rank in ranks if rank["score"] > 0.2 | |
] | |
return self._chat_engine.chat(messages, user_message, context), sources | |