Pavol Liška commited on
Commit
593b823
·
1 Parent(s): 3c35194
Files changed (6) hide show
  1. agent/Agent.py +0 -2
  2. api.py +39 -18
  3. emdedd/MongoEmbedding.py +4 -7
  4. rag.py +53 -28
  5. rag_langchain.py +11 -8
  6. requirements.txt +1 -0
agent/Agent.py CHANGED
@@ -1,5 +1,3 @@
1
- from langchain.embeddings import CacheBackedEmbeddings
2
- from langchain.storage import LocalFileStore
3
  from langchain_core.language_models import BaseChatModel
4
 
5
  from emdedd.Embedding import Embedding
 
 
 
1
  from langchain_core.language_models import BaseChatModel
2
 
3
  from emdedd.Embedding import Embedding
api.py CHANGED
@@ -1,6 +1,8 @@
1
  from fastapi import FastAPI, Response, Body, Security
2
  from fastapi.security import APIKeyHeader
3
- from pydantic import BaseModel
 
 
4
 
5
  from conversation.conversation_store import ConversationStore
6
  from rag_langchain import LangChainRAG
@@ -22,37 +24,58 @@ class QModel(BaseModel):
22
  temperature: str = "0.2"
23
  llm: str = default_llm
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class EmoModel(BaseModel):
27
  qid: str
28
  helpfulness: str
29
 
 
 
 
 
 
 
 
30
 
31
  @api.get("/")
32
- def read_root():
33
  return "Empty"
34
 
35
 
36
- @api.post("/q")
37
- async def q(api_key: str = Security(api_key_header), json_body: QModel = Body(...)):
38
  # Verify the API key
39
  if not valid_api_key(api_key):
40
  return Response(status_code=401)
41
 
42
  rag = LangChainRAG(
43
  config={
44
- "retrieve_documents": json_body.retrieval_count,
45
- "temperature": json_body.temperature,
46
  "prompt_id": prompt_id,
47
  "check_prompt_id": check_prompt_id,
48
  "rewrite_prompt_id": rewrite_prompt_id
49
  }
50
  )
51
 
52
- answer, check_result, sources = rag.rag_chain(json_body.q, json_body.llm)
53
 
54
  oid = conversation_store.save_content(
55
- q=q,
56
  a=answer,
57
  sources=list(map(lambda doc: doc.page_content, sources)),
58
  params=
@@ -61,19 +84,17 @@ async def q(api_key: str = Security(api_key_header), json_body: QModel = Body(..
61
  "check_prompt_id": check_prompt_id,
62
  "rewrite_prompt_id": rewrite_prompt_id,
63
  "check_result": check_result,
64
- "temperature": json_body.temperature,
65
- "retrieve_document_count": json_body.retrieval_count,
66
  }
67
  )
68
 
69
- return Response(
70
- status_code=200,
71
- content={
72
- "response": answer,
73
- "sources": list(map(lambda doc: doc.page_content, sources)),
74
- "qid": oid
75
- }
76
- )
77
 
78
 
79
  @api.post("/emo")
 
1
  from fastapi import FastAPI, Response, Body, Security
2
  from fastapi.security import APIKeyHeader
3
+ from pydantic import BaseModel, model_validator
4
+ from typing import List
5
+ import json
6
 
7
  from conversation.conversation_store import ConversationStore
8
  from rag_langchain import LangChainRAG
 
24
  temperature: str = "0.2"
25
  llm: str = default_llm
26
 
27
+ @classmethod
28
+ @model_validator(mode='before')
29
+ def validate_to_json(cls, value):
30
+ if isinstance(value, str):
31
+ return cls(**json.loads(value))
32
+ return value
33
+
34
+
35
+ class AModel(BaseModel):
36
+ q: str
37
+ a: str
38
+ sources: List[str]
39
+ oid: str
40
+
41
 
42
  class EmoModel(BaseModel):
43
  qid: str
44
  helpfulness: str
45
 
46
+ @classmethod
47
+ @model_validator(mode='before')
48
+ def validate_to_json(cls, value):
49
+ if isinstance(value, str):
50
+ return cls(**json.loads(value))
51
+ return value
52
+
53
 
54
  @api.get("/")
55
+ async def read_root():
56
  return "Empty"
57
 
58
 
59
+ @api.post("/qa", response_model=AModel)
60
+ async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)):
61
  # Verify the API key
62
  if not valid_api_key(api_key):
63
  return Response(status_code=401)
64
 
65
  rag = LangChainRAG(
66
  config={
67
+ "retrieve_documents": data.retrieval_count,
68
+ "temperature": data.temperature,
69
  "prompt_id": prompt_id,
70
  "check_prompt_id": check_prompt_id,
71
  "rewrite_prompt_id": rewrite_prompt_id
72
  }
73
  )
74
 
75
+ answer, check_result, sources = rag.rag_chain(data.q, data.llm)
76
 
77
  oid = conversation_store.save_content(
78
+ q=data.q,
79
  a=answer,
80
  sources=list(map(lambda doc: doc.page_content, sources)),
81
  params=
 
84
  "check_prompt_id": check_prompt_id,
85
  "rewrite_prompt_id": rewrite_prompt_id,
86
  "check_result": check_result,
87
+ "temperature": data.temperature,
88
+ "retrieve_document_count": data.retrieval_count,
89
  }
90
  )
91
 
92
+ return AModel(
93
+ a=answer,
94
+ q=data.q,
95
+ sources=list(map(lambda doc: doc.page_content, sources)),
96
+ oid=oid
97
+ )
 
 
98
 
99
 
100
  @api.post("/emo")
emdedd/MongoEmbedding.py CHANGED
@@ -1,12 +1,11 @@
1
  from dataclasses import dataclass
2
 
 
3
  from langchain.embeddings import CacheBackedEmbeddings
4
- from langchain.storage import LocalFileStore
5
- from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
6
  from langchain_core.embeddings import Embeddings
7
  from langchain_core.stores import InMemoryStore
 
8
  from pymongo import MongoClient
9
- from bson.objectid import ObjectId
10
 
11
  from emdedd.Embedding import Embedding
12
 
@@ -83,9 +82,7 @@ class MongoEmbedding(Embedding):
83
  )
84
 
85
  def search(self, query, search_type, doc_count):
86
- vector_store = self.get_vector_store()
87
- retriever = vector_store.as_retriever(
88
  search_type="similarity",
89
  search_kwargs={"k": doc_count}
90
- )
91
- return retriever.get_relevant_documents(query=query)
 
1
  from dataclasses import dataclass
2
 
3
+ from bson.objectid import ObjectId
4
  from langchain.embeddings import CacheBackedEmbeddings
 
 
5
  from langchain_core.embeddings import Embeddings
6
  from langchain_core.stores import InMemoryStore
7
+ from langchain_mongodb import MongoDBAtlasVectorSearch
8
  from pymongo import MongoClient
 
9
 
10
  from emdedd.Embedding import Embedding
11
 
 
82
  )
83
 
84
  def search(self, query, search_type, doc_count):
85
+ return self.get_vector_store().as_retriever(
 
86
  search_type="similarity",
87
  search_kwargs={"k": doc_count}
88
+ ).get_relevant_documents(query=query)
 
rag.py CHANGED
@@ -10,7 +10,7 @@ from langchain.chains.retrieval import create_retrieval_chain
10
  from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever
11
  from langchain_cohere import CohereRerank
12
  from langchain_core.documents import Document
13
- from langchain_core.prompts import PromptTemplate
14
 
15
  from agent.Agent import Agent
16
  from agent.agents import chat_openai_llm, deepinfra_chat
@@ -23,12 +23,12 @@ load_dotenv()
23
  conversation_store = ConversationStore()
24
  prompt_store = PromptStore()
25
 
26
- grammar_check_1 = prompt_store.get_by_name("gramar_check_1")
27
- rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1")
28
- rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2")
29
- rewrite_1 = prompt_store.get_by_name("rewrite_1")
30
- rewrite_2 = prompt_store.get_by_name("rewrite_2")
31
- rewrite_hyde = prompt_store.get_by_name("rewrite_hyde")
32
 
33
 
34
  def replace_nl(input: str) -> str:
@@ -52,26 +52,6 @@ def rewrite(agent: Agent, q: str, prompt: str) -> list[str]:
52
  return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)]
53
 
54
 
55
- def rag_with_rerank_check_rewrite(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str,
56
- rewrite_prompt: str):
57
- rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt)
58
-
59
- if len(rewritten_list) == 0:
60
- return "Neviem, nemám podklady!", "", ""
61
-
62
- context_doc = retrieve_subqueries(agent, retrieve_document_count, rewritten_list)
63
-
64
- if len(context_doc) == 0:
65
- return "Neviem, nemám kontext!", "", ""
66
-
67
- result = answer_pipeline(agent, context_doc, prompt, q)
68
- answer = result["text"]
69
-
70
- check_result = check_pipeline(answer, check_prompt, context_doc, q)
71
-
72
- return answer, check_result, context_doc
73
-
74
-
75
  def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
76
  check_prompt: str,
77
  rewrite_prompt: str):
@@ -120,7 +100,48 @@ def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
120
  prompt=PromptTemplate(
121
  input_variables=["context", "question", "actual_date"],
122
  template=prompt
123
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  )
125
  ).invoke(
126
  input={
@@ -130,8 +151,12 @@ def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
130
  }
131
  )
132
 
 
 
133
  check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
134
 
 
 
135
  return result["answer"], check_result, result["context"]
136
 
137
 
 
10
  from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever
11
  from langchain_cohere import CohereRerank
12
  from langchain_core.documents import Document
13
+ from langchain_core.prompts import PromptTemplate, BasePromptTemplate
14
 
15
  from agent.Agent import Agent
16
  from agent.agents import chat_openai_llm, deepinfra_chat
 
23
  conversation_store = ConversationStore()
24
  prompt_store = PromptStore()
25
 
26
+ grammar_check_1 = prompt_store.get_by_name("gramar_check_1").text
27
+ rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1").text
28
+ rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2").text
29
+ rewrite_1 = prompt_store.get_by_name("rewrite_1").text
30
+ rewrite_2 = prompt_store.get_by_name("rewrite_2").text
31
+ rewrite_hyde = prompt_store.get_by_name("rewrite_hyde").text
32
 
33
 
34
  def replace_nl(input: str) -> str:
 
52
  return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)]
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
56
  check_prompt: str,
57
  rewrite_prompt: str):
 
100
  prompt=PromptTemplate(
101
  input_variables=["context", "question", "actual_date"],
102
  template=prompt
103
+ ),
104
+ document_prompt=PromptTemplate(input_variables=[], template="page_content")
105
+ )
106
+ ).invoke(
107
+ input={
108
+ "question": q,
109
+ "input": q,
110
+ "actual_date": datetime.date.today().isoformat()
111
+ }
112
+ )
113
+
114
+ print(result)
115
+
116
+ check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
117
+
118
+ print(check_result)
119
+
120
+ return result["answer"], check_result, result["context"]
121
+
122
+
123
+ def vanilla_rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
124
+ check_prompt: str):
125
+ retriever = ContextualCompressionRetriever(
126
+ base_compressor=(CohereRerank(
127
+ model="rerank-multilingual-v3.0",
128
+ top_n=retrieve_document_count
129
+ )),
130
+ base_retriever=(agent.embedding.get_vector_store().as_retriever(
131
+ search_type="similarity",
132
+ search_kwargs={"k": min(retrieve_document_count * 10, 500)},
133
+ ))
134
+ )
135
+
136
+ result = create_retrieval_chain(
137
+ retriever=retriever,
138
+ combine_docs_chain=create_stuff_documents_chain(
139
+ llm=agent.llm,
140
+ prompt=PromptTemplate(
141
+ input_variables=["context", "question", "actual_date"],
142
+ template=prompt
143
+ ),
144
+ document_prompt=PromptTemplate(input_variables=[], template="page_content")
145
  )
146
  ).invoke(
147
  input={
 
151
  }
152
  )
153
 
154
+ print(result)
155
+
156
  check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
157
 
158
+ print(check_result)
159
+
160
  return result["answer"], check_result, result["context"]
161
 
162
 
rag_langchain.py CHANGED
@@ -4,6 +4,7 @@ from dotenv import load_dotenv
4
  from gptcache import Cache
5
  from gptcache.manager.factory import manager_factory
6
  from gptcache.processor.pre import get_prompt
 
7
  from langchain.retrievers import ContextualCompressionRetriever
8
  from langchain_cohere import CohereRerank, CohereEmbeddings
9
  from langchain_community.cache import GPTCache
@@ -19,10 +20,13 @@ from agent.agents import deepinfra_chat, \
19
  from emdedd.Embedding import Embedding
20
  from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
21
  from prompt.prompt_store import PromptStore
22
- from rag import rag_chain
23
 
24
  load_dotenv()
25
 
 
 
 
26
 
27
  class LangChainRAG:
28
  embedding: Embedding
@@ -89,24 +93,23 @@ class LangChainRAG:
89
 
90
  self.retriever = ContextualCompressionRetriever(
91
  base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
92
- base_retriever=self.get_vector_store_mongodb().as_retriever(
93
  search_type="similarity",
94
  search_kwargs={"k": config["retrieve_documents"] * 10}
95
  )
96
  )
97
 
98
- def get_vector_store_mongodb(self):
99
- return self.embedding[0].get_vector_store()
100
-
101
  def get_llms(self):
102
  return self.llms.keys()
103
 
104
- def rag_chain(self, query, choice):
105
- # answer, check_result, context_doc = rag_with_rerank_check_rewrite(
 
106
  # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
107
  # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
 
108
  answer, check_result, context_doc = rag_chain(
109
- Agent(embedding=self.embedding[0], llm=self.llms[choice]),
110
  query,
111
  self.config["retrieve_documents"],
112
  self.prompt_store.get_by_name(self.config["prompt_id"]).text,
 
4
  from gptcache import Cache
5
  from gptcache.manager.factory import manager_factory
6
  from gptcache.processor.pre import get_prompt
7
+ from langchain.globals import set_debug
8
  from langchain.retrievers import ContextualCompressionRetriever
9
  from langchain_cohere import CohereRerank, CohereEmbeddings
10
  from langchain_community.cache import GPTCache
 
20
  from emdedd.Embedding import Embedding
21
  from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
22
  from prompt.prompt_store import PromptStore
23
+ from rag import vanilla_rag_chain, rag_chain
24
 
25
  load_dotenv()
26
 
27
+ # set_verbose(True)
28
+ set_debug(True)
29
+
30
 
31
  class LangChainRAG:
32
  embedding: Embedding
 
93
 
94
  self.retriever = ContextualCompressionRetriever(
95
  base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
96
+ base_retriever=self.embedding.get_vector_store().as_retriever(
97
  search_type="similarity",
98
  search_kwargs={"k": config["retrieve_documents"] * 10}
99
  )
100
  )
101
 
 
 
 
102
  def get_llms(self):
103
  return self.llms.keys()
104
 
105
+ def rag_chain(self, query, llm_choice):
106
+ print("Using " + llm_choice)
107
+
108
  # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
109
  # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
110
+ # answer, check_result, context_doc = vanilla_rag_chain(
111
  answer, check_result, context_doc = rag_chain(
112
+ Agent(embedding=self.embedding, llm=self.llms[llm_choice]),
113
  query,
114
  self.config["retrieve_documents"],
115
  self.prompt_store.get_by_name(self.config["prompt_id"]).text,
requirements.txt CHANGED
@@ -6,6 +6,7 @@ langchain-mistralai
6
  langchain-cohere
7
  langchain-google-genai
8
  langchain-together
 
9
 
10
  fitz
11
  pypdf
 
6
  langchain-cohere
7
  langchain-google-genai
8
  langchain-together
9
+ langchain-mongodb
10
 
11
  fitz
12
  pypdf