ryanbalch commited on
Commit
bf716d8
·
1 Parent(s): dfbef19

cleaning up gradio messages and showing player profile

Browse files
api/event_handlers/gradio_handler.py CHANGED
@@ -5,6 +5,19 @@ from langchain_core.outputs.llm_result import LLMResult
5
  from typing import List
6
  from langchain_core.messages import BaseMessage
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class GradioEventHandler(AsyncCallbackHandler):
10
  """
@@ -33,32 +46,43 @@ class GradioEventHandler(AsyncCallbackHandler):
33
  )
34
 
35
  async def on_chat_model_start(self, *args, **kwargs):
36
- self.info_box('[CHAT START]')
37
- self.ots_box("""
38
- <img
39
- src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/landing.png"
40
- style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
41
- />
42
- """)
 
43
 
44
  async def on_llm_new_token(self, token: str, **kwargs):
45
  if token:
46
  self.queue.put(token)
47
 
48
  async def on_llm_end(self, result: LLMResult, *args, **kwargs):
49
- if self.is_chat_stream_end(result):
50
- self.queue.put(None)
 
51
 
52
  async def on_tool_end(self, output: any, **kwargs):
53
- print(f"\n{Fore.CYAN}[TOOL RESULT] {output}{Style.RESET_ALL}")
 
 
 
 
 
 
 
 
54
 
55
  async def on_tool_start(self, input: any, *args, **kwargs):
56
- self.info_box(f"[TOOL START]")
57
 
58
  async def on_workflow_end(self, state, *args, **kwargs):
59
  print(f"\n{Fore.CYAN}[WORKFLOW END]{Style.RESET_ALL}")
60
- for msg in state["messages"]:
61
- print(f'{Fore.YELLOW}{msg.content}{Style.RESET_ALL}')
 
62
 
63
  @staticmethod
64
  def is_chat_stream_end(result: LLMResult) -> bool:
@@ -67,3 +91,7 @@ class GradioEventHandler(AsyncCallbackHandler):
67
  return bool(content and content.strip())
68
  except (IndexError, AttributeError):
69
  return False
 
 
 
 
 
5
  from typing import List
6
  from langchain_core.messages import BaseMessage
7
 
8
+ image_base = """
9
+ <img
10
+ src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/profiles/players_pics/{filename}"
11
+ style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
12
+ />
13
+ """
14
+ team_image_map = {
15
+ 'everglade-fc': 'Everglade_FC',
16
+ 'fraser-valley-united': 'Fraser_Valley_United',
17
+ 'tierra-alta-fc': 'Tierra_Alta_FC',
18
+ 'yucatan-force': 'Yucatan_Force',
19
+ }
20
+
21
 
22
  class GradioEventHandler(AsyncCallbackHandler):
23
  """
 
46
  )
47
 
48
  async def on_chat_model_start(self, *args, **kwargs):
49
+ pass
50
+ # self.info_box('[CHAT START]')
51
+ # self.ots_box("""
52
+ # <img
53
+ # src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/landing.png"
54
+ # style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
55
+ # />
56
+ # """)
57
 
58
  async def on_llm_new_token(self, token: str, **kwargs):
59
  if token:
60
  self.queue.put(token)
61
 
62
  async def on_llm_end(self, result: LLMResult, *args, **kwargs):
63
+ pass
64
+ # if self.is_chat_stream_end(result):
65
+ # self.queue.put(None)
66
 
67
  async def on_tool_end(self, output: any, **kwargs):
68
+ print(f"\n{Fore.CYAN}[TOOL END] {output}{Style.RESET_ALL}")
69
+ for doc in output:
70
+ if True:#doc.metadata.get("show_profile_card"):
71
+ img = image_base.format(filename=self.get_image_filename(doc))
72
+ print(f"\n{Fore.YELLOW}[TOOL END] {img}{Style.RESET_ALL}")
73
+ self.ots_box(img)
74
+ break
75
+ # else:
76
+ # self.info_box(doc)
77
 
78
  async def on_tool_start(self, input: any, *args, **kwargs):
79
+ self.info_box(input.get("name", "[TOOL START]"))
80
 
81
  async def on_workflow_end(self, state, *args, **kwargs):
82
  print(f"\n{Fore.CYAN}[WORKFLOW END]{Style.RESET_ALL}")
83
+ self.queue.put(None)
84
+ # for msg in state["messages"]:
85
+ # print(f'{Fore.YELLOW}{msg.content}{Style.RESET_ALL}')
86
 
87
  @staticmethod
88
  def is_chat_stream_end(result: LLMResult) -> bool:
 
91
  return bool(content and content.strip())
92
  except (IndexError, AttributeError):
93
  return False
94
+
95
+ @staticmethod
96
+ def get_image_filename(doc):
97
+ return f'{team_image_map.get(doc.metadata.get("team"))}_{doc.metadata.get("number")}.png'
api/event_handlers/print_handler.py CHANGED
@@ -4,6 +4,19 @@ from langchain_core.outputs.llm_result import LLMResult
4
  from typing import List
5
  from langchain_core.messages import BaseMessage
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class PrintEventHandler(AsyncCallbackHandler):
8
  """
9
  Example async event handler: prints streaming tokens and tool results.
@@ -26,7 +39,12 @@ class PrintEventHandler(AsyncCallbackHandler):
26
  print('\n[END]')
27
 
28
  async def on_tool_end(self, output: any, **kwargs):
29
- print(f"\n{Fore.CYAN}[TOOL RESULT] {output}{Style.RESET_ALL}")
 
 
 
 
 
30
 
31
  async def on_tool_start(self, input: any, *args, **kwargs):
32
  print(f"\n{Fore.CYAN}[TOOL START]{Style.RESET_ALL}")
@@ -42,6 +60,10 @@ class PrintEventHandler(AsyncCallbackHandler):
42
  except (IndexError, AttributeError):
43
  return False
44
 
 
 
 
 
45
  # def __getattribute__(self, name):
46
  # attr = super().__getattribute__(name)
47
  # if callable(attr) and name.startswith("on_"):
 
4
  from typing import List
5
  from langchain_core.messages import BaseMessage
6
 
7
+ image_base = """
8
+ <img
9
+ src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/profiles/{filename}"
10
+ style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
11
+ />
12
+ """
13
+ team_image_map = {
14
+ 'everglade-fc': 'Everglade_FC',
15
+ 'fraser-valley-united': 'Fraser_Valley_United',
16
+ 'tierra-alta-fc': 'Tierra_Alta_FC',
17
+ 'yucatan-force': 'Yucatan_Force',
18
+ }
19
+
20
  class PrintEventHandler(AsyncCallbackHandler):
21
  """
22
  Example async event handler: prints streaming tokens and tool results.
 
39
  print('\n[END]')
40
 
41
  async def on_tool_end(self, output: any, **kwargs):
42
+ for doc in output:
43
+ if doc.metadata.get("show_profile_card"):
44
+ img = image_base.format(filename=self.get_image_filename(doc))
45
+ print(f"\n{Fore.CYAN}[TOOL RESULT] {img}{Style.RESET_ALL}")
46
+ else:
47
+ print(f"\n{Fore.CYAN}[TOOL RESULT] {doc}{Style.RESET_ALL}")
48
 
49
  async def on_tool_start(self, input: any, *args, **kwargs):
50
  print(f"\n{Fore.CYAN}[TOOL START]{Style.RESET_ALL}")
 
60
  except (IndexError, AttributeError):
61
  return False
62
 
63
+ @staticmethod
64
+ def get_image_filename(doc):
65
+ return f'{team_image_map.get(doc.metadata.get("team"))}_{doc.metadata.get("number")}.png'
66
+
67
  # def __getattribute__(self, name):
68
  # attr = super().__getattribute__(name)
69
  # if callable(attr) and name.startswith("on_"):
api/scripts/workflow_playground.py CHANGED
@@ -36,8 +36,9 @@ workflow_bundle, state = build_workflow_with_state(
36
  last_name="Bigly",
37
  persona="Casual Fan",
38
  messages=[
39
- HumanMessage(content="tell me about some players in everglade fc"),
40
  # HumanMessage(content="tell me about the league")
 
41
  ],
42
  )
43
 
 
36
  last_name="Bigly",
37
  persona="Casual Fan",
38
  messages=[
39
+ # HumanMessage(content="tell me about some players in everglade fc"),
40
  # HumanMessage(content="tell me about the league")
41
+ HumanMessage(content="tell me about Ryan Martinez of everglade fc")
42
  ],
43
  )
44
 
api/server_gradio.py CHANGED
@@ -114,13 +114,13 @@ def submit_helper(state, handler, user_query):
114
  gr.Info(token["message"])
115
  continue
116
  if token["type"] == "ots":
 
117
  state.ots_content = ots_default.format(content=token["message"])
118
  state = AppState(**state.model_dump())
119
- yield state, result
120
  continue
121
  result += token
122
  yield state, result
123
-
124
  state.history.append(AIMessage(content=result))
125
 
126
  ### Interface ###
@@ -210,12 +210,14 @@ with gr.Blocks() as demo:
210
 
211
  @submit_btn.click(inputs=[state, user_query], outputs=[state, llm_response])
212
  def submit(state, user_query):
213
- user_query = user_query or "tell me about some players in everglade fc"
 
214
  yield from submit_helper(state, handler, user_query)
215
 
216
  @user_query.submit(inputs=[state, user_query], outputs=[state, llm_response])
217
  def user_query_change(state, user_query):
218
- user_query = user_query or "tell me about some players in everglade fc"
 
219
  yield from submit_helper(state, handler, user_query)
220
 
221
  @persona.change(inputs=[persona, state], outputs=[persona_disp])
 
114
  gr.Info(token["message"])
115
  continue
116
  if token["type"] == "ots":
117
+ print('OTS: ' + token["message"])
118
  state.ots_content = ots_default.format(content=token["message"])
119
  state = AppState(**state.model_dump())
 
120
  continue
121
  result += token
122
  yield state, result
123
+
124
  state.history.append(AIMessage(content=result))
125
 
126
  ### Interface ###
 
210
 
211
  @submit_btn.click(inputs=[state, user_query], outputs=[state, llm_response])
212
  def submit(state, user_query):
213
+ # user_query = user_query or "tell me about some players in everglade fc"
214
+ user_query = user_query or "tell me about Ryan Martinez of everglade fc"
215
  yield from submit_helper(state, handler, user_query)
216
 
217
  @user_query.submit(inputs=[state, user_query], outputs=[state, llm_response])
218
  def user_query_change(state, user_query):
219
+ # user_query = user_query or "tell me about some players in everglade fc"
220
+ user_query = user_query or "tell me about Ryan Martinez of everglade fc"
221
  yield from submit_helper(state, handler, user_query)
222
 
223
  @persona.change(inputs=[persona, state], outputs=[persona_disp])
api/tools/player_search.py CHANGED
@@ -31,6 +31,10 @@ class PlayerSearchSchema(BaseModel):
31
  " • Everglade FC (Miami, USA): Flashy, wild, South Florida flair.\n"
32
  " • Fraser Valley United (Abbotsford, Canada): Vineyard roots, top youth academy."
33
  ))
 
 
 
 
34
 
35
 
36
  class PlayerSearchTool(BaseTool):
@@ -43,22 +47,34 @@ class PlayerSearchTool(BaseTool):
43
 
44
  def _run(self,
45
  query: str,
 
46
  run_manager: Optional[CallbackManagerForToolRun] = None,
47
  ) -> List[Document]:
48
  k = 5 if query[0] == "*" else 3
49
- return vector_store.similarity_search(
 
 
50
  query,
51
  k=k,
52
  filter=lambda doc: doc.metadata.get("type") == "player",
53
  )
 
 
 
54
 
55
  async def _arun(self,
56
  query: str,
 
57
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
58
  ) -> List[Document]:
59
  k = 5 if query[0] == "*" else 3
60
- return await vector_store.asimilarity_search(
 
 
61
  query,
62
  k=k,
63
  filter=lambda doc: doc.metadata.get("type") == "player",
64
  )
 
 
 
 
31
  " • Everglade FC (Miami, USA): Flashy, wild, South Florida flair.\n"
32
  " • Fraser Valley United (Abbotsford, Canada): Vineyard roots, top youth academy."
33
  ))
34
+ show_profile_card: bool = Field(description=(
35
+ "If true, only the best-matching player will be returned. The UI will display a player profile card for this result, in addition to the LLM's response."
36
+ "The LLM should use this flag when the user expects a single, specific player and a card UI."
37
+ ))
38
 
39
 
40
  class PlayerSearchTool(BaseTool):
 
47
 
48
  def _run(self,
49
  query: str,
50
+ show_profile_card: bool = False,
51
  run_manager: Optional[CallbackManagerForToolRun] = None,
52
  ) -> List[Document]:
53
  k = 5 if query[0] == "*" else 3
54
+ if show_profile_card:
55
+ k = 1
56
+ results = vector_store.similarity_search(
57
  query,
58
  k=k,
59
  filter=lambda doc: doc.metadata.get("type") == "player",
60
  )
61
+ for result in results:
62
+ result.metadata["show_profile_card"] = show_profile_card
63
+ return results
64
 
65
  async def _arun(self,
66
  query: str,
67
+ show_profile_card: bool = False,
68
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
69
  ) -> List[Document]:
70
  k = 5 if query[0] == "*" else 3
71
+ if show_profile_card:
72
+ k = 1
73
+ results = await vector_store.asimilarity_search(
74
  query,
75
  k=k,
76
  filter=lambda doc: doc.metadata.get("type") == "player",
77
  )
78
+ for result in results:
79
+ result.metadata["show_profile_card"] = show_profile_card
80
+ return results