File size: 1,062 Bytes
5a2b2d3
6f96ca2
5a2b2d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f96ca2
 
14d48df
5a2b2d3
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import Dict, Any, List
from datetime import datetime
from langchain_core.callbacks import BaseCallbackHandler
import schemas
import crud


class LogResponseCallback(BaseCallbackHandler):

    def __init__(self, user_request: schemas.UserRequest, db):
        super().__init__()
        self.user_request = user_request
        self.db = db

    def on_llm_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
        """Run when llm ends running."""
        # TODO: The function on_llm_end is going to be called when the LLM stops sending 
        # the response. Use the crud.add_message function to capture that response.
        # print("Full outputs object:", outputs)
        message = schemas.MessageBase(message=outputs.generations[0][0].text, type='AI', timestamp=datetime.now())
        crud.add_message(self.db, message=message, username=self.user_request.username)

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> Any:
        for prompt in prompts:
            print(prompt)