Spaces:
Sleeping
Sleeping
File size: 8,693 Bytes
212b42a 6f96ca2 212b42a 5a2b2d3 8dbef5d 14d48df 5a2b2d3 14d48df 5a2b2d3 27306dd 14d48df 27306dd 14d48df 27306dd c2c40ec 9a34296 78c941e 5a2b2d3 78c941e 5a2b2d3 78c941e 5a2b2d3 78c941e 5a2b2d3 78c941e 5a2b2d3 78c941e 5a2b2d3 78c941e 5a2b2d3 14d48df 6f96ca2 14d48df 6f96ca2 14d48df 8dbef5d 5a2b2d3 8dbef5d 5a2b2d3 212b42a 5a2b2d3 b1cee42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import sys
import os
from datetime import datetime
from langchain_core.runnables import Runnable
from langchain_core.callbacks import BaseCallbackHandler
from fastapi import FastAPI, Request, Depends
from sse_starlette.sse import EventSourceResponse
from langserve.serialization import WellKnownLCSerializer
from typing import List
from sqlalchemy.orm import Session
import schemas
from chains import simple_chain, formatted_chain, history_chain, rag_chain, filtered_rag_chain
import crud, models, schemas, prompts
from database import SessionLocal, engine
from callbacks import LogResponseCallback
# models.Base comes from SQLAlchemy’s declarative_base() in database.py.
# It acts as the base class for all ORM models (defined in models.py).
# .metadata.create_all(): Tells SQLAlchemy to create all the tables defined
# in the models module if they don’t already exist in the database.
# -> metadata is a catalog of all the tables and other schema constructs in your database.
# -> create_all() method creates all the tables that don't exist yet in the database.
# -> bind=engine specifies which database engine to use for this operation.
models.Base.metadata.create_all(bind=engine)
app = FastAPI()
def get_db():
"""This is a dependency function used to create and provide a
database session to various endpoints in the FastAPI app.
"""
# A new SQLAlchemy session is created using the SessionLocal session factory.
# This session will be used for database transactions.
db = SessionLocal()
# This pattern ensures that each request gets its own database session and that
# the session is properly closed when the request is finished, preventing resource leaks.
try:
yield db
finally:
db.close()
# ..
# "async" marks the function as asynchronous, allowing it to pause and resume during operations like streaming or I/O.
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
"""generate_stream is an asynchronous generator that processes input data,
streams output data from a runnable object, serializes each output, and yields
it to the client in real-time as part of a server-sent event (SSE) stream.
It uses callbacks to customize the processing, serializes each piece of output
using WellKnownLCSerializer, and indicates the end of the stream with a final “end” event.
"""
for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
data = WellKnownLCSerializer().dumps(output).decode("utf-8")
yield {'data': data, "event": "data"}
# After all the data has been streamed and the loop is complete, the function yields a final event to signal
# the end of the stream. This sends an {"event": "end"} message to the client, letting them know that no more
# data will be sent.
yield {"event": "end"}
# This registers the function simple_stream as a handler for HTTP POST requests at the URL endpoint /simple/stream.
# It means that when a client sends a POST request to this endpoint, this function will be triggered.
@app.post("/simple/stream")
async def simple_stream(request: Request):
"""the function handles a POST request at the /simple/stream endpoint,
extracts the JSON body, unpacks the "input" field, and then uses it to
initialize a UserQuestion schema object (which performs validation
and data transformation) and then initiates a server-sent event response
to stream data back to the client based on the user’s question.
"""
# await is used because parsing the JSON may involve asynchronous I/O operations,
# especially when handling larger payloads.
data = await request.json()
user_question = schemas.UserQuestion(**data['input'])
# This line returns an EventSourceResponse, which is typically used to handle server-sent events (SSE).
# It’s a special kind of response that streams data back to the client in real time.
return EventSourceResponse(generate_stream(user_question, simple_chain))
@app.post("/formatted/stream")
async def formatted_stream(request: Request):
# TODO: use the formatted_chain to implement the "/formatted/stream" endpoint.
data = await request.json()
user_question = schemas.UserQuestion(**data['input'])
return EventSourceResponse(generate_stream(user_question, formatted_chain))
@app.post("/history/stream")
async def history_stream(request: Request, db: Session = Depends(get_db)):
# TODO: Let's implement the "/history/stream" endpoint. The endpoint should follow those steps:
# - The endpoint receives the request
data = await request.json()
# - The request is parsed into a user request
user_request = schemas.UserRequest(**data['input'])
# - The user request is used to pull the chat history of the user
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
# - We add as part of the user history the current question by using add_message.
message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
crud.add_message(db, message=message, username=user_request.username)
# - We create an instance of HistoryInput by using format_chat_history.
history_input = schemas.HistoryInput(
question=user_request.question,
chat_history=prompts.format_chat_history(chat_history)
)
# - We use the history input within the history chain.
return EventSourceResponse(generate_stream(
history_input, history_chain, [LogResponseCallback(user_request, db)]
))
@app.post("/rag/stream")
async def rag_stream(request: Request, db: Session = Depends(get_db)):
# TODO: Let's implement the "/rag/stream" endpoint. The endpoint should follow those steps:
# - The endpoint receives the request
# - The request is parsed into a user request
# - The user request is used to pull the chat history of the user
# - We add as part of the user history the current question by using add_message.
# - We create an instance of HistoryInput by using format_chat_history.
# - We use the history input within the rag chain.
data = await request.json()
user_request = schemas.UserRequest(**data['input'])
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
crud.add_message(db, message=message, username=user_request.username)
rag_input = schemas.HistoryInput(
question=user_request.question,
chat_history=prompts.format_chat_history(chat_history)
)
return EventSourceResponse(generate_stream(
rag_input, rag_chain, [LogResponseCallback(user_request, db)]
))
@app.post("/filtered_rag/stream")
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
# TODO: Let's implement the "/filtered_rag/stream" endpoint. The endpoint should follow those steps:
# - The endpoint receives the request
# - The request is parsed into a user request
# - The user request is used to pull the chat history of the user
# - We add as part of the user history the current question by using add_message.
# - We create an instance of HistoryInput by using format_chat_history.
# - We use the history input within the filtered rag chain.
data = await request.json()
user_request = schemas.UserRequest(**data['input'])
chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
crud.add_message(db, message=message, username=user_request.username)
rag_input = schemas.HistoryInput(
question=user_request.question,
chat_history=prompts.format_chat_history(chat_history)
)
return EventSourceResponse(generate_stream(
rag_input, filtered_rag_chain, [LogResponseCallback(user_request, db)]
))
# Run From the Parent Directory with Script
# If you want to use uvicorn.run from within a script using "app.main:app",
# you need to provide the proper path. In this way no matter you run the code
# locally or on the huggingface space, you will alwazs use "app.main:app" as
# input argument in the uvicorn.run
# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="localhost", reload=True, port=8000) |