File size: 5,466 Bytes
06dc1b7
b210243
 
 
 
 
 
 
 
 
 
 
 
 
06dc1b7
b210243
06dc1b7
 
b210243
 
 
06dc1b7
b210243
 
 
 
 
 
 
 
954011f
 
 
 
 
b210243
 
 
 
06dc1b7
b210243
 
06dc1b7
b210243
 
 
 
 
 
 
 
06dc1b7
b210243
 
 
 
 
954011f
b210243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06dc1b7
 
 
 
 
 
b210243
 
06dc1b7
b210243
 
 
 
 
 
 
 
 
 
06dc1b7
 
 
b210243
06dc1b7
 
 
 
 
 
 
 
 
 
b210243
 
 
06dc1b7
 
b210243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06dc1b7
 
 
 
 
 
 
b210243
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from typing import Dict, List, TypedDict, Annotated, Sequence
from langgraph.graph import Graph, StateGraph, END
from langgraph.prebuilt import ToolExecutor
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.tools.tavily_search import TavilySearchResults
import models
import prompts
from helper_functions import format_docs
from operator import itemgetter

# Define the state structure
class State(TypedDict):
    messages: Sequence[str]
    topic: str
    research_data: Dict[str, str]
    team_members: List[str]
    draft_posts: Sequence[str]
    final_post: str


research_members = ["Qdrant_researcher", "Web_researcher"]
# Research Agent Pieces
qdrant_research_chain = (
        {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")}
        | RunnablePassthrough.assign(context=itemgetter("context"))
        | {"response": prompts.research_query_prompt  | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")}
    )

# Web Search Agent Pieces
tavily_tool = TavilySearchResults(max_results=3)
query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() )
tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser())
tavily_chain = (
    {"query": query_chain} | tavily_simple
)

def query_qdrant(state: State) -> State:
    # Extract the last message as the input
    topic = state["topic"]

    # Run the chain
    result = qdrant_research_chain.invoke({"topic": topic})

    # Update the state with the research results
    state["research_data"]["qdrant_results"] = result

    return state

def web_search(state: State) -> State:
    # Extract the last message as the topic
    topic = state["topic"]
    
    # Get the Qdrant results from the state
    qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.")
    
    # Run the web search chain
    result = tavily_chain.invoke({
        "topic": topic,
        "qdrant_results": qdrant_results
    })
    
    # Update the state with the web search results
    state["research_data"]["web_search_results"] = result
    
    return state

def research_supervisor(state):
    # Implement research supervision logic
    return state

def post_creation(state):
    # Implement post creation logic
    return state

def copy_editing(state):
    # Implement copy editing logic
    return state

def voice_editing(state):
    # Implement voice editing logic
    return state

def post_review(state):
    # Implement post review logic
    return state

def writing_supervisor(state):
    # Implement writing supervision logic
    return state

def overall_supervisor(state):
    # Implement overall supervision logic
    return state

# Create the research team graph
research_graph = StateGraph(State)

research_graph.add_node("query_qdrant", query_qdrant)
research_graph.add_node("web_search", web_search)
research_graph.add_node("research_supervisor", research_supervisor)

research_graph.add_edge("query_qdrant", "research_supervisor")
research_graph.add_edge("web_search", "research_supervisor")
research_graph.add_conditional_edges(
    "research_supervisor",
    lambda x: x["next"],
    {"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": END},
)
#research_graph.add_edge("research_supervisor", END)

research_graph.set_entry_point("research_supervisor")
research_graph_comp = research_graph.compile()

# Create the writing team graph
writing_graph = StateGraph(State)

writing_graph.add_node("post_creation", post_creation)
writing_graph.add_node("copy_editing", copy_editing)
writing_graph.add_node("voice_editing", voice_editing)
writing_graph.add_node("post_review", post_review)
writing_graph.add_node("writing_supervisor", writing_supervisor)

writing_graph.add_edge("post_creation", "writing_supervisor")
writing_graph.add_edge("copy_editing", "writing_supervisor")
writing_graph.add_edge("voice_editing", "writing_supervisor")
writing_graph.add_edge("post_review", "writing_supervisor")
writing_graph.add_conditional_edges(
    "writing_supervisor",
    lambda x: x["next"],
    {"post_creation": "post_creation", 
     "copy_editing": "copy_editing",
     "voice_editing": "voice_editing",
     "post_review": "post_review",
     "FINISH": END},
)
#writing_graph.add_edge("writing_supervisor", END)

writing_graph.set_entry_point("writing_supervisor")

writing_graph_comp = research_graph.compile()

# Create the overall graph
overall_graph = StateGraph(State)

# Add the research and writing team graphs as nodes
overall_graph.add_node("research_team", research_graph)
overall_graph.add_node("writing_team", writing_graph)

# Add the overall supervisor node
overall_graph.add_node("overall_supervisor", overall_supervisor)

overall_graph.set_entry_point("overall_supervisor")

# Connect the nodes
overall_graph.add_edge("research_team", "overall_supervisor")
overall_graph.add_edge("writing_team", "overall_supervisor")
overall_graph.add_conditional_edges(
    "overall_supervisor",
    lambda x: x["next"],
    {"research_team": "research_team", 
     "writing_team": "writing_team",
     "FINISH": END},
)

# Compile the graph
app = overall_graph.compile()