Spaces:
Running
Running
Optimize entity retrieval and limit final entity processing.
Browse filesReplaced individual entity fetches with bulk retrieval to improve performance. Increased the number of filtered results but reduced output entities for better efficiency. Adjusted returned fields and introduced shortening in utility functions for streamlined processing.
trauma/api/message/ai/engine.py
CHANGED
@@ -17,7 +17,7 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
|
|
17 |
set_entity_score)
|
18 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
19 |
filter_entities_by_age_location,
|
20 |
-
update_entity_data_obj,
|
21 |
from trauma.api.message.dto import Author
|
22 |
from trauma.api.message.model import MessageModel
|
23 |
from trauma.api.message.schemas import CreateMessageResponse
|
@@ -86,8 +86,8 @@ async def search_semantic_entities(
|
|
86 |
for idx, dist in zip(indices, distances)
|
87 |
if idx in entities_indexes and dist <= 1.3
|
88 |
]
|
89 |
-
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:
|
90 |
-
final_entities = await
|
91 |
final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
|
92 |
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
|
93 |
return final_entities_scored
|
|
|
17 |
set_entity_score)
|
18 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
19 |
filter_entities_by_age_location,
|
20 |
+
update_entity_data_obj, get_entities_bulk)
|
21 |
from trauma.api.message.dto import Author
|
22 |
from trauma.api.message.model import MessageModel
|
23 |
from trauma.api.message.schemas import CreateMessageResponse
|
|
|
86 |
for idx, dist in zip(indices, distances)
|
87 |
if idx in entities_indexes and dist <= 1.3
|
88 |
]
|
89 |
+
filtered_results = sorted(filtered_results, key=lambda x: x["distance"])[:50]
|
90 |
+
final_entities = await get_entities_bulk([i['index'] for i in filtered_results])
|
91 |
final_entities_extended = await extend_entities_with_highlights(final_entities, entity_data)
|
92 |
final_entities_scored = await set_entities_score(final_entities_extended, search_request)
|
93 |
return final_entities_scored
|
trauma/api/message/db_requests.py
CHANGED
@@ -5,7 +5,7 @@ from fastapi import HTTPException
|
|
5 |
from trauma.api.account.dto import AccountType
|
6 |
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat.model import ChatModel
|
8 |
-
from trauma.api.data.model import EntityModel
|
9 |
from trauma.api.message.dto import Author, Feedback
|
10 |
from trauma.api.message.model import MessageModel
|
11 |
from trauma.api.message.schemas import CreateMessageRequest
|
@@ -80,7 +80,7 @@ async def filter_entities_by_age_location(entity_data: dict) -> list[int]:
|
|
80 |
"$regex": f".*{entity_data['postalCode']}.*",
|
81 |
"$options": "i"
|
82 |
}
|
83 |
-
entities = await settings.DB_CLIENT.entities.find(query, {"
|
84 |
return [entity['index'] for entity in entities]
|
85 |
|
86 |
|
@@ -89,6 +89,12 @@ async def get_entity_by_index(index: int) -> EntityModel:
|
|
89 |
return EntityModel.from_mongo(entity)
|
90 |
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
async def update_message_feedback_obj(message_id: str, feedback_data: Feedback) -> MessageModel:
|
93 |
message = await settings.DB_CLIENT.messages.find_one({"id": message_id})
|
94 |
if not message:
|
|
|
5 |
from trauma.api.account.dto import AccountType
|
6 |
from trauma.api.account.model import AccountModel
|
7 |
from trauma.api.chat.model import ChatModel
|
8 |
+
from trauma.api.data.model import EntityModel
|
9 |
from trauma.api.message.dto import Author, Feedback
|
10 |
from trauma.api.message.model import MessageModel
|
11 |
from trauma.api.message.schemas import CreateMessageRequest
|
|
|
80 |
"$regex": f".*{entity_data['postalCode']}.*",
|
81 |
"$options": "i"
|
82 |
}
|
83 |
+
entities = await settings.DB_CLIENT.entities.find(query, {"index": 1, "_id": 0}).to_list(length=None)
|
84 |
return [entity['index'] for entity in entities]
|
85 |
|
86 |
|
|
|
89 |
return EntityModel.from_mongo(entity)
|
90 |
|
91 |
|
92 |
+
async def get_entities_bulk(indices: list[int]) -> list[EntityModel]:
|
93 |
+
entities = await settings.DB_CLIENT.entities.find({"index": {"$in": indices}},
|
94 |
+
{"embedding": 0}).to_list(length=None)
|
95 |
+
return [EntityModel.from_mongo(entity) for entity in entities]
|
96 |
+
|
97 |
+
|
98 |
async def update_message_feedback_obj(message_id: str, feedback_data: Feedback) -> MessageModel:
|
99 |
message = await settings.DB_CLIENT.messages.find_one({"id": message_id})
|
100 |
if not message:
|
trauma/api/message/utils.py
CHANGED
@@ -44,8 +44,9 @@ def prepare_user_messages_str(user_message: str, messages: list[MessageModel]) -
|
|
44 |
|
45 |
|
46 |
def prepare_final_entities_str(entities: list[EntityModel]) -> str:
|
|
|
47 |
entities_list = []
|
48 |
-
for entity in
|
49 |
entities_list.append(entity.model_dump(mode='json', exclude={
|
50 |
'id', 'contactDetails', "highlightedAgeGroup", "highlightedTreatmentArea", "highlightedTreatmentMethod"
|
51 |
}))
|
|
|
44 |
|
45 |
|
46 |
def prepare_final_entities_str(entities: list[EntityModel]) -> str:
|
47 |
+
shortened_entities = entities[:3]
|
48 |
entities_list = []
|
49 |
+
for entity in shortened_entities:
|
50 |
entities_list.append(entity.model_dump(mode='json', exclude={
|
51 |
'id', 'contactDetails', "highlightedAgeGroup", "highlightedTreatmentArea", "highlightedTreatmentMethod"
|
52 |
}))
|