Asaad Almutareb commited on
Commit
bec8a7b
·
1 Parent(s): a0df48e

added callback

Browse files
innovation_pathfinder_ai/backend/app/api/v1/agents/hf_mixtral_agent.py CHANGED
@@ -17,6 +17,7 @@ from app.utils import logger
17
  from app.utils import utils
18
  from langchain.globals import set_llm_cache
19
  from langchain.cache import SQLiteCache
 
20
 
21
  set_llm_cache(SQLiteCache(database_path=".cache.db"))
22
  logger = logger.get_console_logger("hf_mixtral_agent")
@@ -40,6 +41,7 @@ async def websocket_endpoint(websocket: WebSocket):
40
  try:
41
  data = await websocket.receive_json()
42
  user_message = data["message"]
 
43
 
44
  # resp = IChatResponse(
45
  # sender="you",
@@ -51,9 +53,9 @@ async def websocket_endpoint(websocket: WebSocket):
51
 
52
  # await websocket.send_json(resp.dict())
53
  message_id: str = utils.generate_uuid()
54
- # custom_handler = CustomFinalStreamingStdOutCallbackHandler(
55
- # websocket, message_id=message_id
56
- # )
57
 
58
  # Load the model from the Hugging Face Hub
59
  llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
@@ -106,7 +108,7 @@ async def websocket_endpoint(websocket: WebSocket):
106
  handle_parsing_errors=True,
107
  )
108
 
109
- await agent_executor.arun(input=user_message) #, callbacks=[custom_handler]
110
  except WebSocketDisconnect:
111
  logger.info("websocket disconnect")
112
  break
 
17
  from app.utils import utils
18
  from langchain.globals import set_llm_cache
19
  from langchain.cache import SQLiteCache
20
+ from app.utils.callback import CustomAsyncCallbackHandler
21
 
22
  set_llm_cache(SQLiteCache(database_path=".cache.db"))
23
  logger = logger.get_console_logger("hf_mixtral_agent")
 
41
  try:
42
  data = await websocket.receive_json()
43
  user_message = data["message"]
44
+ chat_history = data["history"]
45
 
46
  # resp = IChatResponse(
47
  # sender="you",
 
53
 
54
  # await websocket.send_json(resp.dict())
55
  message_id: str = utils.generate_uuid()
56
+ custom_handler = CustomAsyncCallbackHandler(
57
+ websocket, message_id=message_id
58
+ )
59
 
60
  # Load the model from the Hugging Face Hub
61
  llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
 
108
  handle_parsing_errors=True,
109
  )
110
 
111
+ await agent_executor.arun(input=user_message, chat_history=chat_history, callbacks=[custom_handler])
112
  except WebSocketDisconnect:
113
  logger.info("websocket disconnect")
114
  break
innovation_pathfinder_ai/backend/app/schemas/message_schema.py CHANGED
@@ -1,5 +1,7 @@
1
- from pydantic import BaseModel
2
  from typing import List, Tuple, Optional
 
 
3
 
4
  class InferRequest(BaseModel):
5
  question: str
@@ -8,3 +10,30 @@ class InferRequest(BaseModel):
8
  class BotRequest(BaseModel):
9
  history: List[Tuple[str, str]]
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, validator
2
  from typing import List, Tuple, Optional
3
+ from app.utils.utils import generate_uuid
4
+ from typing import Any
5
 
6
  class InferRequest(BaseModel):
7
  question: str
 
10
  class BotRequest(BaseModel):
11
  history: List[Tuple[str, str]]
12
 
13
+ class IChatResponse(BaseModel):
14
+ """Chat response schema."""
15
+
16
+ id: str
17
+ message_id: str
18
+ sender: str
19
+ message: Any
20
+ type: str
21
+ suggested_responses: list[str] = []
22
+
23
+ @validator("id", "message_id", pre=True, allow_reuse=True)
24
+ def check_ids(cls, v):
25
+ if v == "" or v is None:
26
+ return generate_uuid()
27
+ return v
28
+
29
+ # @validator("sender")
30
+ # def sender_must_be_bot_or_you(cls, v):
31
+ # if v not in ["bot", "you"]:
32
+ # raise ValueError("sender must be bot or you")
33
+ # return v
34
+
35
+ # @validator("type")
36
+ # def validate_message_type(cls, v):
37
+ # if v not in ["start", "stream", "end", "error", "info"]:
38
+ # raise ValueError("type must be start, stream or end")
39
+ # return v
innovation_pathfinder_ai/backend/app/utils/callback.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.schemas.message_schema import IChatResponse
2
+ from langchain.callbacks.base import AsyncCallbackHandler
3
+ from app.utils.utils import generate_uuid
4
+ from fastapi import WebSocket
5
+ from uuid import UUID
6
+ from typing import Any
7
+ from langchain.schema.agent import AgentFinish
8
+ from langchain.schema.output import LLMResult
9
+
10
+
11
+ DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", " Answer", ":"]
12
+
13
+
14
+ class CustomAsyncCallbackHandler(AsyncCallbackHandler):
15
+ def append_to_last_tokens(self, token: str) -> None:
16
+ self.last_tokens.append(token)
17
+ self.last_tokens_stripped.append(token.strip())
18
+ if len(self.last_tokens) > len(self.answer_prefix_tokens):
19
+ self.last_tokens.pop(0)
20
+ self.last_tokens_stripped.pop(0)
21
+
22
+ def check_if_answer_reached(self) -> bool:
23
+ if self.strip_tokens:
24
+ return self.last_tokens_stripped == self.answer_prefix_tokens_stripped
25
+ else:
26
+ return self.last_tokens == self.answer_prefix_tokens
27
+
28
+ def update_message_id(self, message_id: str = generate_uuid()):
29
+ self.message_id = message_id
30
+
31
+ def __init__(
32
+ self,
33
+ websocket: WebSocket,
34
+ *,
35
+ message_id: str = generate_uuid(),
36
+ answer_prefix_tokens: list[str] | None = None,
37
+ strip_tokens: bool = True,
38
+ stream_prefix: bool = False,
39
+ ) -> None:
40
+ """Instantiate FinalStreamingStdOutCallbackHandler.
41
+
42
+ Args:
43
+ answer_prefix_tokens: Token sequence that prefixes the answer.
44
+ Default is ["Final", "Answer", ":"]
45
+ strip_tokens: Ignore white spaces and new lines when comparing
46
+ answer_prefix_tokens to last tokens? (to determine if answer has been
47
+ reached)
48
+ stream_prefix: Should answer prefix itself also be streamed?
49
+ """
50
+ self.websocket: WebSocket = websocket
51
+ self.message_id: str = message_id
52
+ self.text: str = ""
53
+ self.started: bool = False
54
+
55
+ if answer_prefix_tokens is None:
56
+ self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS
57
+ else:
58
+ self.answer_prefix_tokens = answer_prefix_tokens
59
+ if strip_tokens:
60
+ self.answer_prefix_tokens_stripped = [
61
+ token.strip() for token in self.answer_prefix_tokens
62
+ ]
63
+ else:
64
+ self.answer_prefix_tokens_stripped = self.answer_prefix_tokens
65
+ self.last_tokens = [""] * len(self.answer_prefix_tokens)
66
+ self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens)
67
+ self.strip_tokens = strip_tokens
68
+ self.stream_prefix = stream_prefix
69
+ self.answer_reached = False
70
+
71
+ async def on_llm_start(
72
+ self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
73
+ ) -> None:
74
+ """Run when LLM starts running."""
75
+
76
+ resp = IChatResponse(
77
+ id="",
78
+ message_id=self.message_id,
79
+ sender="bot",
80
+ message=self.loading_card.to_dict(),
81
+ type="start",
82
+ )
83
+ await self.websocket.send_json(resp.dict())
84
+
85
+ async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
86
+ """Run on new LLM token. Only available when streaming is enabled."""
87
+ # Remember the last n tokens, where n = len(answer_prefix_tokens)
88
+ self.append_to_last_tokens(token)
89
+
90
+ self.text += f"{token}"
91
+ resp = IChatResponse(
92
+ # id=generate_uuid(),
93
+ id="",
94
+ message_id=self.message_id,
95
+ sender="bot",
96
+ message=self.adaptive_card.to_dict(),
97
+ type="stream",
98
+ )
99
+ await self.websocket.send_json(resp.dict())
100
+
101
+ async def on_llm_end(
102
+ self,
103
+ response: LLMResult,
104
+ *,
105
+ run_id: UUID,
106
+ parent_run_id: UUID | None = None,
107
+ tags: list[str] | None = None,
108
+ **kwargs: Any,
109
+ ) -> None:
110
+ """Run when LLM ends running."""
111
+ resp = IChatResponse(
112
+ id="",
113
+ message_id=self.message_id,
114
+ sender="bot",
115
+ message=self.adaptive_card.to_dict(),
116
+ type="end",
117
+ )
118
+ await self.websocket.send_json(resp.dict())
innovation_pathfinder_ai/frontend/app.py CHANGED
@@ -60,6 +60,20 @@ if __name__ == "__main__":
60
  history[-1][1] = response['output']
61
  # all_sources.clear()
62
  return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def infer(question, history):
65
  # result = agent_executor.invoke(
@@ -69,19 +83,16 @@ if __name__ == "__main__":
69
  # }
70
  # )
71
  # return result
72
- async def ask_question_async(question, history):
73
- uri = "ws://localhost:8000/chat/agent" # Update this URI to your actual WebSocket endpoint
74
- async with websockets.connect(uri) as websocket:
75
- # Prepare the message to send (adjust the structure as needed for your backend)
76
- message_data = {
77
- "message": question,
78
- "history": history
79
- }
80
- await websocket.send(json.dumps(message_data))
81
-
82
- # Wait for the response
83
- response_data = await websocket.recv()
84
- return json.loads(response_data)
85
 
86
  # Run the asynchronous function in the synchronous context
87
  result = asyncio.get_event_loop().run_until_complete(ask_question_async(question, history))
@@ -113,7 +124,7 @@ if __name__ == "__main__":
113
  chatbot = gr.Chatbot([],
114
  elem_id="AI Assistant",
115
  bubble_full_width=False,
116
- avatar_images=(None, "./innovation_pathfinder_ai/assets/avatar.png"),
117
  height=480,)
118
  chatbot.like(vote, None, None)
119
  clear = gr.Button("Clear")
 
60
  history[-1][1] = response['output']
61
  # all_sources.clear()
62
  return history
63
+
64
+ async def ask_question_async(question, history):
65
+ uri = "ws://localhost:8000/chat/agent" # Update this URI to your actual WebSocket endpoint
66
+ async with websockets.connect(uri) as websocket:
67
+ # Prepare the message to send (adjust the structure as needed for your backend)
68
+ message_data = {
69
+ "message": question,
70
+ "history": history
71
+ }
72
+ await websocket.send(json.dumps(message_data))
73
+
74
+ # Wait for the response
75
+ response_data = await websocket.recv()
76
+ return json.loads(response_data)
77
 
78
  def infer(question, history):
79
  # result = agent_executor.invoke(
 
83
  # }
84
  # )
85
  # return result
86
+ try:
87
+ # Ensure there's an event loop to run async code
88
+ loop = asyncio.get_event_loop()
89
+ except RuntimeError as ex:
90
+ if "There is no current event loop" in str(ex):
91
+ loop = asyncio.new_event_loop()
92
+ asyncio.set_event_loop(loop)
93
+
94
+ result = loop.run_until_complete(ask_question_async(question, history))
95
+ return result
 
 
 
96
 
97
  # Run the asynchronous function in the synchronous context
98
  result = asyncio.get_event_loop().run_until_complete(ask_question_async(question, history))
 
124
  chatbot = gr.Chatbot([],
125
  elem_id="AI Assistant",
126
  bubble_full_width=False,
127
+ avatar_images=(None, "./assets/avatar.png"),
128
  height=480,)
129
  chatbot.like(vote, None, None)
130
  clear = gr.Button("Clear")