Spaces:
Running
Running
Add entity handling to message saving and AI response logic
Browse files
trauma/api/message/ai/engine.py
CHANGED
@@ -31,16 +31,10 @@ from trauma.core.config import settings
|
|
31 |
async def search_entities(
|
32 |
user_message: str, messages: list[dict], chat: ChatModel
|
33 |
) -> CreateMessageResponse:
|
34 |
-
|
35 |
-
retrieve_semantic_answer(user_message),
|
36 |
update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
|
37 |
check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
|
38 |
)
|
39 |
-
if related_entity:
|
40 |
-
response = await generate_searched_entity_response(user_message, related_entity[0])
|
41 |
-
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
42 |
-
return CreateMessageResponse(text=response, entities=related_entity)
|
43 |
-
|
44 |
final_entities = None
|
45 |
if not is_valid:
|
46 |
response = await generate_invalid_response(user_message, messages)
|
@@ -60,7 +54,7 @@ async def search_entities(
|
|
60 |
final_entities_str = prepare_final_entities_str(final_entities)
|
61 |
response = await generate_final_response(final_entities_str, user_message, messages)
|
62 |
|
63 |
-
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
64 |
return CreateMessageResponse(text=response, entities=final_entities)
|
65 |
|
66 |
|
|
|
31 |
async def search_entities(
|
32 |
user_message: str, messages: list[dict], chat: ChatModel
|
33 |
) -> CreateMessageResponse:
|
34 |
+
entity_data, is_valid = await asyncio.gather(
|
|
|
35 |
update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
|
36 |
check_is_valid_request(user_message, "\n".join([f"- [{i['role']}]: {i['content']}." for i in messages]))
|
37 |
)
|
|
|
|
|
|
|
|
|
|
|
38 |
final_entities = None
|
39 |
if not is_valid:
|
40 |
response = await generate_invalid_response(user_message, messages)
|
|
|
54 |
final_entities_str = prepare_final_entities_str(final_entities)
|
55 |
response = await generate_final_response(final_entities_str, user_message, messages)
|
56 |
|
57 |
+
asyncio.create_task(save_assistant_user_message(user_message, final_entities, response, chat.id))
|
58 |
return CreateMessageResponse(text=response, entities=final_entities)
|
59 |
|
60 |
|
trauma/api/message/db_requests.py
CHANGED
@@ -5,7 +5,7 @@ from fastapi import HTTPException
|
|
5 |
from trauma.api.account.dto import AccountType
|
6 |
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat.model import ChatModel
|
8 |
-
from trauma.api.data.model import EntityModel
|
9 |
from trauma.api.message.dto import Author
|
10 |
from trauma.api.message.model import MessageModel
|
11 |
from trauma.api.message.schemas import CreateMessageRequest
|
@@ -53,8 +53,10 @@ async def update_entity_data_obj(entity_data: dict, chat_id: str) -> None:
|
|
53 |
|
54 |
|
55 |
@background_task()
|
56 |
-
async def save_assistant_user_message(
|
57 |
-
|
|
|
|
|
58 |
assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message)
|
59 |
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
|
60 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
|
|
5 |
from trauma.api.account.dto import AccountType
|
6 |
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat.model import ChatModel
|
8 |
+
from trauma.api.data.model import EntityModel, EntityModelExtended
|
9 |
from trauma.api.message.dto import Author
|
10 |
from trauma.api.message.model import MessageModel
|
11 |
from trauma.api.message.schemas import CreateMessageRequest
|
|
|
53 |
|
54 |
|
55 |
@background_task()
|
56 |
+
async def save_assistant_user_message(
|
57 |
+
user_message: str, final_entities: list[EntityModelExtended], assistant_message: str, chat_id: str
|
58 |
+
) -> None:
|
59 |
+
user_message = MessageModel(chatId=chat_id, author=Author.User, entities=final_entities, text=user_message)
|
60 |
assistant_message = MessageModel(chatId=chat_id, author=Author.Assistant, text=assistant_message)
|
61 |
await settings.DB_CLIENT.messages.insert_one(user_message.to_mongo())
|
62 |
await settings.DB_CLIENT.messages.insert_one(assistant_message.to_mongo())
|
trauma/api/message/model.py
CHANGED
@@ -2,6 +2,7 @@ from datetime import datetime
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
|
|
5 |
from trauma.api.message.dto import Author
|
6 |
from trauma.core.database import MongoBaseModel
|
7 |
|
@@ -10,5 +11,6 @@ class MessageModel(MongoBaseModel):
|
|
10 |
chatId: str
|
11 |
author: Author
|
12 |
text: str
|
|
|
13 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
14 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|
|
|
2 |
|
3 |
from pydantic import Field
|
4 |
|
5 |
+
from trauma.api.data.model import EntityModelExtended
|
6 |
from trauma.api.message.dto import Author
|
7 |
from trauma.core.database import MongoBaseModel
|
8 |
|
|
|
11 |
chatId: str
|
12 |
author: Author
|
13 |
text: str
|
14 |
+
entities: list[EntityModelExtended] | None = None
|
15 |
datetimeInserted: datetime = Field(default_factory=datetime.now)
|
16 |
datetimeUpdated: datetime = Field(default_factory=datetime.now)
|