Spaces:
Sleeping
Sleeping
Delete app.py
Browse files
app.py
DELETED
@@ -1,219 +0,0 @@
|
|
1 |
-
from __future__ import annotations as _annotations
|
2 |
-
|
3 |
-
import asyncio
|
4 |
-
import json
|
5 |
-
import sqlite3
|
6 |
-
from collections.abc import AsyncIterator
|
7 |
-
from concurrent.futures.thread import ThreadPoolExecutor
|
8 |
-
from contextlib import asynccontextmanager
|
9 |
-
from dataclasses import dataclass
|
10 |
-
from datetime import datetime, timezone
|
11 |
-
from functools import partial
|
12 |
-
from pathlib import Path
|
13 |
-
from typing import Annotated, Any, Callable, Literal, TypeVar
|
14 |
-
|
15 |
-
import fastapi
|
16 |
-
#import logfire
|
17 |
-
from fastapi import Depends, Request
|
18 |
-
from fastapi.responses import FileResponse, Response, StreamingResponse
|
19 |
-
from typing_extensions import LiteralString, ParamSpec, TypedDict
|
20 |
-
|
21 |
-
from pydantic_ai import Agent
|
22 |
-
from pydantic_ai.exceptions import UnexpectedModelBehavior
|
23 |
-
from pydantic_ai.messages import (
|
24 |
-
ModelMessage,
|
25 |
-
ModelMessagesTypeAdapter,
|
26 |
-
ModelRequest,
|
27 |
-
ModelResponse,
|
28 |
-
TextPart,
|
29 |
-
UserPromptPart,
|
30 |
-
)
|
31 |
-
|
32 |
-
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
|
33 |
-
#logfire.configure(send_to_logfire='if-token-present')
|
34 |
-
|
35 |
-
agent = Agent('openai:gpt-4o')
|
36 |
-
THIS_DIR = Path(__file__).parent
|
37 |
-
|
38 |
-
|
39 |
-
@asynccontextmanager
|
40 |
-
async def lifespan(_app: fastapi.FastAPI):
|
41 |
-
async with Database.connect() as db:
|
42 |
-
yield {'db': db}
|
43 |
-
|
44 |
-
|
45 |
-
app = fastapi.FastAPI(lifespan=lifespan)
|
46 |
-
#logfire.instrument_fastapi(app)
|
47 |
-
|
48 |
-
|
49 |
-
@app.get('/')
|
50 |
-
async def index() -> FileResponse:
|
51 |
-
return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html')
|
52 |
-
|
53 |
-
|
54 |
-
@app.get('/chat_app.ts')
|
55 |
-
async def main_ts() -> FileResponse:
|
56 |
-
"""Get the raw typescript code, it's compiled in the browser, forgive me."""
|
57 |
-
return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain')
|
58 |
-
|
59 |
-
|
60 |
-
async def get_db(request: Request) -> Database:
|
61 |
-
return request.state.db
|
62 |
-
|
63 |
-
|
64 |
-
@app.get('/chat/')
|
65 |
-
async def get_chat(database: Database = Depends(get_db)) -> Response:
|
66 |
-
msgs = await database.get_messages()
|
67 |
-
return Response(
|
68 |
-
b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs),
|
69 |
-
media_type='text/plain',
|
70 |
-
)
|
71 |
-
|
72 |
-
|
73 |
-
class ChatMessage(TypedDict):
|
74 |
-
"""Format of messages sent to the browser."""
|
75 |
-
|
76 |
-
role: Literal['user', 'model']
|
77 |
-
timestamp: str
|
78 |
-
content: str
|
79 |
-
|
80 |
-
|
81 |
-
def to_chat_message(m: ModelMessage) -> ChatMessage:
|
82 |
-
first_part = m.parts[0]
|
83 |
-
if isinstance(m, ModelRequest):
|
84 |
-
if isinstance(first_part, UserPromptPart):
|
85 |
-
return {
|
86 |
-
'role': 'user',
|
87 |
-
'timestamp': first_part.timestamp.isoformat(),
|
88 |
-
'content': first_part.content,
|
89 |
-
}
|
90 |
-
elif isinstance(m, ModelResponse):
|
91 |
-
if isinstance(first_part, TextPart):
|
92 |
-
return {
|
93 |
-
'role': 'model',
|
94 |
-
'timestamp': m.timestamp.isoformat(),
|
95 |
-
'content': first_part.content,
|
96 |
-
}
|
97 |
-
raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}')
|
98 |
-
|
99 |
-
|
100 |
-
@app.post('/chat/')
|
101 |
-
async def post_chat(
|
102 |
-
prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db)
|
103 |
-
) -> StreamingResponse:
|
104 |
-
async def stream_messages():
|
105 |
-
"""Streams new line delimited JSON `Message`s to the client."""
|
106 |
-
# stream the user prompt so that can be displayed straight away
|
107 |
-
yield (
|
108 |
-
json.dumps(
|
109 |
-
{
|
110 |
-
'role': 'user',
|
111 |
-
'timestamp': datetime.now(tz=timezone.utc).isoformat(),
|
112 |
-
'content': prompt,
|
113 |
-
}
|
114 |
-
).encode('utf-8')
|
115 |
-
+ b'\n'
|
116 |
-
)
|
117 |
-
# get the chat history so far to pass as context to the agent
|
118 |
-
messages = await database.get_messages()
|
119 |
-
# run the agent with the user prompt and the chat history
|
120 |
-
async with agent.run_stream(prompt, message_history=messages) as result:
|
121 |
-
async for text in result.stream(debounce_by=0.01):
|
122 |
-
# text here is a `str` and the frontend wants
|
123 |
-
# JSON encoded ModelResponse, so we create one
|
124 |
-
m = ModelResponse.from_text(content=text, timestamp=result.timestamp())
|
125 |
-
yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n'
|
126 |
-
|
127 |
-
# add new messages (e.g. the user prompt and the agent response in this case) to the database
|
128 |
-
await database.add_messages(result.new_messages_json())
|
129 |
-
|
130 |
-
return StreamingResponse(stream_messages(), media_type='text/plain')
|
131 |
-
|
132 |
-
|
133 |
-
P = ParamSpec('P')
|
134 |
-
R = TypeVar('R')
|
135 |
-
|
136 |
-
|
137 |
-
@dataclass
|
138 |
-
class Database:
|
139 |
-
"""Rudimentary database to store chat messages in SQLite.
|
140 |
-
|
141 |
-
The SQLite standard library package is synchronous, so we
|
142 |
-
use a thread pool executor to run queries asynchronously.
|
143 |
-
"""
|
144 |
-
|
145 |
-
con: sqlite3.Connection
|
146 |
-
_loop: asyncio.AbstractEventLoop
|
147 |
-
_executor: ThreadPoolExecutor
|
148 |
-
|
149 |
-
@classmethod
|
150 |
-
@asynccontextmanager
|
151 |
-
async def connect(
|
152 |
-
cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite'
|
153 |
-
) -> AsyncIterator[Database]:
|
154 |
-
#with logfire.span('connect to DB'):
|
155 |
-
loop = asyncio.get_event_loop()
|
156 |
-
executor = ThreadPoolExecutor(max_workers=1)
|
157 |
-
con = await loop.run_in_executor(executor, cls._connect, file)
|
158 |
-
slf = cls(con, loop, executor)
|
159 |
-
try:
|
160 |
-
yield slf
|
161 |
-
finally:
|
162 |
-
await slf._asyncify(con.close)
|
163 |
-
|
164 |
-
@staticmethod
|
165 |
-
def _connect(file: Path) -> sqlite3.Connection:
|
166 |
-
con = sqlite3.connect(str(file))
|
167 |
-
#con = logfire.instrument_sqlite3(con)
|
168 |
-
cur = con.cursor()
|
169 |
-
cur.execute(
|
170 |
-
'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);'
|
171 |
-
)
|
172 |
-
con.commit()
|
173 |
-
return con
|
174 |
-
|
175 |
-
async def add_messages(self, messages: bytes):
|
176 |
-
await self._asyncify(
|
177 |
-
self._execute,
|
178 |
-
'INSERT INTO messages (message_list) VALUES (?);',
|
179 |
-
messages,
|
180 |
-
commit=True,
|
181 |
-
)
|
182 |
-
await self._asyncify(self.con.commit)
|
183 |
-
|
184 |
-
async def get_messages(self) -> list[ModelMessage]:
|
185 |
-
c = await self._asyncify(
|
186 |
-
self._execute, 'SELECT message_list FROM messages order by id'
|
187 |
-
)
|
188 |
-
rows = await self._asyncify(c.fetchall)
|
189 |
-
messages: list[ModelMessage] = []
|
190 |
-
for row in rows:
|
191 |
-
messages.extend(ModelMessagesTypeAdapter.validate_json(row[0]))
|
192 |
-
return messages
|
193 |
-
|
194 |
-
def _execute(
|
195 |
-
self, sql: LiteralString, *args: Any, commit: bool = False
|
196 |
-
) -> sqlite3.Cursor:
|
197 |
-
cur = self.con.cursor()
|
198 |
-
cur.execute(sql, args)
|
199 |
-
if commit:
|
200 |
-
self.con.commit()
|
201 |
-
return cur
|
202 |
-
|
203 |
-
async def _asyncify(
|
204 |
-
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
205 |
-
) -> R:
|
206 |
-
return await self._loop.run_in_executor( # type: ignore
|
207 |
-
self._executor,
|
208 |
-
partial(func, **kwargs),
|
209 |
-
*args, # type: ignore
|
210 |
-
)
|
211 |
-
|
212 |
-
|
213 |
-
if __name__ == '__main__':
|
214 |
-
#import uvicorn
|
215 |
-
|
216 |
-
#uvicorn.run(
|
217 |
-
# 'pydantic_ai_examples.chat_app:app', reload=True, reload_dirs=[str(THIS_DIR)]
|
218 |
-
#)
|
219 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|