darsoarafa commited on
Commit
91cfc88
·
verified ·
1 Parent(s): 626298d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -219
app.py DELETED
@@ -1,219 +0,0 @@
1
- from __future__ import annotations as _annotations
2
-
3
- import asyncio
4
- import json
5
- import sqlite3
6
- from collections.abc import AsyncIterator
7
- from concurrent.futures.thread import ThreadPoolExecutor
8
- from contextlib import asynccontextmanager
9
- from dataclasses import dataclass
10
- from datetime import datetime, timezone
11
- from functools import partial
12
- from pathlib import Path
13
- from typing import Annotated, Any, Callable, Literal, TypeVar
14
-
15
- import fastapi
16
- #import logfire
17
- from fastapi import Depends, Request
18
- from fastapi.responses import FileResponse, Response, StreamingResponse
19
- from typing_extensions import LiteralString, ParamSpec, TypedDict
20
-
21
- from pydantic_ai import Agent
22
- from pydantic_ai.exceptions import UnexpectedModelBehavior
23
- from pydantic_ai.messages import (
24
- ModelMessage,
25
- ModelMessagesTypeAdapter,
26
- ModelRequest,
27
- ModelResponse,
28
- TextPart,
29
- UserPromptPart,
30
- )
31
-
32
- # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
33
- #logfire.configure(send_to_logfire='if-token-present')
34
-
35
- agent = Agent('openai:gpt-4o')
36
- THIS_DIR = Path(__file__).parent
37
-
38
-
39
- @asynccontextmanager
40
- async def lifespan(_app: fastapi.FastAPI):
41
- async with Database.connect() as db:
42
- yield {'db': db}
43
-
44
-
45
- app = fastapi.FastAPI(lifespan=lifespan)
46
- #logfire.instrument_fastapi(app)
47
-
48
-
49
- @app.get('/')
50
- async def index() -> FileResponse:
51
- return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html')
52
-
53
-
54
- @app.get('/chat_app.ts')
55
- async def main_ts() -> FileResponse:
56
- """Get the raw typescript code, it's compiled in the browser, forgive me."""
57
- return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain')
58
-
59
-
60
- async def get_db(request: Request) -> Database:
61
- return request.state.db
62
-
63
-
64
- @app.get('/chat/')
65
- async def get_chat(database: Database = Depends(get_db)) -> Response:
66
- msgs = await database.get_messages()
67
- return Response(
68
- b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs),
69
- media_type='text/plain',
70
- )
71
-
72
-
73
- class ChatMessage(TypedDict):
74
- """Format of messages sent to the browser."""
75
-
76
- role: Literal['user', 'model']
77
- timestamp: str
78
- content: str
79
-
80
-
81
- def to_chat_message(m: ModelMessage) -> ChatMessage:
82
- first_part = m.parts[0]
83
- if isinstance(m, ModelRequest):
84
- if isinstance(first_part, UserPromptPart):
85
- return {
86
- 'role': 'user',
87
- 'timestamp': first_part.timestamp.isoformat(),
88
- 'content': first_part.content,
89
- }
90
- elif isinstance(m, ModelResponse):
91
- if isinstance(first_part, TextPart):
92
- return {
93
- 'role': 'model',
94
- 'timestamp': m.timestamp.isoformat(),
95
- 'content': first_part.content,
96
- }
97
- raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}')
98
-
99
-
100
- @app.post('/chat/')
101
- async def post_chat(
102
- prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db)
103
- ) -> StreamingResponse:
104
- async def stream_messages():
105
- """Streams new line delimited JSON `Message`s to the client."""
106
- # stream the user prompt so that can be displayed straight away
107
- yield (
108
- json.dumps(
109
- {
110
- 'role': 'user',
111
- 'timestamp': datetime.now(tz=timezone.utc).isoformat(),
112
- 'content': prompt,
113
- }
114
- ).encode('utf-8')
115
- + b'\n'
116
- )
117
- # get the chat history so far to pass as context to the agent
118
- messages = await database.get_messages()
119
- # run the agent with the user prompt and the chat history
120
- async with agent.run_stream(prompt, message_history=messages) as result:
121
- async for text in result.stream(debounce_by=0.01):
122
- # text here is a `str` and the frontend wants
123
- # JSON encoded ModelResponse, so we create one
124
- m = ModelResponse.from_text(content=text, timestamp=result.timestamp())
125
- yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n'
126
-
127
- # add new messages (e.g. the user prompt and the agent response in this case) to the database
128
- await database.add_messages(result.new_messages_json())
129
-
130
- return StreamingResponse(stream_messages(), media_type='text/plain')
131
-
132
-
133
- P = ParamSpec('P')
134
- R = TypeVar('R')
135
-
136
-
137
- @dataclass
138
- class Database:
139
- """Rudimentary database to store chat messages in SQLite.
140
-
141
- The SQLite standard library package is synchronous, so we
142
- use a thread pool executor to run queries asynchronously.
143
- """
144
-
145
- con: sqlite3.Connection
146
- _loop: asyncio.AbstractEventLoop
147
- _executor: ThreadPoolExecutor
148
-
149
- @classmethod
150
- @asynccontextmanager
151
- async def connect(
152
- cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite'
153
- ) -> AsyncIterator[Database]:
154
- #with logfire.span('connect to DB'):
155
- loop = asyncio.get_event_loop()
156
- executor = ThreadPoolExecutor(max_workers=1)
157
- con = await loop.run_in_executor(executor, cls._connect, file)
158
- slf = cls(con, loop, executor)
159
- try:
160
- yield slf
161
- finally:
162
- await slf._asyncify(con.close)
163
-
164
- @staticmethod
165
- def _connect(file: Path) -> sqlite3.Connection:
166
- con = sqlite3.connect(str(file))
167
- #con = logfire.instrument_sqlite3(con)
168
- cur = con.cursor()
169
- cur.execute(
170
- 'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);'
171
- )
172
- con.commit()
173
- return con
174
-
175
- async def add_messages(self, messages: bytes):
176
- await self._asyncify(
177
- self._execute,
178
- 'INSERT INTO messages (message_list) VALUES (?);',
179
- messages,
180
- commit=True,
181
- )
182
- await self._asyncify(self.con.commit)
183
-
184
- async def get_messages(self) -> list[ModelMessage]:
185
- c = await self._asyncify(
186
- self._execute, 'SELECT message_list FROM messages order by id'
187
- )
188
- rows = await self._asyncify(c.fetchall)
189
- messages: list[ModelMessage] = []
190
- for row in rows:
191
- messages.extend(ModelMessagesTypeAdapter.validate_json(row[0]))
192
- return messages
193
-
194
- def _execute(
195
- self, sql: LiteralString, *args: Any, commit: bool = False
196
- ) -> sqlite3.Cursor:
197
- cur = self.con.cursor()
198
- cur.execute(sql, args)
199
- if commit:
200
- self.con.commit()
201
- return cur
202
-
203
- async def _asyncify(
204
- self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
205
- ) -> R:
206
- return await self._loop.run_in_executor( # type: ignore
207
- self._executor,
208
- partial(func, **kwargs),
209
- *args, # type: ignore
210
- )
211
-
212
-
213
- if __name__ == '__main__':
214
- #import uvicorn
215
-
216
- #uvicorn.run(
217
- # 'pydantic_ai_examples.chat_app:app', reload=True, reload_dirs=[str(THIS_DIR)]
218
- #)
219
- demo.launch()