brestok commited on
Commit
2849c14
·
1 Parent(s): e754e5a

added invalid request handling

Browse files
trauma/api/message/ai/engine.py CHANGED
@@ -9,7 +9,8 @@ from trauma.api.message.ai.openai_request import (update_entity_data_with_ai,
9
  generate_next_question,
10
  generate_search_request,
11
  generate_final_response, convert_value_to_embeddings,
12
- choose_closest_treatment_method, choose_closest_treatment_area)
 
13
  from trauma.api.message.db_requests import (save_assistant_user_message,
14
  filter_entities_by_age,
15
  update_entity_data_obj, get_entity_by_index)
@@ -24,24 +25,28 @@ from trauma.core.config import settings
24
  async def search_entities(
25
  user_message: str, messages: list[dict], chat: ChatModel
26
  ) -> CreateMessageResponse:
27
- entity_data = await update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content'])
28
- asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
29
-
30
- empty_field = retrieve_empty_field_from_entity_data(entity_data)
31
  final_entities = None
32
-
33
- if empty_field:
34
- empty_field_instructions = pick_empty_field_instructions(empty_field)
35
- response = await generate_next_question(empty_field, empty_field_instructions, user_message, messages)
36
  else:
37
- user_messages_str = prepare_user_messages_str(user_message, messages)
38
- possible_entity_indexes, search_request = await asyncio.gather(
39
- filter_entities_by_age(entity_data),
40
- generate_search_request(user_messages_str, entity_data)
41
- )
42
- final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
43
- final_entities_str = prepare_final_entities_str(final_entities)
44
- response = await generate_final_response(final_entities_str, user_message, messages)
 
 
 
 
 
 
45
 
46
  asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
47
  return CreateMessageResponse(text=response, entities=final_entities)
 
9
  generate_next_question,
10
  generate_search_request,
11
  generate_final_response, convert_value_to_embeddings,
12
+ choose_closest_treatment_method, choose_closest_treatment_area,
13
+ check_is_valid_request, generate_invalid_response)
14
  from trauma.api.message.db_requests import (save_assistant_user_message,
15
  filter_entities_by_age,
16
  update_entity_data_obj, get_entity_by_index)
 
25
  async def search_entities(
26
  user_message: str, messages: list[dict], chat: ChatModel
27
  ) -> CreateMessageResponse:
28
+ entity_data, is_valid = await asyncio.gather(
29
+ update_entity_data_with_ai(chat.entityData, user_message, messages[-1]['content']),
30
+ check_is_valid_request(user_message, messages[-1]['content'])
31
+ )
32
  final_entities = None
33
+ if not is_valid:
34
+ response = await generate_invalid_response(user_message, messages)
 
 
35
  else:
36
+ asyncio.create_task(update_entity_data_obj(entity_data, chat.id))
37
+ empty_field = retrieve_empty_field_from_entity_data(entity_data)
38
+ if empty_field:
39
+ empty_field_instructions = pick_empty_field_instructions(empty_field)
40
+ response = await generate_next_question(empty_field, empty_field_instructions, user_message, messages)
41
+ else:
42
+ user_messages_str = prepare_user_messages_str(user_message, messages)
43
+ possible_entity_indexes, search_request = await asyncio.gather(
44
+ filter_entities_by_age(entity_data),
45
+ generate_search_request(user_messages_str, entity_data)
46
+ )
47
+ final_entities = await search_semantic_entities(search_request, entity_data, possible_entity_indexes)
48
+ final_entities_str = prepare_final_entities_str(final_entities)
49
+ response = await generate_final_response(final_entities_str, user_message, messages)
50
 
51
  asyncio.create_task(save_assistant_user_message(user_message, response, chat.id))
52
  return CreateMessageResponse(text=response, entities=final_entities)
trauma/api/message/ai/openai_request.py CHANGED
@@ -101,3 +101,32 @@ async def choose_closest_treatment_method(treatment_methods: list[str], treatmen
101
  }
102
  ]
103
  return messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  }
102
  ]
103
  return messages
104
+
105
+
106
+ @openai_wrapper(is_json=True, return_='is_valid')
107
+ async def check_is_valid_request(user_message: str, assistant_message: str):
108
+ messages = [
109
+ {
110
+ "role": "system",
111
+ "content": TraumaPrompts.decide_is_valid_request
112
+ .replace("{user_message}", user_message)
113
+ .replace("{assistant_message}", assistant_message)
114
+ }
115
+ ]
116
+ return messages
117
+
118
+
119
+ @openai_wrapper(temperature=0.9)
120
+ async def generate_invalid_response(user_message: str, message_history: list[dict]):
121
+ messages = [
122
+ {
123
+ "role": "system",
124
+ "content": TraumaPrompts.generate_invalid_response
125
+ },
126
+ *message_history,
127
+ {
128
+ "role": "user",
129
+ "content": user_message
130
+ }
131
+ ]
132
+ return messages
trauma/api/message/ai/prompts.py CHANGED
@@ -109,6 +109,47 @@ Je moet een antwoord genereren aan de gebruiker waarin je aangeeft dat je geschi
109
  ## Voorbeeld van antwoorden
110
 
111
  - Gefeliciteerd! Hier is een lijst van klinieken die perfect passen bij deze aandoening. Ik heb deze klinieken aanbevolen omdat ze voldoen aan de gevraagde leeftijdsbeperkingen en gespecialiseerd zijn in de behandeling van deze aandoening met behulp van dergelijke methoden."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  generate_clinic_description = """## Taak
114
 
@@ -194,3 +235,4 @@ You must determine the most semantically similar treatment method from the list
194
  ## Instructions for filling JSON
195
 
196
  - [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list."""
 
 
109
  ## Voorbeeld van antwoorden
110
 
111
  - Gefeliciteerd! Hier is een lijst van klinieken die perfect passen bij deze aandoening. Ik heb deze klinieken aanbevolen omdat ze voldoen aan de gevraagde leeftijdsbeperkingen en gespecialiseerd zijn in de behandeling van deze aandoening met behulp van dergelijke methoden."""
112
+ decide_is_valid_request = """## Task
113
+
114
+ You must determine whether the user's response is valid. Provide your answer in the [is_valid] field in JSON format.
115
+
116
+ ## Context
117
+
118
+ A nurse has interviewed a patient and is now searching for a suitable clinic based on the collected information. To assist with this, the nurse answers questions from an AI assistant, which gathers patient data to provide recommendations. Your task is to determine whether the user's (nurse's) message is relevant to the system's purpose. If the user's message falls outside the intended scope of the system's patient data collection, return false.
119
+
120
+ ## Data
121
+
122
+ **user query**:
123
+ ```
124
+ {user_message}
125
+ ```
126
+
127
+ **assistant question**:
128
+ ```
129
+ {assistant_message}
130
+ ```
131
+
132
+ ## JSON Response format
133
+
134
+ ```json
135
+ {
136
+ "is_valid": boolean
137
+ }
138
+ ```
139
+
140
+ ## Instructions for filling JSON
141
+
142
+ The field is considered valid (`is_valid = true`) if:
143
+ - The user has provided a logical answer to the assistant's question.
144
+ - The user's message relates to a medical topic."""
145
+ generate_invalid_response = """## Taak
146
+
147
+ Je moet een antwoord genereren voor de gebruiker waarin je aangeeft dat hun verzoek buiten jouw specificatie valt voor het verzamelen van informatie en het geven van aanbevelingen.
148
+
149
+ ## Belangrijke opmerkingen
150
+
151
+ - Je antwoord moet kort en bondig zijn, bestaande uit twee zinnen.
152
+ - Je moet de gebruiker informeren dat hun verzoek onjuist is en je vorige vraag opnieuw stellen om verder te gaan met het verzamelen van informatie over de patiënt."""
153
 
154
  generate_clinic_description = """## Taak
155
 
 
235
  ## Instructions for filling JSON
236
 
237
  - [result]: The item from the [treatment methods] list that is most semantically similar to the requested treatment method. The treatment method name in the result field must exactly match the name as it appears in the [treatment methods] list."""
238
+