ryanbalch commited on
Commit
bb8feed
·
1 Parent(s): f4657fa

bugfix for new gradio state pattern messing with history

Browse files
api/scripts/workflow_playground.py CHANGED
@@ -37,6 +37,7 @@ workflow_bundle, state = build_workflow_with_state(
37
  persona="Casual Fan",
38
  messages=[
39
  HumanMessage(content="tell me about some players in everglade fc"),
 
40
  ],
41
  )
42
 
 
37
  persona="Casual Fan",
38
  messages=[
39
  HumanMessage(content="tell me about some players in everglade fc"),
40
+ # HumanMessage(content="tell me about the league")
41
  ],
42
  )
43
 
api/server_gradio.py CHANGED
@@ -1,9 +1,9 @@
1
  import asyncio
2
  import gradio as gr
3
  import os
4
- from pydantic import BaseModel
5
  from threading import Thread
6
- from langchain_core.messages import HumanMessage, AIMessage
7
  from event_handlers.gradio_handler import GradioEventHandler
8
  from workflows.base import build_workflow_with_state
9
  from utils.freeplay_helpers import FreeplayClient
@@ -13,6 +13,11 @@ lorem_ipsum = """Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do
13
  show_state = True
14
  fake_response = False
15
  dev_mode = os.getenv("DEV_MODE", "").lower() == "true"
 
 
 
 
 
16
 
17
 
18
  class AppState(BaseModel):
@@ -34,6 +39,20 @@ class AppState(BaseModel):
34
  if not self.freeplay_session_id:
35
  self.freeplay_session_id = FreeplayClient().create_session().session_id
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ### Helpers ###
38
 
39
  def submit_helper(state, handler, user_query):
@@ -41,7 +60,7 @@ def submit_helper(state, handler, user_query):
41
  state.ensure_sessions()
42
  message = HumanMessage(content=user_query)
43
  state.history.append(message)
44
- state = AppState(**state.dict())
45
  yield state, ""
46
 
47
  if fake_response:
@@ -109,7 +128,10 @@ with gr.Blocks() as demo:
109
  lines=1,
110
  interactive=True,
111
  value=state.value.last_name)
112
- llm_response = gr.Textbox(label="LLM Response", lines=10)
 
 
 
113
 
114
  with gr.Row(scale=1):
115
  with gr.Column(scale=1):
@@ -158,9 +180,9 @@ with gr.Blocks() as demo:
158
 
159
  ### Events
160
 
161
- @state.change(inputs=[state], outputs=[count_disp, persona_disp, zep_session_id_disp, freeplay_session_id_disp])
162
  def state_change(state):
163
- return state.count, state.persona, state.zep_session_id, state.freeplay_session_id
164
 
165
  @clear_state_btn.click(outputs=[state, llm_response, persona, user_query, email, first_name, last_name])
166
  def clear_state():
 
1
  import asyncio
2
  import gradio as gr
3
  import os
4
+ from pydantic import BaseModel, field_validator
5
  from threading import Thread
6
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
7
  from event_handlers.gradio_handler import GradioEventHandler
8
  from workflows.base import build_workflow_with_state
9
  from utils.freeplay_helpers import FreeplayClient
 
13
  show_state = True
14
  fake_response = False
15
  dev_mode = os.getenv("DEV_MODE", "").lower() == "true"
16
+ MESSAGE_TYPE_MAP = {
17
+ "human": HumanMessage,
18
+ "ai": AIMessage,
19
+ # Add other message types as needed
20
+ }
21
 
22
 
23
  class AppState(BaseModel):
 
39
  if not self.freeplay_session_id:
40
  self.freeplay_session_id = FreeplayClient().create_session().session_id
41
 
42
+ @field_validator("history", mode="before")
43
+ @classmethod
44
+ def validate_history(cls, v):
45
+ out = []
46
+ for item in v:
47
+ if isinstance(item, BaseMessage):
48
+ out.append(item)
49
+ elif isinstance(item, dict):
50
+ out.append(MESSAGE_TYPE_MAP[item["type"]](**item))
51
+ else:
52
+ raise TypeError(f"Invalid type in history: {type(item)}")
53
+ return out
54
+
55
+
56
  ### Helpers ###
57
 
58
  def submit_helper(state, handler, user_query):
 
60
  state.ensure_sessions()
61
  message = HumanMessage(content=user_query)
62
  state.history.append(message)
63
+ state = AppState(**state.model_dump())
64
  yield state, ""
65
 
66
  if fake_response:
 
128
  lines=1,
129
  interactive=True,
130
  value=state.value.last_name)
131
+
132
+ with gr.Row():
133
+ llm_response = gr.Textbox(label="LLM Response", lines=10)
134
+ ots_box = gr.Textbox(label="OTS", lines=10)
135
 
136
  with gr.Row(scale=1):
137
  with gr.Column(scale=1):
 
180
 
181
  ### Events
182
 
183
+ @state.change(inputs=[state], outputs=[count_disp, persona_disp, zep_session_id_disp, freeplay_session_id_disp, user_query])
184
  def state_change(state):
185
+ return state.count, state.persona, state.zep_session_id, state.freeplay_session_id, ""
186
 
187
  @clear_state_btn.click(outputs=[state, llm_response, persona, user_query, email, first_name, last_name])
188
  def clear_state():