Spaces:
Sleeping
Sleeping
File size: 1,817 Bytes
adc37f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
import weaviate
from sentence_transformers import SentenceTransformer, CrossEncoder
from src.llama_cpp_chat_engine import LlamaCPPChatEngine
class ChatRagAgent:
def __init__(self):
# self._chat_engine = LlamaCPPChatEngine("Phi-3-mini-4k-instruct-q4.gguf")
self._chat_engine = LlamaCPPChatEngine("Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf")
self.n_ctx = self._chat_engine.n_ctx
self._vectorizer = SentenceTransformer(
"jinaai/jina-embeddings-v2-base-en",
trust_remote_code=True
)
self._reranker = CrossEncoder(
"jinaai/jina-reranker-v1-turbo-en",
trust_remote_code=True,
)
self._collection = weaviate.connect_to_wcs(
cluster_url=os.getenv("WCS_URL"),
auth_credentials=weaviate.auth.AuthApiKey(os.getenv("WCS_KEY")),
).collections.get("Collection")
def chat(self, messages, user_message):
embedding = self._vectorizer.encode(user_message).tolist()
docs = self._collection.query.near_vector(
near_vector=embedding,
limit=10
)
ranks = self._reranker.rank(
user_message,
[i.properties['answer'] for i in docs.objects],
top_k=2,
apply_softmax=True
)
context = [
f"""\
Question: {docs.objects[rank['corpus_id']].properties['question']}
Answer: {docs.objects[rank['corpus_id']].properties['answer']}
"""
for rank in ranks if rank["score"] > 0.2
]
sources = [
docs.objects[rank['corpus_id']].properties['link']
for rank in ranks if rank["score"] > 0.2
]
return self._chat_engine.chat(messages, user_message, context), sources
|