Spaces:
No application file
No application file
cleaning up gradio messages and showing player profile
Browse files- api/event_handlers/gradio_handler.py +41 -13
- api/event_handlers/print_handler.py +23 -1
- api/scripts/workflow_playground.py +2 -1
- api/server_gradio.py +6 -4
- api/tools/player_search.py +18 -2
api/event_handlers/gradio_handler.py
CHANGED
@@ -5,6 +5,19 @@ from langchain_core.outputs.llm_result import LLMResult
|
|
5 |
from typing import List
|
6 |
from langchain_core.messages import BaseMessage
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class GradioEventHandler(AsyncCallbackHandler):
|
10 |
"""
|
@@ -33,32 +46,43 @@ class GradioEventHandler(AsyncCallbackHandler):
|
|
33 |
)
|
34 |
|
35 |
async def on_chat_model_start(self, *args, **kwargs):
|
36 |
-
|
37 |
-
self.
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
43 |
|
44 |
async def on_llm_new_token(self, token: str, **kwargs):
|
45 |
if token:
|
46 |
self.queue.put(token)
|
47 |
|
48 |
async def on_llm_end(self, result: LLMResult, *args, **kwargs):
|
49 |
-
|
50 |
-
|
|
|
51 |
|
52 |
async def on_tool_end(self, output: any, **kwargs):
|
53 |
-
print(f"\n{Fore.CYAN}[TOOL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
56 |
-
self.info_box(
|
57 |
|
58 |
async def on_workflow_end(self, state, *args, **kwargs):
|
59 |
print(f"\n{Fore.CYAN}[WORKFLOW END]{Style.RESET_ALL}")
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
@staticmethod
|
64 |
def is_chat_stream_end(result: LLMResult) -> bool:
|
@@ -67,3 +91,7 @@ class GradioEventHandler(AsyncCallbackHandler):
|
|
67 |
return bool(content and content.strip())
|
68 |
except (IndexError, AttributeError):
|
69 |
return False
|
|
|
|
|
|
|
|
|
|
5 |
from typing import List
|
6 |
from langchain_core.messages import BaseMessage
|
7 |
|
8 |
+
image_base = """
|
9 |
+
<img
|
10 |
+
src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/profiles/players_pics/{filename}"
|
11 |
+
style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
|
12 |
+
/>
|
13 |
+
"""
|
14 |
+
team_image_map = {
|
15 |
+
'everglade-fc': 'Everglade_FC',
|
16 |
+
'fraser-valley-united': 'Fraser_Valley_United',
|
17 |
+
'tierra-alta-fc': 'Tierra_Alta_FC',
|
18 |
+
'yucatan-force': 'Yucatan_Force',
|
19 |
+
}
|
20 |
+
|
21 |
|
22 |
class GradioEventHandler(AsyncCallbackHandler):
|
23 |
"""
|
|
|
46 |
)
|
47 |
|
48 |
async def on_chat_model_start(self, *args, **kwargs):
|
49 |
+
pass
|
50 |
+
# self.info_box('[CHAT START]')
|
51 |
+
# self.ots_box("""
|
52 |
+
# <img
|
53 |
+
# src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/landing.png"
|
54 |
+
# style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
|
55 |
+
# />
|
56 |
+
# """)
|
57 |
|
58 |
async def on_llm_new_token(self, token: str, **kwargs):
|
59 |
if token:
|
60 |
self.queue.put(token)
|
61 |
|
62 |
async def on_llm_end(self, result: LLMResult, *args, **kwargs):
|
63 |
+
pass
|
64 |
+
# if self.is_chat_stream_end(result):
|
65 |
+
# self.queue.put(None)
|
66 |
|
67 |
async def on_tool_end(self, output: any, **kwargs):
|
68 |
+
print(f"\n{Fore.CYAN}[TOOL END] {output}{Style.RESET_ALL}")
|
69 |
+
for doc in output:
|
70 |
+
if True:#doc.metadata.get("show_profile_card"):
|
71 |
+
img = image_base.format(filename=self.get_image_filename(doc))
|
72 |
+
print(f"\n{Fore.YELLOW}[TOOL END] {img}{Style.RESET_ALL}")
|
73 |
+
self.ots_box(img)
|
74 |
+
break
|
75 |
+
# else:
|
76 |
+
# self.info_box(doc)
|
77 |
|
78 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
79 |
+
self.info_box(input.get("name", "[TOOL START]"))
|
80 |
|
81 |
async def on_workflow_end(self, state, *args, **kwargs):
|
82 |
print(f"\n{Fore.CYAN}[WORKFLOW END]{Style.RESET_ALL}")
|
83 |
+
self.queue.put(None)
|
84 |
+
# for msg in state["messages"]:
|
85 |
+
# print(f'{Fore.YELLOW}{msg.content}{Style.RESET_ALL}')
|
86 |
|
87 |
@staticmethod
|
88 |
def is_chat_stream_end(result: LLMResult) -> bool:
|
|
|
91 |
return bool(content and content.strip())
|
92 |
except (IndexError, AttributeError):
|
93 |
return False
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def get_image_filename(doc):
|
97 |
+
return f'{team_image_map.get(doc.metadata.get("team"))}_{doc.metadata.get("number")}.png'
|
api/event_handlers/print_handler.py
CHANGED
@@ -4,6 +4,19 @@ from langchain_core.outputs.llm_result import LLMResult
|
|
4 |
from typing import List
|
5 |
from langchain_core.messages import BaseMessage
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
class PrintEventHandler(AsyncCallbackHandler):
|
8 |
"""
|
9 |
Example async event handler: prints streaming tokens and tool results.
|
@@ -26,7 +39,12 @@ class PrintEventHandler(AsyncCallbackHandler):
|
|
26 |
print('\n[END]')
|
27 |
|
28 |
async def on_tool_end(self, output: any, **kwargs):
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
32 |
print(f"\n{Fore.CYAN}[TOOL START]{Style.RESET_ALL}")
|
@@ -42,6 +60,10 @@ class PrintEventHandler(AsyncCallbackHandler):
|
|
42 |
except (IndexError, AttributeError):
|
43 |
return False
|
44 |
|
|
|
|
|
|
|
|
|
45 |
# def __getattribute__(self, name):
|
46 |
# attr = super().__getattribute__(name)
|
47 |
# if callable(attr) and name.startswith("on_"):
|
|
|
4 |
from typing import List
|
5 |
from langchain_core.messages import BaseMessage
|
6 |
|
7 |
+
image_base = """
|
8 |
+
<img
|
9 |
+
src="https://huggingface.co/spaces/ryanbalch/IFX-huge-league/resolve/main/assets/profiles/{filename}"
|
10 |
+
style="max-width: 100%; max-height: 100%; object-fit: contain; display: block; margin: 0 auto;"
|
11 |
+
/>
|
12 |
+
"""
|
13 |
+
team_image_map = {
|
14 |
+
'everglade-fc': 'Everglade_FC',
|
15 |
+
'fraser-valley-united': 'Fraser_Valley_United',
|
16 |
+
'tierra-alta-fc': 'Tierra_Alta_FC',
|
17 |
+
'yucatan-force': 'Yucatan_Force',
|
18 |
+
}
|
19 |
+
|
20 |
class PrintEventHandler(AsyncCallbackHandler):
|
21 |
"""
|
22 |
Example async event handler: prints streaming tokens and tool results.
|
|
|
39 |
print('\n[END]')
|
40 |
|
41 |
async def on_tool_end(self, output: any, **kwargs):
|
42 |
+
for doc in output:
|
43 |
+
if doc.metadata.get("show_profile_card"):
|
44 |
+
img = image_base.format(filename=self.get_image_filename(doc))
|
45 |
+
print(f"\n{Fore.CYAN}[TOOL RESULT] {img}{Style.RESET_ALL}")
|
46 |
+
else:
|
47 |
+
print(f"\n{Fore.CYAN}[TOOL RESULT] {doc}{Style.RESET_ALL}")
|
48 |
|
49 |
async def on_tool_start(self, input: any, *args, **kwargs):
|
50 |
print(f"\n{Fore.CYAN}[TOOL START]{Style.RESET_ALL}")
|
|
|
60 |
except (IndexError, AttributeError):
|
61 |
return False
|
62 |
|
63 |
+
@staticmethod
|
64 |
+
def get_image_filename(doc):
|
65 |
+
return f'{team_image_map.get(doc.metadata.get("team"))}_{doc.metadata.get("number")}.png'
|
66 |
+
|
67 |
# def __getattribute__(self, name):
|
68 |
# attr = super().__getattribute__(name)
|
69 |
# if callable(attr) and name.startswith("on_"):
|
api/scripts/workflow_playground.py
CHANGED
@@ -36,8 +36,9 @@ workflow_bundle, state = build_workflow_with_state(
|
|
36 |
last_name="Bigly",
|
37 |
persona="Casual Fan",
|
38 |
messages=[
|
39 |
-
HumanMessage(content="tell me about some players in everglade fc"),
|
40 |
# HumanMessage(content="tell me about the league")
|
|
|
41 |
],
|
42 |
)
|
43 |
|
|
|
36 |
last_name="Bigly",
|
37 |
persona="Casual Fan",
|
38 |
messages=[
|
39 |
+
# HumanMessage(content="tell me about some players in everglade fc"),
|
40 |
# HumanMessage(content="tell me about the league")
|
41 |
+
HumanMessage(content="tell me about Ryan Martinez of everglade fc")
|
42 |
],
|
43 |
)
|
44 |
|
api/server_gradio.py
CHANGED
@@ -114,13 +114,13 @@ def submit_helper(state, handler, user_query):
|
|
114 |
gr.Info(token["message"])
|
115 |
continue
|
116 |
if token["type"] == "ots":
|
|
|
117 |
state.ots_content = ots_default.format(content=token["message"])
|
118 |
state = AppState(**state.model_dump())
|
119 |
-
yield state, result
|
120 |
continue
|
121 |
result += token
|
122 |
yield state, result
|
123 |
-
|
124 |
state.history.append(AIMessage(content=result))
|
125 |
|
126 |
### Interface ###
|
@@ -210,12 +210,14 @@ with gr.Blocks() as demo:
|
|
210 |
|
211 |
@submit_btn.click(inputs=[state, user_query], outputs=[state, llm_response])
|
212 |
def submit(state, user_query):
|
213 |
-
user_query = user_query or "tell me about some players in everglade fc"
|
|
|
214 |
yield from submit_helper(state, handler, user_query)
|
215 |
|
216 |
@user_query.submit(inputs=[state, user_query], outputs=[state, llm_response])
|
217 |
def user_query_change(state, user_query):
|
218 |
-
user_query = user_query or "tell me about some players in everglade fc"
|
|
|
219 |
yield from submit_helper(state, handler, user_query)
|
220 |
|
221 |
@persona.change(inputs=[persona, state], outputs=[persona_disp])
|
|
|
114 |
gr.Info(token["message"])
|
115 |
continue
|
116 |
if token["type"] == "ots":
|
117 |
+
print('OTS: ' + token["message"])
|
118 |
state.ots_content = ots_default.format(content=token["message"])
|
119 |
state = AppState(**state.model_dump())
|
|
|
120 |
continue
|
121 |
result += token
|
122 |
yield state, result
|
123 |
+
|
124 |
state.history.append(AIMessage(content=result))
|
125 |
|
126 |
### Interface ###
|
|
|
210 |
|
211 |
@submit_btn.click(inputs=[state, user_query], outputs=[state, llm_response])
|
212 |
def submit(state, user_query):
|
213 |
+
# user_query = user_query or "tell me about some players in everglade fc"
|
214 |
+
user_query = user_query or "tell me about Ryan Martinez of everglade fc"
|
215 |
yield from submit_helper(state, handler, user_query)
|
216 |
|
217 |
@user_query.submit(inputs=[state, user_query], outputs=[state, llm_response])
|
218 |
def user_query_change(state, user_query):
|
219 |
+
# user_query = user_query or "tell me about some players in everglade fc"
|
220 |
+
user_query = user_query or "tell me about Ryan Martinez of everglade fc"
|
221 |
yield from submit_helper(state, handler, user_query)
|
222 |
|
223 |
@persona.change(inputs=[persona, state], outputs=[persona_disp])
|
api/tools/player_search.py
CHANGED
@@ -31,6 +31,10 @@ class PlayerSearchSchema(BaseModel):
|
|
31 |
" • Everglade FC (Miami, USA): Flashy, wild, South Florida flair.\n"
|
32 |
" • Fraser Valley United (Abbotsford, Canada): Vineyard roots, top youth academy."
|
33 |
))
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
class PlayerSearchTool(BaseTool):
|
@@ -43,22 +47,34 @@ class PlayerSearchTool(BaseTool):
|
|
43 |
|
44 |
def _run(self,
|
45 |
query: str,
|
|
|
46 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
47 |
) -> List[Document]:
|
48 |
k = 5 if query[0] == "*" else 3
|
49 |
-
|
|
|
|
|
50 |
query,
|
51 |
k=k,
|
52 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
53 |
)
|
|
|
|
|
|
|
54 |
|
55 |
async def _arun(self,
|
56 |
query: str,
|
|
|
57 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
58 |
) -> List[Document]:
|
59 |
k = 5 if query[0] == "*" else 3
|
60 |
-
|
|
|
|
|
61 |
query,
|
62 |
k=k,
|
63 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
64 |
)
|
|
|
|
|
|
|
|
31 |
" • Everglade FC (Miami, USA): Flashy, wild, South Florida flair.\n"
|
32 |
" • Fraser Valley United (Abbotsford, Canada): Vineyard roots, top youth academy."
|
33 |
))
|
34 |
+
show_profile_card: bool = Field(description=(
|
35 |
+
"If true, only the best-matching player will be returned. The UI will display a player profile card for this result, in addition to the LLM's response."
|
36 |
+
"The LLM should use this flag when the user expects a single, specific player and a card UI."
|
37 |
+
))
|
38 |
|
39 |
|
40 |
class PlayerSearchTool(BaseTool):
|
|
|
47 |
|
48 |
def _run(self,
|
49 |
query: str,
|
50 |
+
show_profile_card: bool = False,
|
51 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
52 |
) -> List[Document]:
|
53 |
k = 5 if query[0] == "*" else 3
|
54 |
+
if show_profile_card:
|
55 |
+
k = 1
|
56 |
+
results = vector_store.similarity_search(
|
57 |
query,
|
58 |
k=k,
|
59 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
60 |
)
|
61 |
+
for result in results:
|
62 |
+
result.metadata["show_profile_card"] = show_profile_card
|
63 |
+
return results
|
64 |
|
65 |
async def _arun(self,
|
66 |
query: str,
|
67 |
+
show_profile_card: bool = False,
|
68 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
69 |
) -> List[Document]:
|
70 |
k = 5 if query[0] == "*" else 3
|
71 |
+
if show_profile_card:
|
72 |
+
k = 1
|
73 |
+
results = await vector_store.asimilarity_search(
|
74 |
query,
|
75 |
k=k,
|
76 |
filter=lambda doc: doc.metadata.get("type") == "player",
|
77 |
)
|
78 |
+
for result in results:
|
79 |
+
result.metadata["show_profile_card"] = show_profile_card
|
80 |
+
return results
|