File size: 6,488 Bytes
4878ce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import warnings
warnings.filterwarnings("ignore", message=".*TqdmWarning.*")
from dotenv import load_dotenv

_ = load_dotenv()

from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
import operator
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from tavily import TavilyClient
import os
import gradio as gr

# Define agent state class
class AgentState(TypedDict):
    task: str
    lnode: str
    plan: str
    research_queries: List[str]
    draft: str
    critique: str
    content: List[str]
    revision_number: int
    max_revisions: int
    count: Annotated[int, operator.add]

# Define queries class
class Queries(BaseModel):
    queries: List[str]

# Writer Agent Class
class Ewriter():
    def __init__(self):
        self.model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
        self.PLAN_PROMPT = "You are an expert writer tasked with writing a high-level outline of a short 3-paragraph essay."
        self.RESEARCH_PROMPT = "Generate three research queries to help in writing an essay on the given topic."
        self.WRITER_PROMPT = "You are an essay assistant tasked with writing an excellent 3-paragraph essay."
        self.REFLECTION_PROMPT = "You are a teacher grading an essay. Provide critique and suggestions."
        self.tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
        
        # Initialize Graph
        builder = StateGraph(AgentState)
        builder.add_node("planner", self.plan_node)
        builder.add_node("research", self.research_node)
        builder.add_node("generate", self.generation_node)
        builder.add_node("reflect", self.reflection_node)
        builder.set_entry_point("planner")
        builder.add_edge("planner", "research")
        builder.add_edge("research", "generate")
        builder.add_edge("generate", "reflect")
        builder.add_edge("reflect", END)  # Ensure reflect is not a dead-end
        
        self.graph = builder.compile()

    def plan_node(self, state: AgentState):
        try:
            response = self.model.invoke([SystemMessage(content=self.PLAN_PROMPT), HumanMessage(content=state['task'])])
            return {"plan": response.content, "lnode": "planner", "count": 1}
        except Exception as e:
            return {"plan": f"Error occurred in planning: {str(e)}", "lnode": "planner", "count": 0}

    def research_node(self, state: AgentState):
        try:
            response = self.model.invoke([SystemMessage(content=self.RESEARCH_PROMPT), HumanMessage(content=state['task'])])
            return {"research_queries": response.content.split('\n'), "lnode": "research", "count": 1}
        except Exception as e:
            return {"research_queries": f"Error occurred in research: {str(e)}", "lnode": "research", "count": 0}

    def generation_node(self, state: AgentState):
        try:
            response = self.model.invoke([SystemMessage(content=self.WRITER_PROMPT), HumanMessage(content=state['task'])])
            return {"draft": response.content, "lnode": "generate", "count": 1}
        except Exception as e:
            return {"draft": f"Error occurred in generation: {str(e)}", "lnode": "generate", "count": 0}
    
    def reflection_node(self, state: AgentState):
        try:
            response = self.model.invoke([SystemMessage(content=self.REFLECTION_PROMPT), HumanMessage(content=state['draft'])])
            return {"critique": response.content, "lnode": "reflect", "count": 1}
        except Exception as e:
            return {"critique": f"Error occurred in reflection: {str(e)}", "lnode": "reflect", "count": 0}

# Gradio UI
class WriterGui():
    def __init__(self, graph):
        self.graph = graph
        self.demo = self.create_interface()
    
    def run_agent(self, topic, revision_number, max_revisions):
        config = {'task': topic, 'max_revisions': max_revisions, 'revision_number': revision_number, 'lnode': "", 'count': 0}
        response = self.graph.invoke(config)
        return response["draft"], response["lnode"], response["count"], response.get("critique", ""), response.get("research_queries", [])

    def continue_agent(self, topic, revision_number, max_revisions, last_node, current_draft):
        config = {'task': topic, 'max_revisions': max_revisions, 'revision_number': revision_number, 'lnode': last_node, 'draft': current_draft, 'count': 0}
        response = self.graph.invoke(config)
        return response["draft"], response["lnode"], response["count"], response.get("critique", ""), response.get("research_queries", [])
    
    def create_interface(self):
        with gr.Blocks() as demo:
            with gr.Tabs():
                with gr.Tab("Agent"):
                    topic_input = gr.Textbox(label="Essay Topic")
                    last_node = gr.Textbox(label="Last Node", interactive=False)
                    next_node = gr.Textbox(label="Next Node", interactive=False)
                    thread = gr.Textbox(label="Thread", interactive=False)
                    draft_rev = gr.Textbox(label="Draft Revision", interactive=False)
                    count = gr.Textbox(label="Count", interactive=False)
                    generate_button = gr.Button("Generate Essay", variant="primary")
                    continue_button = gr.Button("Continue Essay")
                    
                    with gr.Row():
                        gr.Markdown("**Manage Agent**")
                    with gr.Row():
                        output_text = gr.Textbox(label="Live Agent Output", interactive=False)
                    with gr.Row():
                        critique_text = gr.Textbox(label="Critique", interactive=False)
                    with gr.Row():
                        research_text = gr.Textbox(label="Research Queries", interactive=False)
                    
                    generate_button.click(fn=self.run_agent, inputs=[topic_input, gr.State(0), gr.State(2)], outputs=[output_text, last_node, next_node, critique_text, research_text])
                    continue_button.click(fn=self.continue_agent, inputs=[topic_input, gr.State(0), gr.State(2), last_node, draft_rev], outputs=[output_text, last_node, next_node, critique_text, research_text])
            
        return demo

    def launch(self):
        self.demo.launch(share=True)

# Run the App
MultiAgent = Ewriter()
app = WriterGui(MultiAgent.graph)
app.launch()