File size: 8,693 Bytes
212b42a
 
6f96ca2
212b42a
5a2b2d3
 
 
 
 
 
 
 
 
8dbef5d
14d48df
5a2b2d3
 
 
14d48df
 
 
 
 
 
 
5a2b2d3
 
 
 
27306dd
14d48df
 
 
 
 
27306dd
14d48df
 
 
27306dd
 
 
 
c2c40ec
9a34296
78c941e
5a2b2d3
78c941e
 
 
 
 
 
5a2b2d3
 
78c941e
 
 
 
5a2b2d3
 
78c941e
 
5a2b2d3
 
78c941e
 
 
 
 
 
 
 
5a2b2d3
 
78c941e
 
5a2b2d3
 
 
78c941e
 
 
 
 
 
5a2b2d3
 
14d48df
 
 
 
 
 
 
 
 
 
 
 
 
6f96ca2
14d48df
 
 
6f96ca2
 
 
 
14d48df
 
 
 
 
 
 
8dbef5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a2b2d3
 
8dbef5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a2b2d3
 
212b42a
 
 
 
 
 
 
 
5a2b2d3
 
 
b1cee42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import sys
import os
from datetime import datetime

from langchain_core.runnables import Runnable
from langchain_core.callbacks import BaseCallbackHandler
from fastapi import FastAPI, Request, Depends
from sse_starlette.sse import EventSourceResponse
from langserve.serialization import WellKnownLCSerializer
from typing import List
from sqlalchemy.orm import Session

import schemas
from chains import simple_chain, formatted_chain, history_chain, rag_chain, filtered_rag_chain
import crud, models, schemas, prompts
from database import SessionLocal, engine
from callbacks import LogResponseCallback

# models.Base comes from SQLAlchemy’s declarative_base() in database.py. 
# It acts as the base class for all ORM models (defined in models.py). 
# .metadata.create_all(): Tells SQLAlchemy to create all the tables defined 
# in the models module if they don’t already exist in the database.
# -> metadata is a catalog of all the tables and other schema constructs in your database.
# -> create_all() method creates all the tables that don't exist yet in the database.
# -> bind=engine specifies which database engine to use for this operation.
models.Base.metadata.create_all(bind=engine)

app = FastAPI()

def get_db():
    """This is a dependency function used to create and provide a 
    database session to various endpoints in the FastAPI app.
    """    
    # A new SQLAlchemy session is created using the SessionLocal session factory. 
    # This session will be used for database transactions.
    db = SessionLocal()

    # This pattern ensures that each request gets its own database session and that 
    # the session is properly closed when the request is finished, preventing resource leaks.
    try:
        yield db
    finally:
        db.close()

# ..
# "async" marks the function as asynchronous, allowing it to pause and resume during operations like streaming or I/O.
async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
    """generate_stream is an asynchronous generator that processes input data, 
    streams output data from a runnable object, serializes each output, and yields 
    it to the client in real-time as part of a server-sent event (SSE) stream.
    It uses callbacks to customize the processing, serializes each piece of output 
    using WellKnownLCSerializer, and indicates the end of the stream with a final “end” event.
    """    
    for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}): 
        data = WellKnownLCSerializer().dumps(output).decode("utf-8")
        yield {'data': data, "event": "data"}
    # After all the data has been streamed and the loop is complete, the function yields a final event to signal 
    # the end of the stream. This sends an {"event": "end"} message to the client, letting them know that no more 
    # data will be sent.
    yield {"event": "end"}

# This registers the function simple_stream as a handler for HTTP POST requests at the URL endpoint /simple/stream. 
# It means that when a client sends a POST request to this endpoint, this function will be triggered.
@app.post("/simple/stream")
async def simple_stream(request: Request):
    """the function handles a POST request at the /simple/stream endpoint, 
    extracts the JSON body, unpacks the "input" field, and then uses it to 
    initialize a UserQuestion schema object (which performs validation 
    and data transformation) and then initiates a server-sent event response 
    to stream data back to the client based on the user’s question.
    """    
    # await is used because parsing the JSON may involve asynchronous I/O operations, 
    # especially when handling larger payloads.
    data = await request.json()
    user_question = schemas.UserQuestion(**data['input'])
    # This line returns an EventSourceResponse, which is typically used to handle server-sent events (SSE). 
    # It’s a special kind of response that streams data back to the client in real time. 
    return EventSourceResponse(generate_stream(user_question, simple_chain))


@app.post("/formatted/stream")
async def formatted_stream(request: Request):
    # TODO: use the formatted_chain to implement the "/formatted/stream" endpoint.
    data = await request.json()
    user_question = schemas.UserQuestion(**data['input'])
    return EventSourceResponse(generate_stream(user_question, formatted_chain))


@app.post("/history/stream")
async def history_stream(request: Request, db: Session = Depends(get_db)):  
    # TODO: Let's implement the "/history/stream" endpoint. The endpoint should follow those steps:
    # - The endpoint receives the request
    data = await request.json()

    # - The request is parsed into a user request
    user_request = schemas.UserRequest(**data['input'])

    # - The user request is used to pull the chat history of the user
    chat_history = crud.get_user_chat_history(db=db, username=user_request.username)

    # - We add as part of the user history the current question by using add_message.
    message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
    crud.add_message(db, message=message, username=user_request.username)

    # - We create an instance of HistoryInput by using format_chat_history.
    history_input = schemas.HistoryInput(
        question=user_request.question, 
        chat_history=prompts.format_chat_history(chat_history)
    )

    # - We use the history input within the history chain.
    return EventSourceResponse(generate_stream(
        history_input, history_chain, [LogResponseCallback(user_request, db)]
    ))
    

@app.post("/rag/stream")
async def rag_stream(request: Request, db: Session = Depends(get_db)):  
    # TODO: Let's implement the "/rag/stream" endpoint. The endpoint should follow those steps:
    # - The endpoint receives the request
    # - The request is parsed into a user request
    # - The user request is used to pull the chat history of the user
    # - We add as part of the user history the current question by using add_message.
    # - We create an instance of HistoryInput by using format_chat_history.
    # - We use the history input within the rag chain.
    
    data = await request.json()
    user_request = schemas.UserRequest(**data['input'])
    chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
    message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
    crud.add_message(db, message=message, username=user_request.username)
    rag_input = schemas.HistoryInput(
        question=user_request.question, 
        chat_history=prompts.format_chat_history(chat_history)
    )

    return EventSourceResponse(generate_stream(
        rag_input, rag_chain, [LogResponseCallback(user_request, db)]
    ))


@app.post("/filtered_rag/stream")
async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):  
    # TODO: Let's implement the "/filtered_rag/stream" endpoint. The endpoint should follow those steps:
    # - The endpoint receives the request
    # - The request is parsed into a user request
    # - The user request is used to pull the chat history of the user
    # - We add as part of the user history the current question by using add_message.
    # - We create an instance of HistoryInput by using format_chat_history.
    # - We use the history input within the filtered rag chain.

    data = await request.json()
    user_request = schemas.UserRequest(**data['input'])
    chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
    message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
    crud.add_message(db, message=message, username=user_request.username)
    rag_input = schemas.HistoryInput(
        question=user_request.question, 
        chat_history=prompts.format_chat_history(chat_history)
    )

    return EventSourceResponse(generate_stream(
        rag_input, filtered_rag_chain, [LogResponseCallback(user_request, db)]
    ))
    

# Run From the Parent Directory with Script
# If you want to use uvicorn.run from within a script using "app.main:app", 
# you need to provide the proper path. In this way no matter you run the code 
# locally or on the huggingface space, you will alwazs use "app.main:app" as 
# input argument in the uvicorn.run

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app.main:app", host="localhost", reload=True,  port=8000)