Spaces:
Sleeping
Sleeping
from __future__ import annotations as _annotations | |
import os | |
import asyncio | |
import json | |
import sqlite3 | |
import datetime | |
import fastapi | |
import logfire | |
import time | |
from collections.abc import AsyncIterator | |
from concurrent.futures.thread import ThreadPoolExecutor | |
from contextlib import asynccontextmanager | |
from dataclasses import dataclass | |
from datetime import datetime, timezone, date | |
from functools import partial | |
from pathlib import Path | |
from typing import Annotated, Any, Callable, Literal, TypeVar | |
from pydantic import BaseModel, Field, ValidationError, model_validator | |
from typing import List, Optional, Dict | |
from fastapi import Depends, Request | |
from fastapi.responses import FileResponse, Response, StreamingResponse | |
from typing_extensions import LiteralString, ParamSpec, TypedDict | |
from pydantic_ai import Agent | |
from pydantic_ai.exceptions import UnexpectedModelBehavior | |
from pydantic_ai.messages import ( | |
ModelMessage, | |
ModelMessagesTypeAdapter, | |
ModelRequest, | |
ModelResponse, | |
TextPart, | |
UserPromptPart, | |
) | |
from pydantic_ai.models.openai import OpenAIModel | |
model = OpenAIModel( | |
'gemma-2-2b-it', | |
base_url='http://localhost:1234/v1', | |
api_key='your-local-api-key', | |
) | |
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured | |
logfire.configure(send_to_logfire='if-token-present') | |
class ClinicalNoteResult(BaseModel): | |
entities: list | |
message: str | |
# # Create a system prompt to guide the model | |
system_prompt="Anda adalah dokter medis yang membantu mengekstrak informasi dari catatan klinis. Hasil extract adalah menjadi format JSON" | |
#INI SAJA. SALAH SATU | |
agent = Agent('gemini-1.5-flash', system_prompt=system_prompt) # OK-Gemini | |
#agent = Agent(model) # OK-Lokal | |
THIS_DIR = Path(__file__).parent | |
async def lifespan(_app: fastapi.FastAPI): | |
async with Database.connect() as db: | |
yield {'db': db} | |
app = fastapi.FastAPI(lifespan=lifespan) | |
logfire.instrument_fastapi(app) | |
async def index() -> FileResponse: | |
return FileResponse((THIS_DIR / 'chat_app.html'), media_type='text/html') | |
async def main_ts() -> FileResponse: | |
"""Get the raw typescript code, it's compiled in the browser, forgive me.""" | |
return FileResponse((THIS_DIR / 'chat_app.ts'), media_type='text/plain') | |
async def get_db(request: Request) -> Database: | |
return request.state.db | |
async def get_chat(database: Database = Depends(get_db)) -> Response: | |
msgs = await database.get_messages() | |
return Response( | |
b'\n'.join(json.dumps(to_chat_message(m)).encode('utf-8') for m in msgs), | |
media_type='text/plain', | |
) | |
class ChatMessage(TypedDict): | |
"""Format of messages sent to the browser.""" | |
role: Literal['user', 'model'] | |
timestamp: str | |
content: str | |
def to_chat_message(m: ModelMessage) -> ChatMessage: | |
first_part = m.parts[0] | |
if isinstance(m, ModelRequest): | |
first_part = m.parts[1] | |
if isinstance(first_part, UserPromptPart): | |
return { | |
'role': 'user', | |
'timestamp': first_part.timestamp.isoformat(), | |
'content': first_part.content, | |
} | |
elif isinstance(m, ModelResponse): | |
if isinstance(first_part, TextPart): | |
return { | |
'role': 'model', | |
'timestamp': m.timestamp.isoformat(), | |
'content': first_part.content, | |
} | |
raise UnexpectedModelBehavior(f'Unexpected message type for chat app: {m}') | |
def to_ds_message(m: ModelMessage) -> ChatMessage: | |
if isinstance(m, ModelRequest): | |
first_part = m.parts[1] | |
if isinstance(first_part, UserPromptPart): | |
return { | |
'role': 'user', | |
'timestamp': first_part.timestamp.isoformat(), | |
'content': first_part.content, | |
} | |
elif isinstance(m, ModelResponse): | |
first_part = m.parts[0] | |
if isinstance(first_part, TextPart): | |
return { | |
'role': 'model', | |
'timestamp': m.timestamp.isoformat(), | |
'content': first_part.content, | |
} | |
raise UnexpectedModelBehavior(f'Unexpected ds-message type for chat app: {m}') | |
async def post_chat( | |
prompt: Annotated[str, fastapi.Form()], database: Database = Depends(get_db) | |
) -> StreamingResponse: | |
async def stream_messages(): | |
"""Streams new line delimited JSON `Message`s to the client.""" | |
# stream the user prompt so that can be displayed straight away | |
yield ( | |
json.dumps( | |
{ | |
'role': 'user', | |
'timestamp': datetime.now(tz=timezone.utc).isoformat(), | |
'content': prompt, | |
} | |
).encode('utf-8') | |
+ b'\n' | |
) | |
## get the chat history so far to pass as context to the agent | |
#messages = await database.get_messages() | |
## run the agent with the user prompt and the chat history | |
async with agent.run_stream(prompt) as result: | |
async for text in result.stream(debounce_by=0.01): | |
# text here is a `str` and the frontend wants | |
# JSON encoded ModelResponse, so we create one | |
m = ModelResponse.from_text(content=text, timestamp=result.timestamp()) | |
yield json.dumps(to_chat_message(m)).encode('utf-8') + b'\n' | |
# add new messages (e.g. the user prompt and the agent response in this case) to the database | |
print("---",result.new_messages_json(),"---") | |
#print("***",prompt,"***") | |
await database.add_messages(result.new_messages_json()) | |
if prompt[0] == "@" : | |
#print("@@@", prompt, "@@@") | |
nn = len(prompt) | |
prompt = prompt[1:nn] | |
print(">>>", prompt, "<<<") | |
return StreamingResponse(stream_messages(), media_type='text/plain') | |
elif prompt[0] != "@" : | |
#print("biasa") | |
return StreamingResponse(stream_messages(), media_type='text/plain') | |
print("** selesai **") | |
return StreamingResponse(stream_messages(), media_type='text/plain') | |
P = ParamSpec('P') | |
R = TypeVar('R') | |
class Database: | |
"""Rudimentary database to store chat messages in SQLite. | |
The SQLite standard library package is synchronous, so we | |
use a thread pool executor to run queries asynchronously. | |
""" | |
con: sqlite3.Connection | |
_loop: asyncio.AbstractEventLoop | |
_executor: ThreadPoolExecutor | |
async def connect( | |
cls, file: Path = THIS_DIR / '.chat_messages.sqlite' | |
) -> AsyncIterator[Database]: | |
with logfire.span('connect to DB'): | |
loop = asyncio.get_event_loop() | |
executor = ThreadPoolExecutor(max_workers=1) | |
con = await loop.run_in_executor(executor, cls._connect, file) | |
slf = cls(con, loop, executor) | |
try: | |
yield slf | |
finally: | |
await slf._asyncify(con.close) | |
def _connect(file: Path) -> sqlite3.Connection: | |
con = sqlite3.connect(str(file)) | |
con = logfire.instrument_sqlite3(con) | |
cur = con.cursor() | |
cur.execute( | |
'CREATE TABLE IF NOT EXISTS messages (id INT PRIMARY KEY, message_list TEXT);' | |
) | |
con.commit() | |
return con | |
async def add_messages(self, messages: bytes): | |
await self._asyncify( | |
self._execute, | |
'INSERT INTO messages (message_list) VALUES (?);', | |
messages, | |
commit=True, | |
) | |
await self._asyncify(self.con.commit) | |
async def get_messages(self) -> list[ModelMessage]: | |
c = await self._asyncify( | |
self._execute, 'SELECT message_list FROM messages order by id asc' | |
) | |
rows = await self._asyncify(c.fetchall) | |
messages: list[ModelMessage] = [] | |
for row in rows: | |
messages.extend(ModelMessagesTypeAdapter.validate_json(row[0])) | |
return messages | |
def _execute( | |
self, sql: LiteralString, *args: Any, commit: bool = False | |
) -> sqlite3.Cursor: | |
cur = self.con.cursor() | |
cur.execute(sql, args) | |
if commit: | |
self.con.commit() | |
return cur | |
async def _asyncify( | |
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs | |
) -> R: | |
return await self._loop.run_in_executor( # type: ignore | |
self._executor, | |
partial(func, **kwargs), | |
*args, # type: ignore | |
) | |
if __name__ == '__main__': | |
import uvicorn | |
uvicorn.run( | |
'app:app', reload=True, host="0.0.0.0", port=7860, reload_dirs=[str(THIS_DIR)] | |
) |