danicafisher commited on
Commit
37cf481
·
2 Parent(s): 72762da a57d32e

Merge branch 'main' of https://huggingface.co/spaces/CoExperiences/aie4-final

Browse files
.gitignore CHANGED
@@ -1 +1,2 @@
1
  .env
 
 
1
  .env
2
+ /__pycache__
__pycache__/constants.cpython-311.pyc CHANGED
Binary files a/__pycache__/constants.cpython-311.pyc and b/__pycache__/constants.cpython-311.pyc differ
 
__pycache__/models.cpython-311.pyc CHANGED
Binary files a/__pycache__/models.cpython-311.pyc and b/__pycache__/models.cpython-311.pyc differ
 
load_existing_docs.py CHANGED
@@ -6,12 +6,7 @@ from langchain_community.document_loaders import PyPDFLoader, UnstructuredURLLoa
6
  from qdrant_client.http.models import VectorParams
7
  import pymupdf
8
  import requests
9
-
10
- #qdrant = QdrantVectorStore.from_existing_collection(
11
- # embedding=models.basic_embeddings,
12
- # collection_name="kai_test_documents",
13
- # url=constants.QDRANT_ENDPOINT,
14
- #)
15
 
16
  def extract_links_from_pdf(pdf_path):
17
  links = []
@@ -78,26 +73,22 @@ for link in unique_links:
78
 
79
 
80
  #print(len(documents))
81
- semantic_split_docs = models.semanticChunker.split_documents(documents)
82
- RCTS_split_docs = models.RCTS.split_documents(documents)
83
-
84
-
85
- #for file in filepaths:
86
- # loader = PyPDFLoader(file)
87
- # documents = loader.load()
88
- # for doc in documents:
89
- # doc.metadata = {
90
- # "source": file,
91
- # "tag": "employee" if "employee" in file.lower() else "employer"
92
- # }
93
- # all_documents.extend(documents)
94
-
95
- #chunk them
96
- #semantic_split_docs = models.semanticChunker.split_documents(all_documents)
97
-
98
-
99
  #add them to the existing qdrant client
100
- collection_name = "docs_from_ripped_urls_recursive"
101
 
102
  collections = models.qdrant_client.get_collections()
103
  collection_names = [collection.name for collection in collections.collections]
@@ -105,16 +96,16 @@ collection_names = [collection.name for collection in collections.collections]
105
  if collection_name not in collection_names:
106
  models.qdrant_client.create_collection(
107
  collection_name=collection_name,
108
- vectors_config=VectorParams(size=1536, distance="Cosine")
109
  )
110
 
111
  qdrant_vector_store = QdrantVectorStore(
112
  client=models.qdrant_client,
113
  collection_name=collection_name,
114
- embedding=models.te3_small
115
  )
116
 
117
- qdrant_vector_store.add_documents(RCTS_split_docs)
118
 
119
 
120
 
 
6
  from qdrant_client.http.models import VectorParams
7
  import pymupdf
8
  import requests
9
+ from transformers import AutoTokenizer
 
 
 
 
 
10
 
11
  def extract_links_from_pdf(pdf_path):
12
  links = []
 
73
 
74
 
75
  #print(len(documents))
76
+ #semantic_split_docs = models.semanticChunker.split_documents(documents)
77
+ semantic_tuned_split_docs = models.semanticChunker_tuned.split_documents(documents)
78
+ #RCTS_split_docs = models.RCTS.split_documents(documents)
79
+ #print(len(semantic_split_docs))
80
+ print(len(semantic_tuned_split_docs))
81
+ #tokenizer = models.tuned_embeddings.client.tokenizer
82
+ #
83
+ #token_sizes = [len(tokenizer.encode(chunk)) for chunk in semantic_tuned_split_docs]
84
+
85
+ # Display the token sizes
86
+ #for idx, size in enumerate(token_sizes):
87
+ # print(f"Chunk {idx + 1}: {size} tokens")
88
+ #
89
+ #exit()
 
 
 
 
90
  #add them to the existing qdrant client
91
+ collection_name = "docs_from_ripped_urls_semantic_tuned"
92
 
93
  collections = models.qdrant_client.get_collections()
94
  collection_names = [collection.name for collection in collections.collections]
 
96
  if collection_name not in collection_names:
97
  models.qdrant_client.create_collection(
98
  collection_name=collection_name,
99
+ vectors_config=VectorParams(size=1024, distance="Cosine")
100
  )
101
 
102
  qdrant_vector_store = QdrantVectorStore(
103
  client=models.qdrant_client,
104
  collection_name=collection_name,
105
+ embedding=models.tuned_embeddings
106
  )
107
 
108
+ qdrant_vector_store.add_documents(semantic_tuned_split_docs)
109
 
110
 
111
 
models.py CHANGED
@@ -5,9 +5,11 @@ from langchain.callbacks.tracers import LangChainTracer
5
  from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
  from langchain_experimental.text_splitter import SemanticChunker
7
  from langchain_openai.embeddings import OpenAIEmbeddings
8
- from langchain_community.vectorstores import Qdrant
 
9
  from qdrant_client import QdrantClient
10
  from langchain_text_splitters import RecursiveCharacterTextSplitter
 
11
  import constants
12
  import os
13
 
@@ -18,7 +20,9 @@ os.environ["LANGCHAIN_ENDPOINT"] = constants.LANGCHAIN_ENDPOINT
18
  tracer = LangChainTracer()
19
  callback_manager = CallbackManager([tracer])
20
 
21
- qdrant_client = QdrantClient(url=constants.QDRANT_ENDPOINT, api_key=constants.QDRANT_API_KEY)
 
 
22
 
23
  opus3 = ChatAnthropic(
24
  api_key=constants.ANTRHOPIC_API_KEY,
@@ -65,22 +69,68 @@ gpt4o_mini = ChatOpenAI(
65
  callbacks=callback_manager
66
  )
67
 
 
 
 
 
68
  basic_embeddings = HuggingFaceEmbeddings(model_name="snowflake/snowflake-arctic-embed-l")
69
- #hkunlp_instructor_large = HuggingFaceInstructEmbeddings(
70
- # model_name = "hkunlp/instructor-large",
71
- # query_instruction="Represent the query for retrieval: "
72
- #)
73
 
74
  te3_small = OpenAIEmbeddings(api_key=constants.OPENAI_API_KEY, model="text-embedding-3-small")
75
 
 
 
 
 
76
  semanticChunker = SemanticChunker(
77
  te3_small,
78
  breakpoint_threshold_type="percentile"
79
  )
80
 
 
 
 
 
 
 
81
  RCTS = RecursiveCharacterTextSplitter(
82
  # Set a really small chunk size, just to show.
83
  chunk_size=500,
84
  chunk_overlap=25,
85
  length_function=len,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  )
 
5
  from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
  from langchain_experimental.text_splitter import SemanticChunker
7
  from langchain_openai.embeddings import OpenAIEmbeddings
8
+ from langchain_qdrant import QdrantVectorStore, Qdrant
9
+ from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
10
  from qdrant_client import QdrantClient
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from langchain_cohere import CohereRerank
13
  import constants
14
  import os
15
 
 
20
  tracer = LangChainTracer()
21
  callback_manager = CallbackManager([tracer])
22
 
23
+ ########################
24
+ ### Chat Models ###
25
+ ########################
26
 
27
  opus3 = ChatAnthropic(
28
  api_key=constants.ANTRHOPIC_API_KEY,
 
69
  callbacks=callback_manager
70
  )
71
 
72
+ ########################
73
+ ### Embedding Models ###
74
+ ########################
75
+
76
  basic_embeddings = HuggingFaceEmbeddings(model_name="snowflake/snowflake-arctic-embed-l")
77
+
78
+ tuned_embeddings = HuggingFaceEmbeddings(model_name="CoExperiences/snowflake-l-marketing-tuned")
 
 
79
 
80
  te3_small = OpenAIEmbeddings(api_key=constants.OPENAI_API_KEY, model="text-embedding-3-small")
81
 
82
+ #######################
83
+ ### Text Splitters ###
84
+ #######################
85
+
86
  semanticChunker = SemanticChunker(
87
  te3_small,
88
  breakpoint_threshold_type="percentile"
89
  )
90
 
91
+ semanticChunker_tuned = SemanticChunker(
92
+ tuned_embeddings,
93
+ breakpoint_threshold_type="percentile",
94
+ breakpoint_threshold_amount=85
95
+ )
96
+
97
  RCTS = RecursiveCharacterTextSplitter(
98
  # Set a really small chunk size, just to show.
99
  chunk_size=500,
100
  chunk_overlap=25,
101
  length_function=len,
102
+ )
103
+
104
+ #######################
105
+ ### Vector Stores ###
106
+ #######################
107
+
108
+ qdrant_client = QdrantClient(url=constants.QDRANT_ENDPOINT, api_key=constants.QDRANT_API_KEY)
109
+
110
+ semantic_Qdrant_vs = QdrantVectorStore(
111
+ client=qdrant_client,
112
+ collection_name="docs_from_ripped_urls",
113
+ embedding=te3_small
114
+ )
115
+
116
+ rcts_Qdrant_vs = QdrantVectorStore(
117
+ client=qdrant_client,
118
+ collection_name="docs_from_ripped_urls_recursive",
119
+ embedding=te3_small
120
+ )
121
+
122
+ semantic_tuned_Qdrant_vs = QdrantVectorStore(
123
+ client=qdrant_client,
124
+ collection_name="docs_from_ripped_urls_semantic_tuned",
125
+ embedding=tuned_embeddings
126
+ )
127
+
128
+ #######################
129
+ ### Retrievers ###
130
+ #######################
131
+ semantic_tuned_retriever = semantic_tuned_Qdrant_vs.as_retriever(search_kwargs={"k" : 10})
132
+
133
+ compressor = CohereRerank(model="rerank-english-v3.0")
134
+ compression_retriever = ContextualCompressionRetriever(
135
+ base_compressor=compressor, base_retriever=semantic_tuned_retriever
136
  )
multiagent.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Change to requirements caller
2
+ import sys
3
+ import subprocess
4
+
5
+ def run_pip_install():
6
+ packages = [
7
+ "langgraph",
8
+ "langchain",
9
+ "langchain_openai",
10
+ "langchain_experimental",
11
+ "qdrant-client",
12
+ "pymupdf",
13
+ "tiktoken",
14
+ "huggingface_hub",
15
+ "openai",
16
+ "tavily-python"
17
+ ]
18
+
19
+ package_string = " ".join(packages)
20
+
21
+ try:
22
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-qU"] + packages)
23
+ print("All required packages have been installed successfully.")
24
+ except subprocess.CalledProcessError:
25
+ print(f"Failed to install packages. Please run the following command manually:")
26
+ print(f"%pip install -qU {package_string}")
27
+ sys.exit(1)
28
+
29
+ # Run pip install
30
+ run_pip_install()
31
+
32
+ import os
33
+ import functools
34
+ import operator
35
+ from typing import Annotated, List, Tuple, Union, Dict, Optional
36
+ from typing_extensions import TypedDict
37
+ import uuid
38
+ from pathlib import Path
39
+
40
+ from langchain_core.tools import tool
41
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
42
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
43
+ from langchain_openai import ChatOpenAI
44
+ from langchain.agents import AgentExecutor, create_openai_functions_agent
45
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
46
+ from langchain_community.tools.tavily_search import TavilySearchResults
47
+ from langchain_community.vectorstores import Qdrant
48
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
49
+ from langchain_openai.embeddings import OpenAIEmbeddings
50
+ from langgraph.graph import END, StateGraph
51
+ from huggingface_hub import hf_hub_download
52
+
53
+ # Environment setup
54
+ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
55
+ TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY")
56
+
57
+ if not OPENAI_API_KEY:
58
+ raise ValueError("OPENAI_API_KEY not found in environment variables")
59
+ if not TAVILY_API_KEY:
60
+ raise ValueError("TAVILY_API_KEY not found in environment variables")
61
+
62
+ os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
63
+ os.environ["TAVILY_API_KEY"] = TAVILY_API_KEY
64
+
65
+ # CHANGE TO HF DIRECTORY
66
+ WORKING_DIRECTORY = Path("/tmp/content/data")
67
+ WORKING_DIRECTORY.mkdir(parents=True, exist_ok=True)
68
+
69
+ # Utility functions
70
+ def create_random_subdirectory():
71
+ random_id = str(uuid.uuid4())[:8]
72
+ subdirectory_path = WORKING_DIRECTORY / random_id
73
+ subdirectory_path.mkdir(exist_ok=True)
74
+ return subdirectory_path
75
+
76
+ def get_current_files():
77
+ try:
78
+ files = [f.relative_to(WORKING_DIRECTORY) for f in WORKING_DIRECTORY.rglob("*") if f.is_file()]
79
+ return "\n".join(str(f) for f in files) if files else "No files written."
80
+ except Exception:
81
+ return "Unable to retrieve current files."
82
+
83
+ # Document loading change to upload in HF
84
+ def fetch_hbr_article():
85
+ pdf_path = hf_hub_download(repo_id="your-username/your-repo-name", filename="murthy-loneliness.pdf")
86
+ return PyMuPDFLoader(pdf_path).load()
87
+
88
+ # Document processing
89
+ def tiktoken_len(text):
90
+ tokens = tiktoken.encoding_for_model("gpt-4o-mini").encode(text)
91
+ return len(tokens)
92
+
93
+ text_splitter = RecursiveCharacterTextSplitter(
94
+ chunk_size=300,
95
+ chunk_overlap=0,
96
+ length_function=tiktoken_len,
97
+ )
98
+
99
+ docs = fetch_hbr_article()
100
+ split_chunks = text_splitter.split_documents(docs)
101
+
102
+ # Embedding and vector store setup
103
+ embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
104
+ qdrant_vectorstore = Qdrant.from_documents(
105
+ split_chunks,
106
+ embedding_model,
107
+ location=":memory:",
108
+ collection_name="extending_context_window_llama_3",
109
+ )
110
+ qdrant_retriever = qdrant_vectorstore.as_retriever()
111
+
112
+ # RAG setup
113
+ RAG_PROMPT = """
114
+ CONTEXT:
115
+ {context}
116
+
117
+ QUERY:
118
+ {question}
119
+
120
+ You are a helpful assistant. Use the available context to answer the question. If you can't answer the question, say you don't know.
121
+ """
122
+ rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
123
+ openai_chat_model = ChatOpenAI(model="gpt-4o-mini")
124
+
125
+ rag_chain = (
126
+ {"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
127
+ | rag_prompt | openai_chat_model | StrOutputParser()
128
+ )
129
+
130
+ # Tool definitions
131
+ @tool
132
+ def create_outline(points: List[str], file_name: str) -> str:
133
+ """Create and save an outline."""
134
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
135
+ for i, point in enumerate(points):
136
+ file.write(f"{i + 1}. {point}\n")
137
+ return f"Outline saved to {file_name}"
138
+
139
+ @tool
140
+ def read_document(file_name: str, start: Optional[int] = None, end: Optional[int] = None) -> str:
141
+ """Read the specified document."""
142
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
143
+ lines = file.readlines()
144
+ if start is not None:
145
+ start = 0
146
+ return "\n".join(lines[start:end])
147
+
148
+ @tool
149
+ def write_document(content: str, file_name: str) -> str:
150
+ """Create and save a text document."""
151
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
152
+ file.write(content)
153
+ return f"Document saved to {file_name}"
154
+
155
+ @tool
156
+ def edit_document(file_name: str, inserts: Dict[int, str] = {}) -> str:
157
+ """Edit a document by inserting text at specific line numbers."""
158
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
159
+ lines = file.readlines()
160
+
161
+ sorted_inserts = sorted(inserts.items())
162
+ for line_number, text in sorted_inserts:
163
+ if 1 <= line_number <= len(lines) + 1:
164
+ lines.insert(line_number - 1, text + "\n")
165
+ else:
166
+ return f"Error: Line number {line_number} is out of range."
167
+
168
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
169
+ file.writelines(lines)
170
+ return f"Document edited and saved to {file_name}"
171
+
172
+ @tool
173
+ def retrieve_information(query: str):
174
+ """Use Retrieval Augmented Generation to retrieve information about the 'murthy-loneliness' paper."""
175
+ return rag_chain.invoke({"question": query})
176
+
177
+ # Agent creation helpers
178
+ def create_team_agent(llm, tools, system_prompt, agent_name, team_members):
179
+ return create_agent(
180
+ llm,
181
+ tools,
182
+ f"{system_prompt}\nBelow are files currently in your directory:\n{{current_files}}",
183
+ team_members
184
+ )
185
+
186
+ def create_agent_node(agent, name):
187
+ return functools.partial(agent_node, agent=agent, name=name)
188
+
189
+ def add_agent_to_graph(graph, agent_name, agent_node):
190
+ graph.add_node(agent_name, agent_node)
191
+ graph.add_edge(agent_name, "supervisor")
192
+
193
+ def create_team_supervisor(llm, team_description, team_members):
194
+ return create_team_supervisor(
195
+ llm,
196
+ f"You are a supervisor tasked with managing a conversation between the"
197
+ f" following workers: {', '.join(team_members)}. {team_description}"
198
+ f" When all workers are finished, you must respond with FINISH.",
199
+ team_members
200
+ )
201
+
202
+ def create_team_chain(graph, team_members):
203
+ return (
204
+ functools.partial(enter_chain, members=team_members)
205
+ | graph.compile()
206
+ )
207
+
208
+ # LLM setup
209
+ llm = ChatOpenAI(model="gpt-4-turbo")
210
+
211
+ # Agent creation
212
+ tavily_tool = TavilySearchResults(max_results=5)
213
+
214
+ search_agent = create_team_agent(
215
+ llm,
216
+ [tavily_tool],
217
+ "You are a research assistant who can search for up-to-date info using the tavily search engine.",
218
+ "Search",
219
+ ["Search", "PaperInformationRetriever"]
220
+ )
221
+
222
+ research_agent = create_team_agent(
223
+ llm,
224
+ [retrieve_information],
225
+ "You are a research assistant who can provide specific information on the provided paper: 'murthy-loneliness.pdf'. You must only respond with information about the paper related to the request.",
226
+ "PaperInformationRetriever",
227
+ ["Search", "PaperInformationRetriever"]
228
+ )
229
+
230
+ doc_writer_agent = create_team_agent(
231
+ llm,
232
+ [write_document, edit_document, read_document],
233
+ "You are an expert writing technical social media posts.",
234
+ "DocWriter",
235
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
236
+ )
237
+
238
+ note_taking_agent = create_team_agent(
239
+ llm,
240
+ [create_outline, read_document],
241
+ "You are an expert senior researcher tasked with writing a social media post outline and taking notes to craft a social media post.",
242
+ "NoteTaker",
243
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
244
+ )
245
+
246
+ copy_editor_agent = create_team_agent(
247
+ llm,
248
+ [write_document, edit_document, read_document],
249
+ "You are an expert copy editor who focuses on fixing grammar, spelling, and tone issues.",
250
+ "CopyEditor",
251
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
252
+ )
253
+
254
+ voice_editor_agent = create_team_agent(
255
+ llm,
256
+ [write_document, edit_document, read_document],
257
+ "You are an expert in crafting and refining the voice and tone of social media posts. You edit the document to ensure it has a consistent, professional, and engaging voice appropriate for social media platforms.",
258
+ "VoiceEditor",
259
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
260
+ )
261
+
262
+ # Node creation
263
+ search_node = create_agent_node(search_agent, "Search")
264
+ research_node = create_agent_node(research_agent, "PaperInformationRetriever")
265
+ doc_writing_node = create_agent_node(doc_writer_agent, "DocWriter")
266
+ note_taking_node = create_agent_node(note_taking_agent, "NoteTaker")
267
+ copy_editing_node = create_agent_node(copy_editor_agent, "CopyEditor")
268
+ voice_node = create_agent_node(voice_editor_agent, "VoiceEditor")
269
+
270
+ # Graph creation
271
+ research_graph = StateGraph(ResearchTeamState)
272
+ add_agent_to_graph(research_graph, "Search", search_node)
273
+ add_agent_to_graph(research_graph, "PaperInformationRetriever", research_node)
274
+
275
+ authoring_graph = StateGraph(DocWritingState)
276
+ add_agent_to_graph(authoring_graph, "DocWriter", doc_writing_node)
277
+ add_agent_to_graph(authoring_graph, "NoteTaker", note_taking_node)
278
+ add_agent_to_graph(authoring_graph, "CopyEditor", copy_editing_node)
279
+ add_agent_to_graph(authoring_graph, "VoiceEditor", voice_node)
280
+
281
+ # Supervisor creation
282
+ research_supervisor = create_team_supervisor(
283
+ llm,
284
+ "Given the following user request, determine the subject to be researched and respond with the worker to act next.",
285
+ ["Search", "PaperInformationRetriever"]
286
+ )
287
+
288
+ doc_writing_supervisor = create_team_supervisor(
289
+ llm,
290
+ "Given the following user request, determine which worker should act next. Each worker will perform a task and respond with their results and status.",
291
+ ["DocWriter", "NoteTaker", "CopyEditor", "VoiceEditor"]
292
+ )
293
+
294
+ # Graph compilation
295
+ research_graph.add_node("supervisor", research_supervisor)
296
+ research_graph.set_entry_point("supervisor")
297
+ research_chain = create_team_chain(research_graph, research_graph.nodes)
298
+
299
+ authoring_graph.add_node("supervisor", doc_writing_supervisor)
300
+ authoring_graph.set_entry_point("supervisor")
301
+ authoring_chain = create_team_chain(authoring_graph, authoring_graph.nodes)
302
+
303
+ # Meta-supervisor setup
304
+ super_graph = StateGraph(State)
305
+ super_graph.add_node("Research team", get_last_message | research_chain | join_graph)
306
+ super_graph.add_node("SocialMedia team", get_last_message | authoring_chain | join_graph)
307
+ super_graph.add_node("supervisor", supervisor_node)
308
+
309
+ super_graph.add_edge("Research team", "supervisor")
310
+ super_graph.add_edge("SocialMedia team", "supervisor")
311
+ super_graph.add_conditional_edges(
312
+ "supervisor",
313
+ lambda x: x["next"],
314
+ {
315
+ "SocialMedia team": "SocialMedia team",
316
+ "Research team": "Research team",
317
+ "FINISH": END,
318
+ },
319
+ )
320
+ super_graph.set_entry_point("supervisor")
321
+ super_graph = super_graph.compile()
322
+
323
+ # Example usage
324
+ user_input = input("Enter your request for the social media post: ")
325
+
326
+ for s in super_graph.stream(
327
+ {
328
+ "messages": [
329
+ HumanMessage(content=user_input)
330
+ ],
331
+ },
332
+ {"recursion_limit": 50},
333
+ ):
334
+ if "__end__" not in s:
335
+ print(s)
336
+ print("---")
public/logo_light.svg DELETED
requirements.txt CHANGED
@@ -41,6 +41,7 @@ langchain-core==0.3.1
41
  langchain-openai==0.2.0
42
  langchain-qdrant==0.1.4
43
  langchain-text-splitters==0.3.0
 
44
  langsmith==0.1.121
45
  Lazify==0.4.0
46
  marshmallow==3.22.0
@@ -82,6 +83,7 @@ sniffio==1.3.1
82
  SQLAlchemy==2.0.35
83
  starlette==0.27.0
84
  syncer==2.0.3
 
85
  tenacity==8.5.0
86
  tiktoken==0.7.0
87
  tomli==2.0.1
 
41
  langchain-openai==0.2.0
42
  langchain-qdrant==0.1.4
43
  langchain-text-splitters==0.3.0
44
+ langgraph
45
  langsmith==0.1.121
46
  Lazify==0.4.0
47
  marshmallow==3.22.0
 
83
  SQLAlchemy==2.0.35
84
  starlette==0.27.0
85
  syncer==2.0.3
86
+ tavily-python
87
  tenacity==8.5.0
88
  tiktoken==0.7.0
89
  tomli==2.0.1
tuning/requirements.in ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain_openai
2
+ langchain_huggingface
3
+ langchain_core==0.2.38
4
+ langchain
5
+ langchain_community
6
+ langchain-text-splitters
7
+ faiss-cpu
8
+ unstructured==0.15.7
9
+ python-pptx==1.0.2
10
+ nltk==3.9.1
11
+ pyarrow
12
+ sentence_transformers
13
+ datasets
14
+ ragas
tuning/requirements.txt ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.11
3
+ # by the following command:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ aiohappyeyeballs==2.4.3
8
+ # via aiohttp
9
+ aiohttp==3.10.10
10
+ # via
11
+ # datasets
12
+ # fsspec
13
+ # langchain
14
+ # langchain-community
15
+ aiosignal==1.3.1
16
+ # via aiohttp
17
+ annotated-types==0.7.0
18
+ # via pydantic
19
+ anyio==4.6.2.post1
20
+ # via
21
+ # httpx
22
+ # openai
23
+ appdirs==1.4.4
24
+ # via ragas
25
+ attrs==24.2.0
26
+ # via aiohttp
27
+ backoff==2.2.1
28
+ # via unstructured
29
+ beautifulsoup4==4.12.3
30
+ # via unstructured
31
+ certifi==2024.8.30
32
+ # via
33
+ # httpcore
34
+ # httpx
35
+ # requests
36
+ cffi==1.17.1
37
+ # via cryptography
38
+ chardet==5.2.0
39
+ # via unstructured
40
+ charset-normalizer==3.4.0
41
+ # via requests
42
+ click==8.1.7
43
+ # via nltk
44
+ cryptography==43.0.1
45
+ # via unstructured-client
46
+ dataclasses-json==0.6.7
47
+ # via
48
+ # langchain-community
49
+ # unstructured
50
+ datasets==3.0.1
51
+ # via
52
+ # -r requirements.in
53
+ # ragas
54
+ dill==0.3.8
55
+ # via
56
+ # datasets
57
+ # multiprocess
58
+ distro==1.9.0
59
+ # via openai
60
+ emoji==2.14.0
61
+ # via unstructured
62
+ eval-type-backport==0.2.0
63
+ # via unstructured-client
64
+ faiss-cpu==1.9.0
65
+ # via -r requirements.in
66
+ filelock==3.16.1
67
+ # via
68
+ # datasets
69
+ # huggingface-hub
70
+ # torch
71
+ # transformers
72
+ # triton
73
+ filetype==1.2.0
74
+ # via unstructured
75
+ frozenlist==1.4.1
76
+ # via
77
+ # aiohttp
78
+ # aiosignal
79
+ fsspec[http]==2024.6.1
80
+ # via
81
+ # datasets
82
+ # huggingface-hub
83
+ # torch
84
+ greenlet==3.1.1
85
+ # via sqlalchemy
86
+ h11==0.14.0
87
+ # via httpcore
88
+ httpcore==1.0.6
89
+ # via httpx
90
+ httpx==0.27.2
91
+ # via
92
+ # langsmith
93
+ # openai
94
+ # unstructured-client
95
+ huggingface-hub==0.26.0
96
+ # via
97
+ # datasets
98
+ # langchain-huggingface
99
+ # sentence-transformers
100
+ # tokenizers
101
+ # transformers
102
+ idna==3.10
103
+ # via
104
+ # anyio
105
+ # httpx
106
+ # requests
107
+ # yarl
108
+ jinja2==3.1.4
109
+ # via torch
110
+ jiter==0.6.1
111
+ # via openai
112
+ joblib==1.4.2
113
+ # via
114
+ # nltk
115
+ # scikit-learn
116
+ jsonpatch==1.33
117
+ # via langchain-core
118
+ jsonpath-python==1.0.6
119
+ # via unstructured-client
120
+ jsonpointer==3.0.0
121
+ # via jsonpatch
122
+ langchain==0.2.16
123
+ # via
124
+ # -r requirements.in
125
+ # langchain-community
126
+ # ragas
127
+ langchain-community==0.2.16
128
+ # via
129
+ # -r requirements.in
130
+ # ragas
131
+ langchain-core==0.2.38
132
+ # via
133
+ # -r requirements.in
134
+ # langchain
135
+ # langchain-community
136
+ # langchain-huggingface
137
+ # langchain-openai
138
+ # langchain-text-splitters
139
+ # ragas
140
+ langchain-huggingface==0.0.3
141
+ # via -r requirements.in
142
+ langchain-openai==0.1.23
143
+ # via
144
+ # -r requirements.in
145
+ # ragas
146
+ langchain-text-splitters==0.2.4
147
+ # via
148
+ # -r requirements.in
149
+ # langchain
150
+ langdetect==1.0.9
151
+ # via unstructured
152
+ langsmith==0.1.136
153
+ # via
154
+ # langchain
155
+ # langchain-community
156
+ # langchain-core
157
+ lxml==5.3.0
158
+ # via
159
+ # python-pptx
160
+ # unstructured
161
+ markupsafe==3.0.2
162
+ # via jinja2
163
+ marshmallow==3.23.0
164
+ # via dataclasses-json
165
+ mpmath==1.3.0
166
+ # via sympy
167
+ multidict==6.1.0
168
+ # via
169
+ # aiohttp
170
+ # yarl
171
+ multiprocess==0.70.16
172
+ # via datasets
173
+ mypy-extensions==1.0.0
174
+ # via typing-inspect
175
+ nest-asyncio==1.6.0
176
+ # via
177
+ # ragas
178
+ # unstructured-client
179
+ networkx==3.4.1
180
+ # via torch
181
+ nltk==3.9.1
182
+ # via
183
+ # -r requirements.in
184
+ # unstructured
185
+ numpy==1.26.4
186
+ # via
187
+ # datasets
188
+ # faiss-cpu
189
+ # langchain
190
+ # langchain-community
191
+ # pandas
192
+ # pyarrow
193
+ # ragas
194
+ # scikit-learn
195
+ # scipy
196
+ # transformers
197
+ # unstructured
198
+ nvidia-cublas-cu12==12.4.5.8
199
+ # via
200
+ # nvidia-cudnn-cu12
201
+ # nvidia-cusolver-cu12
202
+ # torch
203
+ nvidia-cuda-cupti-cu12==12.4.127
204
+ # via torch
205
+ nvidia-cuda-nvrtc-cu12==12.4.127
206
+ # via torch
207
+ nvidia-cuda-runtime-cu12==12.4.127
208
+ # via torch
209
+ nvidia-cudnn-cu12==9.1.0.70
210
+ # via torch
211
+ nvidia-cufft-cu12==11.2.1.3
212
+ # via torch
213
+ nvidia-curand-cu12==10.3.5.147
214
+ # via torch
215
+ nvidia-cusolver-cu12==11.6.1.9
216
+ # via torch
217
+ nvidia-cusparse-cu12==12.3.1.170
218
+ # via
219
+ # nvidia-cusolver-cu12
220
+ # torch
221
+ nvidia-nccl-cu12==2.21.5
222
+ # via torch
223
+ nvidia-nvjitlink-cu12==12.4.127
224
+ # via
225
+ # nvidia-cusolver-cu12
226
+ # nvidia-cusparse-cu12
227
+ # torch
228
+ nvidia-nvtx-cu12==12.4.127
229
+ # via torch
230
+ openai==1.52.0
231
+ # via
232
+ # langchain-openai
233
+ # ragas
234
+ orjson==3.10.7
235
+ # via langsmith
236
+ packaging==24.1
237
+ # via
238
+ # datasets
239
+ # faiss-cpu
240
+ # huggingface-hub
241
+ # langchain-core
242
+ # marshmallow
243
+ # transformers
244
+ pandas==2.2.3
245
+ # via datasets
246
+ pillow==11.0.0
247
+ # via
248
+ # python-pptx
249
+ # sentence-transformers
250
+ propcache==0.2.0
251
+ # via yarl
252
+ psutil==6.1.0
253
+ # via unstructured
254
+ pyarrow==17.0.0
255
+ # via
256
+ # -r requirements.in
257
+ # datasets
258
+ pycparser==2.22
259
+ # via cffi
260
+ pydantic==2.9.2
261
+ # via
262
+ # langchain
263
+ # langchain-core
264
+ # langsmith
265
+ # openai
266
+ # ragas
267
+ # unstructured-client
268
+ pydantic-core==2.23.4
269
+ # via pydantic
270
+ pypdf==5.0.1
271
+ # via unstructured-client
272
+ pysbd==0.3.4
273
+ # via ragas
274
+ python-dateutil==2.8.2
275
+ # via
276
+ # pandas
277
+ # unstructured-client
278
+ python-iso639==2024.4.27
279
+ # via unstructured
280
+ python-magic==0.4.27
281
+ # via unstructured
282
+ python-pptx==1.0.2
283
+ # via -r requirements.in
284
+ pytz==2024.2
285
+ # via pandas
286
+ pyyaml==6.0.2
287
+ # via
288
+ # datasets
289
+ # huggingface-hub
290
+ # langchain
291
+ # langchain-community
292
+ # langchain-core
293
+ # transformers
294
+ ragas==0.2.1
295
+ # via -r requirements.in
296
+ rapidfuzz==3.10.0
297
+ # via unstructured
298
+ regex==2024.9.11
299
+ # via
300
+ # nltk
301
+ # tiktoken
302
+ # transformers
303
+ requests==2.32.3
304
+ # via
305
+ # datasets
306
+ # huggingface-hub
307
+ # langchain
308
+ # langchain-community
309
+ # langsmith
310
+ # requests-toolbelt
311
+ # tiktoken
312
+ # transformers
313
+ # unstructured
314
+ requests-toolbelt==1.0.0
315
+ # via
316
+ # langsmith
317
+ # unstructured-client
318
+ safetensors==0.4.5
319
+ # via transformers
320
+ scikit-learn==1.5.2
321
+ # via sentence-transformers
322
+ scipy==1.14.1
323
+ # via
324
+ # scikit-learn
325
+ # sentence-transformers
326
+ sentence-transformers==3.2.0
327
+ # via
328
+ # -r requirements.in
329
+ # langchain-huggingface
330
+ six==1.16.0
331
+ # via
332
+ # langdetect
333
+ # python-dateutil
334
+ sniffio==1.3.1
335
+ # via
336
+ # anyio
337
+ # httpx
338
+ # openai
339
+ soupsieve==2.6
340
+ # via beautifulsoup4
341
+ sqlalchemy==2.0.36
342
+ # via
343
+ # langchain
344
+ # langchain-community
345
+ sympy==1.13.1
346
+ # via torch
347
+ tabulate==0.9.0
348
+ # via unstructured
349
+ tenacity==8.5.0
350
+ # via
351
+ # langchain
352
+ # langchain-community
353
+ # langchain-core
354
+ threadpoolctl==3.5.0
355
+ # via scikit-learn
356
+ tiktoken==0.8.0
357
+ # via
358
+ # langchain-openai
359
+ # ragas
360
+ tokenizers==0.20.1
361
+ # via
362
+ # langchain-huggingface
363
+ # transformers
364
+ torch==2.5.0
365
+ # via sentence-transformers
366
+ tqdm==4.66.5
367
+ # via
368
+ # datasets
369
+ # huggingface-hub
370
+ # nltk
371
+ # openai
372
+ # sentence-transformers
373
+ # transformers
374
+ # unstructured
375
+ transformers==4.45.2
376
+ # via
377
+ # langchain-huggingface
378
+ # sentence-transformers
379
+ triton==3.1.0
380
+ # via torch
381
+ typing-extensions==4.12.2
382
+ # via
383
+ # huggingface-hub
384
+ # langchain-core
385
+ # openai
386
+ # pydantic
387
+ # pydantic-core
388
+ # python-pptx
389
+ # sqlalchemy
390
+ # torch
391
+ # typing-inspect
392
+ # unstructured
393
+ typing-inspect==0.9.0
394
+ # via
395
+ # dataclasses-json
396
+ # unstructured-client
397
+ tzdata==2024.2
398
+ # via pandas
399
+ unstructured==0.15.7
400
+ # via -r requirements.in
401
+ unstructured-client==0.26.1
402
+ # via unstructured
403
+ urllib3==2.2.3
404
+ # via requests
405
+ wrapt==1.16.0
406
+ # via unstructured
407
+ xlsxwriter==3.2.0
408
+ # via python-pptx
409
+ xxhash==3.5.0
410
+ # via datasets
411
+ yarl==1.15.4
412
+ # via aiohttp
tuning/tuning_embeddings_sandbox.ipynb ADDED
The diff for this file is too large to render. See raw diff