bachephysicdun commited on
Commit
6f96ca2
·
1 Parent(s): 14d48df

fixed bugs for chat_history version

Browse files
Files changed (5) hide show
  1. app/callbacks.py +3 -2
  2. app/main.py +6 -2
  3. app/models.py +1 -1
  4. app/prompts.py +4 -2
  5. app/schemas.py +4 -3
app/callbacks.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Dict, Any, List
 
2
  from langchain_core.callbacks import BaseCallbackHandler
3
  import schemas
4
  import crud
@@ -15,8 +16,8 @@ class LogResponseCallback(BaseCallbackHandler):
15
  """Run when llm ends running."""
16
  # TODO: The function on_llm_end is going to be called when the LLM stops sending
17
  # the response. Use the crud.add_message function to capture that response.
18
- print(outputs)
19
- message = schemas.MessageBase(message=outputs.get('text'), type='AI')
20
  crud.add_message(self.db, message=message, username=self.user_request.username)
21
 
22
  def on_llm_start(
 
1
  from typing import Dict, Any, List
2
+ from datetime import datetime
3
  from langchain_core.callbacks import BaseCallbackHandler
4
  import schemas
5
  import crud
 
16
  """Run when llm ends running."""
17
  # TODO: The function on_llm_end is going to be called when the LLM stops sending
18
  # the response. Use the crud.add_message function to capture that response.
19
+ # print("Full outputs object:", outputs)
20
+ message = schemas.MessageBase(message=outputs.generations[0][0].text, type='AI', timestamp=datetime.now())
21
  crud.add_message(self.db, message=message, username=self.user_request.username)
22
 
23
  def on_llm_start(
app/main.py CHANGED
@@ -1,5 +1,6 @@
1
  import sys
2
  import os
 
3
 
4
  from langchain_core.runnables import Runnable
5
  from langchain_core.callbacks import BaseCallbackHandler
@@ -98,11 +99,14 @@ async def history_stream(request: Request, db: Session = Depends(get_db)):
98
  chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
99
 
100
  # - We add as part of the user history the current question by using add_message.
101
- message = schemas.MessageBase(message=user_request.question, type='User')
102
  crud.add_message(db, message=message, username=user_request.username)
103
 
104
  # - We create an instance of HistoryInput by using format_chat_history.
105
- history_input = schemas.HistoryInput(question=user_request.username, chat_history=chat_history)
 
 
 
106
 
107
  # - We use the history input within the history chain.
108
  return EventSourceResponse(generate_stream(
 
1
  import sys
2
  import os
3
+ from datetime import datetime
4
 
5
  from langchain_core.runnables import Runnable
6
  from langchain_core.callbacks import BaseCallbackHandler
 
99
  chat_history = crud.get_user_chat_history(db=db, username=user_request.username)
100
 
101
  # - We add as part of the user history the current question by using add_message.
102
+ message = schemas.MessageBase(message=user_request.question, type='User', timestamp=datetime.now())
103
  crud.add_message(db, message=message, username=user_request.username)
104
 
105
  # - We create an instance of HistoryInput by using format_chat_history.
106
+ history_input = schemas.HistoryInput(
107
+ question=user_request.question,
108
+ chat_history=prompts.format_chat_history(chat_history)
109
+ )
110
 
111
  # - We use the history input within the history chain.
112
  return EventSourceResponse(generate_stream(
app/models.py CHANGED
@@ -50,6 +50,6 @@ class Message(Base):
50
  user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
51
  message = Column(String, nullable=False)
52
  type = Column(String, nullable=False)
53
- timestamp = Column(DateTime, default=datetime.utcnow, nullable=False)
54
 
55
  user = relationship("User", back_populates="messages")
 
50
  user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
51
  message = Column(String, nullable=False)
52
  type = Column(String, nullable=False)
53
+ timestamp = Column(DateTime, default=datetime.now(), nullable=False)
54
 
55
  user = relationship("User", back_populates="messages")
app/prompts.py CHANGED
@@ -37,12 +37,14 @@ def format_chat_history(messages: List[models.Message]):
37
  # TODO: implement format_chat_history to format
38
  # the list of Message into a text of chat history.
39
 
 
 
40
  return '\n'.join([
41
  '[{}] {}: {}'.format(
42
  message.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
43
  message.type,
44
  message.message
45
- ) for message in messages
46
  ])
47
 
48
 
@@ -88,7 +90,7 @@ raw_prompt_formatted = format_prompt(prompt)
88
 
89
 
90
  # TODO: use format_prompt to create history_prompt_formatted
91
- history_prompt_formatted = format_prompt(history_prompt)
92
 
93
  # TODO: use format_prompt to create standalone_prompt_formatted
94
  standalone_prompt_formatted: PromptTemplate = None
 
37
  # TODO: implement format_chat_history to format
38
  # the list of Message into a text of chat history.
39
 
40
+ # Sort messages by timestamp using a lambda function
41
+ ordered_messages = sorted(messages, key=lambda m: m.timestamp, reverse=False)
42
  return '\n'.join([
43
  '[{}] {}: {}'.format(
44
  message.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
45
  message.type,
46
  message.message
47
+ ) for message in ordered_messages
48
  ])
49
 
50
 
 
90
 
91
 
92
  # TODO: use format_prompt to create history_prompt_formatted
93
+ history_prompt_formatted = format_prompt(history_prompt)
94
 
95
  # TODO: use format_prompt to create standalone_prompt_formatted
96
  standalone_prompt_formatted: PromptTemplate = None
app/schemas.py CHANGED
@@ -18,8 +18,9 @@ class UserRequest(BaseModel):
18
  # TODO: implement MessageBase as a schema mapping from the database model to the
19
  # FastAPI data model. Basically MessageBase should have the same attributes as models.Message
20
  class MessageBase(BaseModel):
21
- id: int
22
- user_id: int
23
  message: str
24
  type: str
25
- timestamp: datetime
 
 
18
  # TODO: implement MessageBase as a schema mapping from the database model to the
19
  # FastAPI data model. Basically MessageBase should have the same attributes as models.Message
20
  class MessageBase(BaseModel):
21
+ # id: int
22
+ # user_id: int
23
  message: str
24
  type: str
25
+ timestamp: datetime
26
+ # user: str