vikram-fresche commited on
Commit
d0b295c
·
verified ·
1 Parent(s): 8ccf043

handler v9 (#9)

Browse files

- added custom handler v9 (cd11daeba612cd8893250017ccb612aaf4528395)

Files changed (1) hide show
  1. handler.py +24 -7
handler.py CHANGED
@@ -1,7 +1,9 @@
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
  import logging
 
5
 
6
  # Configure logging
7
  logging.basicConfig(
@@ -72,12 +74,10 @@ class EndpointHandler:
72
  tokenize=False,
73
  add_generation_prompt=True
74
  )
75
- logger.info(f"Generated chat prompt: {prompt}")
76
 
77
  # Tokenize the prompt
78
- logger.info("Tokenizing input")
79
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
80
- logger.info(f"Input shape: {inputs.input_ids.shape}")
81
 
82
  # Generate response
83
  logger.info("Generating response")
@@ -86,7 +86,6 @@ class EndpointHandler:
86
  **inputs,
87
  **gen_params
88
  )
89
- logger.info(f"Output shape: {output_tokens.shape}")
90
 
91
  # Decode the response
92
  logger.info("Decoding response")
@@ -94,11 +93,29 @@ class EndpointHandler:
94
 
95
  # Extract the assistant's response by removing the input prompt
96
  response = output_text#[len(prompt):].strip()
97
- logger.info(f"Generated response length: {len(response)}")
98
- logger.info(f"Generated response: {response}")
99
 
100
  #return [{"role": "assistant", "content": response}]
101
- return [{"result": response, "error": None}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  except Exception as e:
104
  logger.error(f"Error during generation: {str(e)}", exc_info=True)
 
1
  from typing import Dict, List, Any
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import json
5
  import logging
6
+ import time
7
 
8
  # Configure logging
9
  logging.basicConfig(
 
74
  tokenize=False,
75
  add_generation_prompt=True
76
  )
77
+ logger.info(f"Generated chat prompt: {json.dumps(prompt)}")
78
 
79
  # Tokenize the prompt
 
80
  inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
 
81
 
82
  # Generate response
83
  logger.info("Generating response")
 
86
  **inputs,
87
  **gen_params
88
  )
 
89
 
90
  # Decode the response
91
  logger.info("Decoding response")
 
93
 
94
  # Extract the assistant's response by removing the input prompt
95
  response = output_text#[len(prompt):].strip()
96
+ logger.info(f"Generated response: {json.dumps(response)}")
 
97
 
98
  #return [{"role": "assistant", "content": response}]
99
+ #return {"result": response, "error": None}
100
+ return {
101
+ "id": "cmpl-" + str(hash(response))[:10], # Generate a unique ID
102
+ "object": "chat.completion",
103
+ "created": int(time.time()),
104
+ "model": self.model.config.name_or_path,
105
+ "choices": [{
106
+ "index": 0,
107
+ "message": {
108
+ "role": "assistant",
109
+ "content": response
110
+ },
111
+ "finish_reason": "stop"
112
+ }],
113
+ "usage": {
114
+ "prompt_tokens": len(inputs["input_ids"][0]),
115
+ "completion_tokens": len(output_tokens[0]) - len(inputs["input_ids"][0]),
116
+ "total_tokens": len(output_tokens[0])
117
+ }
118
+ }
119
 
120
  except Exception as e:
121
  logger.error(f"Error during generation: {str(e)}", exc_info=True)