Spaces:
Paused
Paused
Pavol Liška
commited on
Commit
·
593b823
1
Parent(s):
3c35194
v1-fix
Browse files- agent/Agent.py +0 -2
- api.py +39 -18
- emdedd/MongoEmbedding.py +4 -7
- rag.py +53 -28
- rag_langchain.py +11 -8
- requirements.txt +1 -0
agent/Agent.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
from langchain.embeddings import CacheBackedEmbeddings
|
2 |
-
from langchain.storage import LocalFileStore
|
3 |
from langchain_core.language_models import BaseChatModel
|
4 |
|
5 |
from emdedd.Embedding import Embedding
|
|
|
|
|
|
|
1 |
from langchain_core.language_models import BaseChatModel
|
2 |
|
3 |
from emdedd.Embedding import Embedding
|
api.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from fastapi import FastAPI, Response, Body, Security
|
2 |
from fastapi.security import APIKeyHeader
|
3 |
-
from pydantic import BaseModel
|
|
|
|
|
4 |
|
5 |
from conversation.conversation_store import ConversationStore
|
6 |
from rag_langchain import LangChainRAG
|
@@ -22,37 +24,58 @@ class QModel(BaseModel):
|
|
22 |
temperature: str = "0.2"
|
23 |
llm: str = default_llm
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
class EmoModel(BaseModel):
|
27 |
qid: str
|
28 |
helpfulness: str
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
@api.get("/")
|
32 |
-
def read_root():
|
33 |
return "Empty"
|
34 |
|
35 |
|
36 |
-
@api.post("/
|
37 |
-
async def
|
38 |
# Verify the API key
|
39 |
if not valid_api_key(api_key):
|
40 |
return Response(status_code=401)
|
41 |
|
42 |
rag = LangChainRAG(
|
43 |
config={
|
44 |
-
"retrieve_documents":
|
45 |
-
"temperature":
|
46 |
"prompt_id": prompt_id,
|
47 |
"check_prompt_id": check_prompt_id,
|
48 |
"rewrite_prompt_id": rewrite_prompt_id
|
49 |
}
|
50 |
)
|
51 |
|
52 |
-
answer, check_result, sources = rag.rag_chain(
|
53 |
|
54 |
oid = conversation_store.save_content(
|
55 |
-
q=q,
|
56 |
a=answer,
|
57 |
sources=list(map(lambda doc: doc.page_content, sources)),
|
58 |
params=
|
@@ -61,19 +84,17 @@ async def q(api_key: str = Security(api_key_header), json_body: QModel = Body(..
|
|
61 |
"check_prompt_id": check_prompt_id,
|
62 |
"rewrite_prompt_id": rewrite_prompt_id,
|
63 |
"check_result": check_result,
|
64 |
-
"temperature":
|
65 |
-
"retrieve_document_count":
|
66 |
}
|
67 |
)
|
68 |
|
69 |
-
return
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
}
|
76 |
-
)
|
77 |
|
78 |
|
79 |
@api.post("/emo")
|
|
|
1 |
from fastapi import FastAPI, Response, Body, Security
|
2 |
from fastapi.security import APIKeyHeader
|
3 |
+
from pydantic import BaseModel, model_validator
|
4 |
+
from typing import List
|
5 |
+
import json
|
6 |
|
7 |
from conversation.conversation_store import ConversationStore
|
8 |
from rag_langchain import LangChainRAG
|
|
|
24 |
temperature: str = "0.2"
|
25 |
llm: str = default_llm
|
26 |
|
27 |
+
@classmethod
|
28 |
+
@model_validator(mode='before')
|
29 |
+
def validate_to_json(cls, value):
|
30 |
+
if isinstance(value, str):
|
31 |
+
return cls(**json.loads(value))
|
32 |
+
return value
|
33 |
+
|
34 |
+
|
35 |
+
class AModel(BaseModel):
|
36 |
+
q: str
|
37 |
+
a: str
|
38 |
+
sources: List[str]
|
39 |
+
oid: str
|
40 |
+
|
41 |
|
42 |
class EmoModel(BaseModel):
|
43 |
qid: str
|
44 |
helpfulness: str
|
45 |
|
46 |
+
@classmethod
|
47 |
+
@model_validator(mode='before')
|
48 |
+
def validate_to_json(cls, value):
|
49 |
+
if isinstance(value, str):
|
50 |
+
return cls(**json.loads(value))
|
51 |
+
return value
|
52 |
+
|
53 |
|
54 |
@api.get("/")
|
55 |
+
async def read_root():
|
56 |
return "Empty"
|
57 |
|
58 |
|
59 |
+
@api.post("/qa", response_model=AModel)
|
60 |
+
async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
|
61 |
# Verify the API key
|
62 |
if not valid_api_key(api_key):
|
63 |
return Response(status_code=401)
|
64 |
|
65 |
rag = LangChainRAG(
|
66 |
config={
|
67 |
+
"retrieve_documents": data.retrieval_count,
|
68 |
+
"temperature": data.temperature,
|
69 |
"prompt_id": prompt_id,
|
70 |
"check_prompt_id": check_prompt_id,
|
71 |
"rewrite_prompt_id": rewrite_prompt_id
|
72 |
}
|
73 |
)
|
74 |
|
75 |
+
answer, check_result, sources = rag.rag_chain(data.q, data.llm)
|
76 |
|
77 |
oid = conversation_store.save_content(
|
78 |
+
q=data.q,
|
79 |
a=answer,
|
80 |
sources=list(map(lambda doc: doc.page_content, sources)),
|
81 |
params=
|
|
|
84 |
"check_prompt_id": check_prompt_id,
|
85 |
"rewrite_prompt_id": rewrite_prompt_id,
|
86 |
"check_result": check_result,
|
87 |
+
"temperature": data.temperature,
|
88 |
+
"retrieve_document_count": data.retrieval_count,
|
89 |
}
|
90 |
)
|
91 |
|
92 |
+
return AModel(
|
93 |
+
a=answer,
|
94 |
+
q=data.q,
|
95 |
+
sources=list(map(lambda doc: doc.page_content, sources)),
|
96 |
+
oid=oid
|
97 |
+
)
|
|
|
|
|
98 |
|
99 |
|
100 |
@api.post("/emo")
|
emdedd/MongoEmbedding.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
|
|
|
3 |
from langchain.embeddings import CacheBackedEmbeddings
|
4 |
-
from langchain.storage import LocalFileStore
|
5 |
-
from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
|
6 |
from langchain_core.embeddings import Embeddings
|
7 |
from langchain_core.stores import InMemoryStore
|
|
|
8 |
from pymongo import MongoClient
|
9 |
-
from bson.objectid import ObjectId
|
10 |
|
11 |
from emdedd.Embedding import Embedding
|
12 |
|
@@ -83,9 +82,7 @@ class MongoEmbedding(Embedding):
|
|
83 |
)
|
84 |
|
85 |
def search(self, query, search_type, doc_count):
|
86 |
-
|
87 |
-
retriever = vector_store.as_retriever(
|
88 |
search_type="similarity",
|
89 |
search_kwargs={"k": doc_count}
|
90 |
-
)
|
91 |
-
return retriever.get_relevant_documents(query=query)
|
|
|
1 |
from dataclasses import dataclass
|
2 |
|
3 |
+
from bson.objectid import ObjectId
|
4 |
from langchain.embeddings import CacheBackedEmbeddings
|
|
|
|
|
5 |
from langchain_core.embeddings import Embeddings
|
6 |
from langchain_core.stores import InMemoryStore
|
7 |
+
from langchain_mongodb import MongoDBAtlasVectorSearch
|
8 |
from pymongo import MongoClient
|
|
|
9 |
|
10 |
from emdedd.Embedding import Embedding
|
11 |
|
|
|
82 |
)
|
83 |
|
84 |
def search(self, query, search_type, doc_count):
|
85 |
+
return self.get_vector_store().as_retriever(
|
|
|
86 |
search_type="similarity",
|
87 |
search_kwargs={"k": doc_count}
|
88 |
+
).get_relevant_documents(query=query)
|
|
rag.py
CHANGED
@@ -10,7 +10,7 @@ from langchain.chains.retrieval import create_retrieval_chain
|
|
10 |
from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever
|
11 |
from langchain_cohere import CohereRerank
|
12 |
from langchain_core.documents import Document
|
13 |
-
from langchain_core.prompts import PromptTemplate
|
14 |
|
15 |
from agent.Agent import Agent
|
16 |
from agent.agents import chat_openai_llm, deepinfra_chat
|
@@ -23,12 +23,12 @@ load_dotenv()
|
|
23 |
conversation_store = ConversationStore()
|
24 |
prompt_store = PromptStore()
|
25 |
|
26 |
-
grammar_check_1 = prompt_store.get_by_name("gramar_check_1")
|
27 |
-
rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1")
|
28 |
-
rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2")
|
29 |
-
rewrite_1 = prompt_store.get_by_name("rewrite_1")
|
30 |
-
rewrite_2 = prompt_store.get_by_name("rewrite_2")
|
31 |
-
rewrite_hyde = prompt_store.get_by_name("rewrite_hyde")
|
32 |
|
33 |
|
34 |
def replace_nl(input: str) -> str:
|
@@ -52,26 +52,6 @@ def rewrite(agent: Agent, q: str, prompt: str) -> list[str]:
|
|
52 |
return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)]
|
53 |
|
54 |
|
55 |
-
def rag_with_rerank_check_rewrite(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str,
|
56 |
-
rewrite_prompt: str):
|
57 |
-
rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt)
|
58 |
-
|
59 |
-
if len(rewritten_list) == 0:
|
60 |
-
return "Neviem, nemám podklady!", "", ""
|
61 |
-
|
62 |
-
context_doc = retrieve_subqueries(agent, retrieve_document_count, rewritten_list)
|
63 |
-
|
64 |
-
if len(context_doc) == 0:
|
65 |
-
return "Neviem, nemám kontext!", "", ""
|
66 |
-
|
67 |
-
result = answer_pipeline(agent, context_doc, prompt, q)
|
68 |
-
answer = result["text"]
|
69 |
-
|
70 |
-
check_result = check_pipeline(answer, check_prompt, context_doc, q)
|
71 |
-
|
72 |
-
return answer, check_result, context_doc
|
73 |
-
|
74 |
-
|
75 |
def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
76 |
check_prompt: str,
|
77 |
rewrite_prompt: str):
|
@@ -120,7 +100,48 @@ def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
|
120 |
prompt=PromptTemplate(
|
121 |
input_variables=["context", "question", "actual_date"],
|
122 |
template=prompt
|
123 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
)
|
125 |
).invoke(
|
126 |
input={
|
@@ -130,8 +151,12 @@ def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
|
130 |
}
|
131 |
)
|
132 |
|
|
|
|
|
133 |
check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
|
134 |
|
|
|
|
|
135 |
return result["answer"], check_result, result["context"]
|
136 |
|
137 |
|
|
|
10 |
from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever
|
11 |
from langchain_cohere import CohereRerank
|
12 |
from langchain_core.documents import Document
|
13 |
+
from langchain_core.prompts import PromptTemplate, BasePromptTemplate
|
14 |
|
15 |
from agent.Agent import Agent
|
16 |
from agent.agents import chat_openai_llm, deepinfra_chat
|
|
|
23 |
conversation_store = ConversationStore()
|
24 |
prompt_store = PromptStore()
|
25 |
|
26 |
+
grammar_check_1 = prompt_store.get_by_name("gramar_check_1").text
|
27 |
+
rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1").text
|
28 |
+
rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2").text
|
29 |
+
rewrite_1 = prompt_store.get_by_name("rewrite_1").text
|
30 |
+
rewrite_2 = prompt_store.get_by_name("rewrite_2").text
|
31 |
+
rewrite_hyde = prompt_store.get_by_name("rewrite_hyde").text
|
32 |
|
33 |
|
34 |
def replace_nl(input: str) -> str:
|
|
|
52 |
return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)]
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
56 |
check_prompt: str,
|
57 |
rewrite_prompt: str):
|
|
|
100 |
prompt=PromptTemplate(
|
101 |
input_variables=["context", "question", "actual_date"],
|
102 |
template=prompt
|
103 |
+
),
|
104 |
+
document_prompt=PromptTemplate(input_variables=[], template="page_content")
|
105 |
+
)
|
106 |
+
).invoke(
|
107 |
+
input={
|
108 |
+
"question": q,
|
109 |
+
"input": q,
|
110 |
+
"actual_date": datetime.date.today().isoformat()
|
111 |
+
}
|
112 |
+
)
|
113 |
+
|
114 |
+
print(result)
|
115 |
+
|
116 |
+
check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
|
117 |
+
|
118 |
+
print(check_result)
|
119 |
+
|
120 |
+
return result["answer"], check_result, result["context"]
|
121 |
+
|
122 |
+
|
123 |
+
def vanilla_rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
124 |
+
check_prompt: str):
|
125 |
+
retriever = ContextualCompressionRetriever(
|
126 |
+
base_compressor=(CohereRerank(
|
127 |
+
model="rerank-multilingual-v3.0",
|
128 |
+
top_n=retrieve_document_count
|
129 |
+
)),
|
130 |
+
base_retriever=(agent.embedding.get_vector_store().as_retriever(
|
131 |
+
search_type="similarity",
|
132 |
+
search_kwargs={"k": min(retrieve_document_count * 10, 500)},
|
133 |
+
))
|
134 |
+
)
|
135 |
+
|
136 |
+
result = create_retrieval_chain(
|
137 |
+
retriever=retriever,
|
138 |
+
combine_docs_chain=create_stuff_documents_chain(
|
139 |
+
llm=agent.llm,
|
140 |
+
prompt=PromptTemplate(
|
141 |
+
input_variables=["context", "question", "actual_date"],
|
142 |
+
template=prompt
|
143 |
+
),
|
144 |
+
document_prompt=PromptTemplate(input_variables=[], template="page_content")
|
145 |
)
|
146 |
).invoke(
|
147 |
input={
|
|
|
151 |
}
|
152 |
)
|
153 |
|
154 |
+
print(result)
|
155 |
+
|
156 |
check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
|
157 |
|
158 |
+
print(check_result)
|
159 |
+
|
160 |
return result["answer"], check_result, result["context"]
|
161 |
|
162 |
|
rag_langchain.py
CHANGED
@@ -4,6 +4,7 @@ from dotenv import load_dotenv
|
|
4 |
from gptcache import Cache
|
5 |
from gptcache.manager.factory import manager_factory
|
6 |
from gptcache.processor.pre import get_prompt
|
|
|
7 |
from langchain.retrievers import ContextualCompressionRetriever
|
8 |
from langchain_cohere import CohereRerank, CohereEmbeddings
|
9 |
from langchain_community.cache import GPTCache
|
@@ -19,10 +20,13 @@ from agent.agents import deepinfra_chat, \
|
|
19 |
from emdedd.Embedding import Embedding
|
20 |
from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
|
21 |
from prompt.prompt_store import PromptStore
|
22 |
-
from rag import rag_chain
|
23 |
|
24 |
load_dotenv()
|
25 |
|
|
|
|
|
|
|
26 |
|
27 |
class LangChainRAG:
|
28 |
embedding: Embedding
|
@@ -89,24 +93,23 @@ class LangChainRAG:
|
|
89 |
|
90 |
self.retriever = ContextualCompressionRetriever(
|
91 |
base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
|
92 |
-
base_retriever=self.
|
93 |
search_type="similarity",
|
94 |
search_kwargs={"k": config["retrieve_documents"] * 10}
|
95 |
)
|
96 |
)
|
97 |
|
98 |
-
def get_vector_store_mongodb(self):
|
99 |
-
return self.embedding[0].get_vector_store()
|
100 |
-
|
101 |
def get_llms(self):
|
102 |
return self.llms.keys()
|
103 |
|
104 |
-
def rag_chain(self, query,
|
105 |
-
|
|
|
106 |
# answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
|
107 |
# answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
|
|
|
108 |
answer, check_result, context_doc = rag_chain(
|
109 |
-
Agent(embedding=self.embedding
|
110 |
query,
|
111 |
self.config["retrieve_documents"],
|
112 |
self.prompt_store.get_by_name(self.config["prompt_id"]).text,
|
|
|
4 |
from gptcache import Cache
|
5 |
from gptcache.manager.factory import manager_factory
|
6 |
from gptcache.processor.pre import get_prompt
|
7 |
+
from langchain.globals import set_debug
|
8 |
from langchain.retrievers import ContextualCompressionRetriever
|
9 |
from langchain_cohere import CohereRerank, CohereEmbeddings
|
10 |
from langchain_community.cache import GPTCache
|
|
|
20 |
from emdedd.Embedding import Embedding
|
21 |
from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
|
22 |
from prompt.prompt_store import PromptStore
|
23 |
+
from rag import vanilla_rag_chain, rag_chain
|
24 |
|
25 |
load_dotenv()
|
26 |
|
27 |
+
# set_verbose(True)
|
28 |
+
set_debug(True)
|
29 |
+
|
30 |
|
31 |
class LangChainRAG:
|
32 |
embedding: Embedding
|
|
|
93 |
|
94 |
self.retriever = ContextualCompressionRetriever(
|
95 |
base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
|
96 |
+
base_retriever=self.embedding.get_vector_store().as_retriever(
|
97 |
search_type="similarity",
|
98 |
search_kwargs={"k": config["retrieve_documents"] * 10}
|
99 |
)
|
100 |
)
|
101 |
|
|
|
|
|
|
|
102 |
def get_llms(self):
|
103 |
return self.llms.keys()
|
104 |
|
105 |
+
def rag_chain(self, query, llm_choice):
|
106 |
+
print("Using " + llm_choice)
|
107 |
+
|
108 |
# answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
|
109 |
# answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
|
110 |
+
# answer, check_result, context_doc = vanilla_rag_chain(
|
111 |
answer, check_result, context_doc = rag_chain(
|
112 |
+
Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
|
113 |
query,
|
114 |
self.config["retrieve_documents"],
|
115 |
self.prompt_store.get_by_name(self.config["prompt_id"]).text,
|
requirements.txt
CHANGED
@@ -6,6 +6,7 @@ langchain-mistralai
|
|
6 |
langchain-cohere
|
7 |
langchain-google-genai
|
8 |
langchain-together
|
|
|
9 |
|
10 |
fitz
|
11 |
pypdf
|
|
|
6 |
langchain-cohere
|
7 |
langchain-google-genai
|
8 |
langchain-together
|
9 |
+
langchain-mongodb
|
10 |
|
11 |
fitz
|
12 |
pypdf
|