Spaces:
Running
Running
added invalid request handling
Browse files
trauma/api/message/ai/engine.py
CHANGED
@@ -9,7 +9,8 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
|
|
9 |
generate_next_question,
|
10 |
generate_search_request,
|
11 |
generate_final_response, convert_value_to_embeddings,
|
12 |
-
choose_closest_treatment_method, choose_closest_treatment_area
|
|
|
13 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
14 |
filter_entities_by_age,
|
15 |
update_entity_data_obj, get_entity_by_index)
|
@@ -24,24 +25,28 @@ from trauma.core.config import settings
|
|
24 |
async def search_entities(
|
25 |
user_message: str, messages: list[dict], chat: ChatModel
|
26 |
) -> CreateMessageResponse:
|
27 |
-
entity_data = await
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
final_entities = None
|
32 |
-
|
33 |
-
|
34 |
-
empty_field_instructions = pick_empty_field_instructions(empty_field)
|
35 |
-
response = await generate_next_question(empty_field, empty_field_instructions, user_message, messages)
|
36 |
else:
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
47 |
return CreateMessageResponse(text=response, entities=final_entities)
|
|
|
9 |
generate_next_question,
|
10 |
generate_search_request,
|
11 |
generate_final_response, convert_value_to_embeddings,
|
12 |
+
choose_closest_treatment_method, choose_closest_treatment_area,
|
13 |
+
check_is_valid_request, generate_invalid_response)
|
14 |
from trauma.api.message.db_requests import (save_assistant_user_message,
|
15 |
filter_entities_by_age,
|
16 |
update_entity_data_obj, get_entity_by_index)
|
|
|
25 |
async def search_entities(
|
26 |
user_message: str, messages: list[dict], chat: ChatModel
|
27 |
) -> CreateMessageResponse:
|
28 |
+
entity_data, is_valid = await asyncio.gather(
|
29 |
+
update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
|
30 |
+
check_is_valid_request(user_message, messages[-1]['content'])
|
31 |
+
)
|
32 |
final_entities = None
|
33 |
+
if not is_valid:
|
34 |
+
response = await generate_invalid_response(user_message, messages)
|
|
|
|
|
35 |
else:
|
36 |
+
asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
|
37 |
+
empty_field = retrieve_empty_field_from_entity_data(entity_data)
|
38 |
+
if empty_field:
|
39 |
+
empty_field_instructions = pick_empty_field_instructions(empty_field)
|
40 |
+
response = await generate_next_question(empty_field, empty_field_instructions, user_message, messages)
|
41 |
+
else:
|
42 |
+
user_messages_str = prepare_user_messages_str(user_message, messages)
|
43 |
+
possible_entity_indexes, search_request = await asyncio.gather(
|
44 |
+
filter_entities_by_age(entity_data),
|
45 |
+
generate_search_request(user_messages_str, entity_data)
|
46 |
+
)
|
47 |
+
final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
|
48 |
+
final_entities_str = prepare_final_entities_str(final_entities)
|
49 |
+
response = await generate_final_response(final_entities_str, user_message, messages)
|
50 |
|
51 |
asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
|
52 |
return CreateMessageResponse(text=response, entities=final_entities)
|
trauma/api/message/ai/openai_request.py
CHANGED
@@ -101,3 +101,32 @@ async def choose_closest_treatment_method(treatment_methods: list[str], treatmen
|
|
101 |
}
|
102 |
]
|
103 |
return messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
}
|
102 |
]
|
103 |
return messages
|
104 |
+
|
105 |
+
|
106 |
+
@openai_wrapper(is_json=True, return_='is_valid')
|
107 |
+
async def check_is_valid_request(user_message: str, assistant_message: str):
|
108 |
+
messages = [
|
109 |
+
{
|
110 |
+
"role": "system",
|
111 |
+
"content": TraumaPrompts.decide_is_valid_request
|
112 |
+
.replace("{user_message}", user_message)
|
113 |
+
.replace("{assistant_message}", assistant_message)
|
114 |
+
}
|
115 |
+
]
|
116 |
+
return messages
|
117 |
+
|
118 |
+
|
119 |
+
@openai_wrapper(temperature=0.9)
|
120 |
+
async def generate_invalid_response(user_message: str, message_history: list[dict]):
|
121 |
+
messages = [
|
122 |
+
{
|
123 |
+
"role": "system",
|
124 |
+
"content": TraumaPrompts.generate_invalid_response
|
125 |
+
},
|
126 |
+
*message_history,
|
127 |
+
{
|
128 |
+
"role": "user",
|
129 |
+
"content": user_message
|
130 |
+
}
|
131 |
+
]
|
132 |
+
return messages
|
trauma/api/message/ai/prompts.py
CHANGED
@@ -109,6 +109,47 @@ Je moet een antwoord genereren aan de gebruiker waarin je aangeeft dat je geschi
|
|
109 |
## Voorbeeld van antwoorden
|
110 |
|
111 |
- Gefeliciteerd! Hier is een lijst van klinieken die perfect passen bij deze aandoening. Ik heb deze klinieken aanbevolen omdat ze voldoen aan de gevraagde leeftijdsbeperkingen en gespecialiseerd zijn in de behandeling van deze aandoening met behulp van dergelijke methoden."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
generate_clinic_description = """## Taak
|
114 |
|
@@ -194,3 +235,4 @@ You must determine the most semantically similar treatment method from the list
|
|
194 |
## Instructions for filling JSON
|
195 |
|
196 |
- [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list."""
|
|
|
|
109 |
## Voorbeeld van antwoorden
|
110 |
|
111 |
- Gefeliciteerd! Hier is een lijst van klinieken die perfect passen bij deze aandoening. Ik heb deze klinieken aanbevolen omdat ze voldoen aan de gevraagde leeftijdsbeperkingen en gespecialiseerd zijn in de behandeling van deze aandoening met behulp van dergelijke methoden."""
|
112 |
+
decide_is_valid_request = """## Task
|
113 |
+
|
114 |
+
You must determine whether the user's response is valid. Provide your answer in the [is_valid] field in JSON format.
|
115 |
+
|
116 |
+
## Context
|
117 |
+
|
118 |
+
A nurse has interviewed a patient and is now searching for a suitable clinic based on the collected information. To assist with this, the nurse answers questions from an AI assistant, which gathers patient data to provide recommendations. Your task is to determine whether the user's (nurse's) message is relevant to the system's purpose. If the user's message falls outside the intended scope of the system's patient data collection, return false.
|
119 |
+
|
120 |
+
## Data
|
121 |
+
|
122 |
+
**user query**:
|
123 |
+
```
|
124 |
+
{user_message}
|
125 |
+
```
|
126 |
+
|
127 |
+
**assistant question**:
|
128 |
+
```
|
129 |
+
{assistant_message}
|
130 |
+
```
|
131 |
+
|
132 |
+
## JSON Response format
|
133 |
+
|
134 |
+
```json
|
135 |
+
{
|
136 |
+
"is_valid": boolean
|
137 |
+
}
|
138 |
+
```
|
139 |
+
|
140 |
+
## Instructions for filling JSON
|
141 |
+
|
142 |
+
The field is considered valid (`is_valid = true`) if:
|
143 |
+
- The user has provided a logical answer to the assistant's question.
|
144 |
+
- The user's message relates to a medical topic."""
|
145 |
+
generate_invalid_response = """## Taak
|
146 |
+
|
147 |
+
Je moet een antwoord genereren voor de gebruiker waarin je aangeeft dat hun verzoek buiten jouw specificatie valt voor het verzamelen van informatie en het geven van aanbevelingen.
|
148 |
+
|
149 |
+
## Belangrijke opmerkingen
|
150 |
+
|
151 |
+
- Je antwoord moet kort en bondig zijn, bestaande uit twee zinnen.
|
152 |
+
- Je moet de gebruiker informeren dat hun verzoek onjuist is en je vorige vraag opnieuw stellen om verder te gaan met het verzamelen van informatie over de patiënt."""
|
153 |
|
154 |
generate_clinic_description = """## Taak
|
155 |
|
|
|
235 |
## Instructions for filling JSON
|
236 |
|
237 |
- [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list."""
|
238 |
+
|