Hansimov commited on
Commit
4c9b469
1 Parent(s): 82cb440

:gem: [Feature] New API with stream reponse via POST request

Browse files
apis/chat_api.py CHANGED
@@ -1,7 +1,8 @@
 
1
  import uvicorn
2
 
3
- from fastapi import FastAPI, APIRouter, WebSocket, WebSocketDisconnect
4
- from fastapi.routing import APIRoute
5
  from pydantic import BaseModel, Field
6
  from conversations import (
7
  ConversationConnector,
@@ -9,6 +10,8 @@ from conversations import (
9
  ConversationSession,
10
  )
11
 
 
 
12
 
13
  class ChatAPIApp:
14
  def __init__(self):
@@ -59,9 +62,9 @@ class ChatAPIApp:
59
  creator.create()
60
  return {
61
  "model": item.model,
62
- "conversation_id": creator.conversation_id,
63
- "client_id": creator.client_id,
64
  "sec_access_token": creator.sec_access_token,
 
 
65
  }
66
 
67
  class ChatPostItem(BaseModel):
@@ -105,6 +108,51 @@ class ChatAPIApp:
105
  with session:
106
  session.chat(prompt=item.prompt)
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def setup_routes(self):
109
  self.app.get(
110
  "/models",
@@ -121,6 +169,11 @@ class ChatAPIApp:
121
  summary="Chat in conversation session",
122
  )(self.chat)
123
 
 
 
 
 
 
124
 
125
  app = ChatAPIApp().app
126
 
 
1
+ import json
2
  import uvicorn
3
 
4
+ from fastapi import FastAPI
5
+ from fastapi.responses import StreamingResponse
6
  from pydantic import BaseModel, Field
7
  from conversations import (
8
  ConversationConnector,
 
10
  ConversationSession,
11
  )
12
 
13
+ from networks import StreamResponseConstructor
14
+
15
 
16
  class ChatAPIApp:
17
  def __init__(self):
 
62
  creator.create()
63
  return {
64
  "model": item.model,
 
 
65
  "sec_access_token": creator.sec_access_token,
66
+ "client_id": creator.client_id,
67
+ "conversation_id": creator.conversation_id,
68
  }
69
 
70
  class ChatPostItem(BaseModel):
 
108
  with session:
109
  session.chat(prompt=item.prompt)
110
 
111
+ class ChatCompletionsPostItem(BaseModel):
112
+ model: str = Field(
113
+ default="precise",
114
+ description="(str) `precise`, `balanced`, `creative`, `precise-offline`, `balanced-offline`, `creative-offline`",
115
+ )
116
+ messages: list = Field(
117
+ default=[{"role": "user", "content": "Hello, who are you?"}],
118
+ description="(list) Messages",
119
+ )
120
+ sec_access_token: str = Field(
121
+ default="",
122
+ description="(str) Sec Access Token",
123
+ )
124
+ client_id: str = Field(
125
+ default="",
126
+ description="(str) Client ID",
127
+ )
128
+ conversation_id: str = Field(
129
+ default="",
130
+ description="(str) Conversation ID",
131
+ )
132
+ invocation_id: int = Field(
133
+ default=0,
134
+ description="(int) Invocation ID",
135
+ )
136
+
137
+ async def chat_completions(self, item: ChatCompletionsPostItem):
138
+ connector = ConversationConnector(
139
+ conversation_style=item.model,
140
+ sec_access_token=item.sec_access_token,
141
+ client_id=item.client_id,
142
+ conversation_id=item.conversation_id,
143
+ invocation_id=item.invocation_id,
144
+ )
145
+
146
+ if item.invocation_id == 0:
147
+ # TODO: History Messages Merger
148
+ prompt = item.messages[-1]["content"]
149
+ else:
150
+ prompt = item.messages[-1]["content"]
151
+
152
+ return StreamingResponse(
153
+ connector.stream_chat(prompt=prompt, yield_output=True)
154
+ )
155
+
156
  def setup_routes(self):
157
  self.app.get(
158
  "/models",
 
169
  summary="Chat in conversation session",
170
  )(self.chat)
171
 
172
+ self.app.post(
173
+ "/chat/completions",
174
+ summary="Chat completions in conversation session",
175
+ )(self.chat_completions)
176
+
177
 
178
  app = ChatAPIApp().app
179
 
conversations/conversation_connector.py CHANGED
@@ -8,7 +8,7 @@ from networks import (
8
  ChathubRequestPayloadConstructor,
9
  ConversationRequestHeadersConstructor,
10
  )
11
- from networks import MessageParser, IdleOutputer
12
  from utils.logger import logger
13
 
14
  http_proxy = "http://localhost:11111" # Replace with yours
@@ -77,10 +77,10 @@ class ConversationConnector:
77
  self.connect_request_payload = payload_constructor.request_payload
78
  await self.wss_send(self.connect_request_payload)
79
 
80
- async def stream_chat(self, prompt=""):
81
  await self.connect()
82
  await self.send_chathub_request(prompt)
83
- message_parser = MessageParser(outputer=IdleOutputer())
84
  while not self.wss.closed:
85
  response_lines_str = await self.wss.receive_str()
86
  if isinstance(response_lines_str, str):
@@ -93,7 +93,10 @@ class ConversationConnector:
93
  data = json.loads(line)
94
  # Stream: Meaningful Messages
95
  if data.get("type") == 1:
96
- message_parser.parse(data)
 
 
 
97
  # Stream: List of all messages in the whole conversation
98
  elif data.get("type") == 2:
99
  if data.get("item"):
@@ -102,10 +105,21 @@ class ConversationConnector:
102
  pass
103
  # Stream: End of Conversation
104
  elif data.get("type") == 3:
105
- logger.success("\n[Finished]")
 
106
  self.invocation_id += 1
107
  await self.wss.close()
108
  await self.aiohttp_session.close()
 
 
 
 
 
 
 
 
 
 
109
  break
110
  # Stream: Heartbeat Signal
111
  elif data.get("type") == 6:
 
8
  ChathubRequestPayloadConstructor,
9
  ConversationRequestHeadersConstructor,
10
  )
11
+ from networks import MessageParser, IdleOutputer, ContentJSONOutputer
12
  from utils.logger import logger
13
 
14
  http_proxy = "http://localhost:11111" # Replace with yours
 
77
  self.connect_request_payload = payload_constructor.request_payload
78
  await self.wss_send(self.connect_request_payload)
79
 
80
+ async def stream_chat(self, prompt="", yield_output=False):
81
  await self.connect()
82
  await self.send_chathub_request(prompt)
83
+ message_parser = MessageParser(outputer=ContentJSONOutputer())
84
  while not self.wss.closed:
85
  response_lines_str = await self.wss.receive_str()
86
  if isinstance(response_lines_str, str):
 
93
  data = json.loads(line)
94
  # Stream: Meaningful Messages
95
  if data.get("type") == 1:
96
+ if yield_output:
97
+ yield message_parser.parse(data, return_output=True)
98
+ else:
99
+ message_parser.parse(data)
100
  # Stream: List of all messages in the whole conversation
101
  elif data.get("type") == 2:
102
  if data.get("item"):
 
105
  pass
106
  # Stream: End of Conversation
107
  elif data.get("type") == 3:
108
+ finished_str = "\n[Finished]"
109
+ logger.success(finished_str)
110
  self.invocation_id += 1
111
  await self.wss.close()
112
  await self.aiohttp_session.close()
113
+ if yield_output:
114
+ yield (
115
+ json.dumps(
116
+ {
117
+ "content": finished_str,
118
+ "content_type": "Finished",
119
+ }
120
+ )
121
+ + "\n"
122
+ ).encode("utf-8")
123
  break
124
  # Stream: Heartbeat Signal
125
  elif data.get("type") == 6:
networks/__init__.py CHANGED
@@ -3,5 +3,6 @@ from .cookies_constructor import CookiesConstructor
3
  from .conversation_request_headers_constructor import (
4
  ConversationRequestHeadersConstructor,
5
  )
6
- from .message_outputer import IdleOutputer
7
  from .message_parser import MessageParser
 
 
3
  from .conversation_request_headers_constructor import (
4
  ConversationRequestHeadersConstructor,
5
  )
6
+ from .message_outputer import IdleOutputer, ContentJSONOutputer
7
  from .message_parser import MessageParser
8
+ from .stream_response_constructor import StreamResponseConstructor
networks/message_outputer.py CHANGED
@@ -1,3 +1,19 @@
 
 
 
1
  class IdleOutputer:
2
- def output(self, message_content=None, message_type=None):
3
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+
4
  class IdleOutputer:
5
+ def output(self, content=None, content_type=None):
6
+ return json.dumps({}).encode("utf-8")
7
+
8
+
9
+ class ContentJSONOutputer:
10
+ def output(self, content=None, content_type=None):
11
+ return (
12
+ json.dumps(
13
+ {
14
+ "content": content,
15
+ "content_type": content_type,
16
+ }
17
+ )
18
+ + "\n"
19
+ ).encode("utf-8")
networks/message_parser.py CHANGED
@@ -1,13 +1,15 @@
 
 
1
  from utils.logger import logger
2
- from networks import IdleOutputer
3
 
4
 
5
  class MessageParser:
6
- def __init__(self, outputer=IdleOutputer()):
7
  self.delta_content_pointer = 0
8
  self.outputer = outputer
9
 
10
- def parse(self, data):
11
  arguments = data["arguments"][0]
12
  if arguments.get("throttling"):
13
  throttling = arguments.get("throttling")
@@ -20,7 +22,6 @@ class MessageParser:
20
  content = message["adaptiveCards"][0]["body"][0]["text"]
21
  delta_content = content[self.delta_content_pointer :]
22
  logger.line(delta_content, end="")
23
- self.outputer.output(delta_content, message_type="Completions")
24
  self.delta_content_pointer = len(content)
25
  # Message: Suggested Questions
26
  if message.get("suggestedResponses"):
@@ -28,20 +29,34 @@ class MessageParser:
28
  for suggestion in message.get("suggestedResponses"):
29
  suggestion_text = suggestion.get("text")
30
  logger.file(f"- {suggestion_text}")
31
- self.outputer.output(
32
- message.get("suggestedResponses"),
33
- message_type="Suggestions",
34
  )
 
 
 
 
 
 
 
35
  # Message: Search Query
36
  elif message_type in ["InternalSearchQuery"]:
37
  message_hidden_text = message["hiddenText"]
38
- logger.note(f"\n[Searching: [{message_hidden_text}]]")
39
- self.outputer.output(
40
- message_hidden_text, message_type="InternalSearchQuery"
41
- )
 
 
42
  # Message: Internal Search Results
43
  elif message_type in ["InternalSearchResult"]:
44
- logger.note("[Analyzing search results ...]")
 
 
 
 
 
45
  # Message: Loader status, such as "Generating Answers"
46
  elif message_type in ["InternalLoaderMessage"]:
47
  # logger.note("[Generating answers ...]\n")
@@ -62,3 +77,15 @@ class MessageParser:
62
  raise NotImplementedError(
63
  f"Not Supported Message Type: {message_type}"
64
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
  from utils.logger import logger
4
+ from networks import IdleOutputer, ContentJSONOutputer
5
 
6
 
7
  class MessageParser:
8
+ def __init__(self, outputer=ContentJSONOutputer()):
9
  self.delta_content_pointer = 0
10
  self.outputer = outputer
11
 
12
+ def parse(self, data, return_output=False):
13
  arguments = data["arguments"][0]
14
  if arguments.get("throttling"):
15
  throttling = arguments.get("throttling")
 
22
  content = message["adaptiveCards"][0]["body"][0]["text"]
23
  delta_content = content[self.delta_content_pointer :]
24
  logger.line(delta_content, end="")
 
25
  self.delta_content_pointer = len(content)
26
  # Message: Suggested Questions
27
  if message.get("suggestedResponses"):
 
29
  for suggestion in message.get("suggestedResponses"):
30
  suggestion_text = suggestion.get("text")
31
  logger.file(f"- {suggestion_text}")
32
+ if return_output:
33
+ output_bytes = self.outputer.output(
34
+ delta_content, content_type="Completions"
35
  )
36
+ if message.get("suggestedResponses"):
37
+ output_bytes += self.outputer.output(
38
+ message.get("suggestedResponses"),
39
+ content_type="SuggestedResponses",
40
+ )
41
+ return output_bytes
42
+
43
  # Message: Search Query
44
  elif message_type in ["InternalSearchQuery"]:
45
  message_hidden_text = message["hiddenText"]
46
+ search_str = f"\n[Searching: [{message_hidden_text}]]"
47
+ logger.note(search_str)
48
+ if return_output:
49
+ return self.outputer.output(
50
+ search_str, content_type="InternalSearchQuery"
51
+ )
52
  # Message: Internal Search Results
53
  elif message_type in ["InternalSearchResult"]:
54
+ analysis_str = f"\n[Analyzing search results ...]"
55
+ logger.note(analysis_str)
56
+ if return_output:
57
+ return self.outputer.output(
58
+ analysis_str, content_type="InternalSearchResult"
59
+ )
60
  # Message: Loader status, such as "Generating Answers"
61
  elif message_type in ["InternalLoaderMessage"]:
62
  # logger.note("[Generating answers ...]\n")
 
77
  raise NotImplementedError(
78
  f"Not Supported Message Type: {message_type}"
79
  )
80
+
81
+ return (
82
+ (
83
+ json.dumps(
84
+ {
85
+ "content": "",
86
+ "content_type": "NotImplemented",
87
+ }
88
+ )
89
+ )
90
+ + "\n"
91
+ ).encode("utf-8")