handler v11
#11
by
vikram-fresche
- opened
- handler.py +24 -4
handler.py
CHANGED
@@ -91,11 +91,31 @@ class EndpointHandler:
|
|
91 |
logger.info("Decoding response")
|
92 |
output_text = self.tokenizer.batch_decode(output_tokens)[0]
|
93 |
|
94 |
-
# Extract the assistant's response by
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
logger.info(f"Generated response: {json.dumps(response)}")
|
97 |
-
|
98 |
-
return [{"generations": [{"text": response}]}]
|
99 |
|
100 |
except Exception as e:
|
101 |
logger.error(f"Error during generation: {str(e)}", exc_info=True)
|
|
|
91 |
logger.info("Decoding response")
|
92 |
output_text = self.tokenizer.batch_decode(output_tokens)[0]
|
93 |
|
94 |
+
# Extract only the assistant's response by finding the last assistant role block
|
95 |
+
assistant_start = output_text.rfind("<|start_of_role|>assistant<|end_of_role|>")
|
96 |
+
if assistant_start != -1:
|
97 |
+
response = output_text[assistant_start + len("<|start_of_role|>assistant<|end_of_role|>"):].strip()
|
98 |
+
# Remove any trailing end_of_text marker
|
99 |
+
if "<|end_of_text|>" in response:
|
100 |
+
response = response.split("<|end_of_text|>")[0].strip()
|
101 |
+
|
102 |
+
# Check for function calling
|
103 |
+
if "Calling function:" in response:
|
104 |
+
# Split response into text and function call
|
105 |
+
parts = response.split("Calling function:", 1)
|
106 |
+
text_response = parts[0].strip()
|
107 |
+
function_call = "Calling function:" + parts[1].strip()
|
108 |
+
|
109 |
+
logger.info(f"Function call: {function_call}")
|
110 |
+
logger.info(f"Text response: {text_response}")
|
111 |
+
# Return both text and tool message
|
112 |
+
return {"result": [{"text": text_response}]},
|
113 |
+
#{"generations": [{"text": function_call, "type": "tool"}]}
|
114 |
+
else:
|
115 |
+
response = output_text
|
116 |
+
|
117 |
logger.info(f"Generated response: {json.dumps(response)}")
|
118 |
+
return {"result": [{"text": response}]}
|
|
|
119 |
|
120 |
except Exception as e:
|
121 |
logger.error(f"Error during generation: {str(e)}", exc_info=True)
|