Spaces:
Running
Running
import json | |
import time | |
from typing import Callable | |
import litellm | |
from litellm.types.utils import ModelResponse, Usage | |
class BasetenError(Exception): | |
def __init__(self, status_code, message): | |
self.status_code = status_code | |
self.message = message | |
super().__init__( | |
self.message | |
) # Call the base class constructor with the parameters it needs | |
def validate_environment(api_key): | |
headers = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
} | |
if api_key: | |
headers["Authorization"] = f"Api-Key {api_key}" | |
return headers | |
def completion( | |
model: str, | |
messages: list, | |
model_response: ModelResponse, | |
print_verbose: Callable, | |
encoding, | |
api_key, | |
logging_obj, | |
optional_params: dict, | |
litellm_params=None, | |
logger_fn=None, | |
): | |
headers = validate_environment(api_key) | |
completion_url_fragment_1 = "https://app.baseten.co/models/" | |
completion_url_fragment_2 = "/predict" | |
model = model | |
prompt = "" | |
for message in messages: | |
if "role" in message: | |
if message["role"] == "user": | |
prompt += f"{message['content']}" | |
else: | |
prompt += f"{message['content']}" | |
else: | |
prompt += f"{message['content']}" | |
data = { | |
"inputs": prompt, | |
"prompt": prompt, | |
"parameters": optional_params, | |
"stream": ( | |
True | |
if "stream" in optional_params and optional_params["stream"] is True | |
else False | |
), | |
} | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
) | |
## COMPLETION CALL | |
response = litellm.module_level_client.post( | |
completion_url_fragment_1 + model + completion_url_fragment_2, | |
headers=headers, | |
data=json.dumps(data), | |
stream=( | |
True | |
if "stream" in optional_params and optional_params["stream"] is True | |
else False | |
), | |
) | |
if "text/event-stream" in response.headers["Content-Type"] or ( | |
"stream" in optional_params and optional_params["stream"] is True | |
): | |
return response.iter_lines() | |
else: | |
## LOGGING | |
logging_obj.post_call( | |
input=prompt, | |
api_key=api_key, | |
original_response=response.text, | |
additional_args={"complete_input_dict": data}, | |
) | |
print_verbose(f"raw model_response: {response.text}") | |
## RESPONSE OBJECT | |
completion_response = response.json() | |
if "error" in completion_response: | |
raise BasetenError( | |
message=completion_response["error"], | |
status_code=response.status_code, | |
) | |
else: | |
if "model_output" in completion_response: | |
if ( | |
isinstance(completion_response["model_output"], dict) | |
and "data" in completion_response["model_output"] | |
and isinstance(completion_response["model_output"]["data"], list) | |
): | |
model_response.choices[0].message.content = completion_response[ # type: ignore | |
"model_output" | |
][ | |
"data" | |
][ | |
0 | |
] | |
elif isinstance(completion_response["model_output"], str): | |
model_response.choices[0].message.content = completion_response[ # type: ignore | |
"model_output" | |
] | |
elif "completion" in completion_response and isinstance( | |
completion_response["completion"], str | |
): | |
model_response.choices[0].message.content = completion_response[ # type: ignore | |
"completion" | |
] | |
elif isinstance(completion_response, list) and len(completion_response) > 0: | |
if "generated_text" not in completion_response: | |
raise BasetenError( | |
message=f"Unable to parse response. Original response: {response.text}", | |
status_code=response.status_code, | |
) | |
model_response.choices[0].message.content = completion_response[0][ # type: ignore | |
"generated_text" | |
] | |
## GETTING LOGPROBS | |
if ( | |
"details" in completion_response[0] | |
and "tokens" in completion_response[0]["details"] | |
): | |
model_response.choices[0].finish_reason = completion_response[0][ | |
"details" | |
]["finish_reason"] | |
sum_logprob = 0 | |
for token in completion_response[0]["details"]["tokens"]: | |
sum_logprob += token["logprob"] | |
model_response.choices[0].logprobs = sum_logprob # type: ignore | |
else: | |
raise BasetenError( | |
message=f"Unable to parse response. Original response: {response.text}", | |
status_code=response.status_code, | |
) | |
## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. | |
prompt_tokens = len(encoding.encode(prompt)) | |
completion_tokens = len( | |
encoding.encode(model_response["choices"][0]["message"]["content"]) | |
) | |
model_response.created = int(time.time()) | |
model_response.model = model | |
usage = Usage( | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
total_tokens=prompt_tokens + completion_tokens, | |
) | |
setattr(model_response, "usage", usage) | |
return model_response | |
def embedding(): | |
# logic for parsing in - calling - parsing out model embedding calls | |
pass | |