brestok commited on
Commit
3a56394
·
1 Parent(s): 3e8fd5d

Add entity handling to message saving and AI response logic

Browse files
trauma/api/message/ai/engine.py CHANGED
@@ -31,16 +31,10 @@ from trauma.core.config import settings
31
  async def search_entities(
32
  user_message: str, messages: list[dict], chat: ChatModel
33
  ) -> CreateMessageResponse:
34
- related_entity, entity_data, is_valid = await asyncio.gather(
35
- retrieve_semantic_answer(user_message),
36
  update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
37
  check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
38
  )
39
- if related_entity:
40
- response = await generate_searched_entity_response(user_message, related_entity[0])
41
- asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
42
- return CreateMessageResponse(text=response, entities=related_entity)
43
-
44
  final_entities = None
45
  if not is_valid:
46
  response = await generate_invalid_response(user_message, messages)
@@ -60,7 +54,7 @@ async def search_entities(
60
  final_entities_str = prepare_final_entities_str(final_entities)
61
  response = await generate_final_response(final_entities_str, user_message, messages)
62
 
63
- asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
64
  return CreateMessageResponse(text=response, entities=final_entities)
65
 
66
 
 
31
  async def search_entities(
32
  user_message: str, messages: list[dict], chat: ChatModel
33
  ) -> CreateMessageResponse:
34
+ entity_data, is_valid = await asyncio.gather(
 
35
  update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
36
  check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
37
  )
 
 
 
 
 
38
  final_entities = None
39
  if not is_valid:
40
  response = await generate_invalid_response(user_message, messages)
 
54
  final_entities_str = prepare_final_entities_str(final_entities)
55
  response = await generate_final_response(final_entities_str, user_message, messages)
56
 
57
+ asyncio.create_task(save_assistant_user_message(user_message, final_entities, response, chat.id))
58
  return CreateMessageResponse(text=response, entities=final_entities)
59
 
60
 
trauma/api/message/db_requests.py CHANGED
@@ -5,7 +5,7 @@ from fastapi import HTTPException
5
  from trauma.api.account.dto import AccountType
6
  from trauma.api.account.model import AccountModel
7
  from trauma.api.chat.model import ChatModel
8
- from trauma.api.data.model import EntityModel
9
  from trauma.api.message.dto import Author
10
  from trauma.api.message.model import MessageModel
11
  from trauma.api.message.schemas import CreateMessageRequest
@@ -53,8 +53,10 @@ async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None:
53
 
54
 
55
  @background_task()
56
- async def save_assistant_user_message(user_message: str, assistant_message: str, chat_id: str) -> None:
57
- user_message = MessageModel(chatId=chat_id, author=Author.User, text=user_message)
 
 
58
  assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message)
59
  await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
60
  await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
 
5
  from trauma.api.account.dto import AccountType
6
  from trauma.api.account.model import AccountModel
7
  from trauma.api.chat.model import ChatModel
8
+ from trauma.api.data.model import EntityModel, EntityModelExtended
9
  from trauma.api.message.dto import Author
10
  from trauma.api.message.model import MessageModel
11
  from trauma.api.message.schemas import CreateMessageRequest
 
53
 
54
 
55
  @background_task()
56
+ async def save_assistant_user_message(
57
+ user_message: str, final_entities: list[EntityModelExtended], assistant_message: str, chat_id: str
58
+ ) -> None:
59
+ user_message = MessageModel(chatId=chat_id, author=Author.User, entities=final_entities, text=user_message)
60
  assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message)
61
  await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
62
  await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
trauma/api/message/model.py CHANGED
@@ -2,6 +2,7 @@ from datetime import datetime
2
 
3
  from pydantic import Field
4
 
 
5
  from trauma.api.message.dto import Author
6
  from trauma.core.database import MongoBaseModel
7
 
@@ -10,5 +11,6 @@ class MessageModel(MongoBaseModel):
10
  chatId: str
11
  author: Author
12
  text: str
 
13
  datetimeInserted: datetime = Field(default_factory=datetime.now)
14
  datetimeUpdated: datetime = Field(default_factory=datetime.now)
 
2
 
3
  from pydantic import Field
4
 
5
+ from trauma.api.data.model import EntityModelExtended
6
  from trauma.api.message.dto import Author
7
  from trauma.core.database import MongoBaseModel
8
 
 
11
  chatId: str
12
  author: Author
13
  text: str
14
+ entities: list[EntityModelExtended] | None = None
15
  datetimeInserted: datetime = Field(default_factory=datetime.now)
16
  datetimeUpdated: datetime = Field(default_factory=datetime.now)