Pavol Liška commited on
Commit
ae95c3d
·
1 Parent(s): 869eb7d
Files changed (4) hide show
  1. agent/agents.py +0 -2
  2. rag.py +0 -28
  3. rag_langchain.py +3 -3
  4. task_splitting.py +0 -101
agent/agents.py CHANGED
@@ -33,8 +33,6 @@ def cohere_llm():
33
  model="command-r-plus",
34
  max_tokens=2048,
35
  temperature=os.environ["temperature"],
36
- # p=os.environ["top_p"],
37
- # frequency_penalty=os.environ["frequency_penalty"],
38
  )
39
 
40
 
 
33
  model="command-r-plus",
34
  max_tokens=2048,
35
  temperature=os.environ["temperature"],
 
 
36
  )
37
 
38
 
rag.py CHANGED
@@ -35,34 +35,6 @@ def replace_nl(input: str) -> str:
35
  return input.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>')
36
 
37
 
38
- def rag(agent: Agent, q: str, retrieve_document_count: int):
39
- k = retrieve_document_count
40
-
41
- context_doc = retrieve(agent.embedding, q, k)
42
-
43
- prompt_template = PromptTemplate(
44
- input_variables=["context", "question"],
45
- template=os.environ["RAG_TEMPLATE"]
46
- )
47
-
48
- llm_chain = LLMChain(
49
- llm=agent.llm,
50
- prompt=prompt_template,
51
- verbose=False
52
- )
53
-
54
- # llm_chain = prompt_template | agent.llm
55
-
56
- result: dict[str, Any] = llm_chain.invoke(
57
- input={
58
- "question": q,
59
- "context": context_doc
60
- }
61
- )
62
-
63
- return result["text"]
64
-
65
-
66
  def rewrite(agent: Agent, q: str, prompt: str) -> list[str]:
67
  prompt_template = PromptTemplate(
68
  input_variables=["question"],
 
35
  return input.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>')
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def rewrite(agent: Agent, q: str, prompt: str) -> list[str]:
39
  prompt_template = PromptTemplate(
40
  input_variables=["question"],
rag_langchain.py CHANGED
@@ -5,7 +5,7 @@ from gptcache import Cache
5
  from gptcache.manager.factory import manager_factory
6
  from gptcache.processor.pre import get_prompt
7
  from langchain.retrievers import ContextualCompressionRetriever
8
- from langchain_cohere import CohereRerank
9
  from langchain_community.cache import GPTCache
10
  from langchain_core.language_models import BaseChatModel
11
  from langchain_core.prompts import PromptTemplate
@@ -17,7 +17,7 @@ from agent.Agent import Agent
17
  from agent.agents import deepinfra_chat, \
18
  together_ai_chat, groq_chat, cohere_llm
19
  from emdedd.Embedding import Embedding
20
- from emdedd.embeddings import chroma_embedding, cohere_embeddings
21
  from prompt.prompt_store import PromptStore
22
  from rag import rag_chain
23
 
@@ -25,7 +25,7 @@ load_dotenv()
25
 
26
 
27
  class LangChainRAG:
28
- embedding: tuple[Embedding]
29
  llms: dict[str, BaseChatModel]
30
  retriever: BaseRetriever
31
  prompt_template: PromptTemplate
 
5
  from gptcache.manager.factory import manager_factory
6
  from gptcache.processor.pre import get_prompt
7
  from langchain.retrievers import ContextualCompressionRetriever
8
+ from langchain_cohere import CohereRerank, CohereEmbeddings
9
  from langchain_community.cache import GPTCache
10
  from langchain_core.language_models import BaseChatModel
11
  from langchain_core.prompts import PromptTemplate
 
17
  from agent.agents import deepinfra_chat, \
18
  together_ai_chat, groq_chat, cohere_llm
19
  from emdedd.Embedding import Embedding
20
+ from emdedd.MongoEmbedding import EmbeddingDbConnection, MongoEmbedding
21
  from prompt.prompt_store import PromptStore
22
  from rag import rag_chain
23
 
 
25
 
26
 
27
  class LangChainRAG:
28
+ embedding: Embedding
29
  llms: dict[str, BaseChatModel]
30
  retriever: BaseRetriever
31
  prompt_template: PromptTemplate
task_splitting.py DELETED
@@ -1,101 +0,0 @@
1
- import datetime
2
- from time import sleep
3
-
4
- from langchain.chains import LLMChain
5
- from langchain_core.prompts import PromptTemplate
6
-
7
- from agent.Agent import Agent
8
- from agent.agents import chat_groq_llama3_70
9
- from emdedd.embeddings import cohere_embeddings, chroma_embedding, embed_zakonnik_prace
10
- from promts import for_tree_llama3_rag_sub, for_tree_llama3_rag_tree, for_tree_llama3_rag_group
11
- from retrieval import retrieve_with_rerank
12
- from questions import questions
13
-
14
-
15
- def rag_tree(agent: Agent, q: str, retrieve_document_count: int) -> str:
16
- tree_template = PromptTemplate(
17
- input_variables=["context", "question"],
18
- template=for_tree_llama3_rag_tree
19
- )
20
-
21
- context_doc = retrieve_with_rerank(agent.embedding, q, retrieve_document_count * 2)
22
-
23
- sub_qs = LLMChain(
24
- llm=agent.llm,
25
- prompt=tree_template,
26
- verbose=False
27
- ).invoke(
28
- input={
29
- "question": q,
30
- "context": context_doc
31
- }
32
- )["text"]
33
-
34
- print(sub_qs)
35
- sleep(60)
36
-
37
- print("_________")
38
-
39
- sub_template = PromptTemplate(
40
- input_variables=["context", "question"],
41
- template=for_tree_llama3_rag_sub
42
- )
43
-
44
- sub_answers: dict[str, str] = {}
45
-
46
- for sub_q in sub_qs.splitlines():
47
- if "?" not in sub_q: continue
48
- print(sub_q)
49
- sub_answers[sub_q] = LLMChain(
50
- llm=agent.llm,
51
- prompt=sub_template,
52
- verbose=False
53
- ).invoke(
54
- input={
55
- "question": sub_q,
56
- "context": retrieve_with_rerank(agent.embedding, sub_q, retrieve_document_count)
57
- }
58
- )["text"]
59
- print(sub_answers[sub_q])
60
- sleep(60)
61
-
62
-
63
- final_template = PromptTemplate(
64
- input_variables=["context", "question", "subs"],
65
- template=for_tree_llama3_rag_group
66
- )
67
-
68
- result = LLMChain(
69
- llm=agent.llm,
70
- prompt=final_template,
71
- verbose=True
72
- ).invoke(
73
- input={
74
- "question": q,
75
- "context": context_doc,
76
- "subs": sub_answers.items()
77
- }
78
- )
79
-
80
- return result["text"]
81
-
82
-
83
- def tree_of_thought(name: str, agent: Agent, emded: bool = False, retrieve_document_count=5):
84
- try:
85
- result_file = open(name + "_test.md", "a")
86
- if emded:
87
- embed_zakonnik_prace(agent.embedding)
88
- for q in questions:
89
- print("--- Q: " + q)
90
- result_file.write("\n\n| " + name + str(datetime.datetime.now()) + " | " + q + " |")
91
- result_file.write("\n|-------|-----------|")
92
-
93
- answer = rag_tree(agent, q, retrieve_document_count)
94
- print(answer)
95
- result_file.write(
96
- "\n| tree | " + answer.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>') + " |")
97
- sleep(60)
98
- finally:
99
- result_file.write("\n\n")
100
- result_file.flush()
101
- result_file.close()