Spaces:
Runtime error
Runtime error
from abc import ABC, abstractmethod | |
class EngineBase(ABC): | |
"""Abstract base class for the AI engines. | |
Engines define the API for the AI engines that can be used in the swarm. | |
""" | |
TOKEN_LIMITS = { | |
"gpt-4": 16*1024, | |
"gpt-4-0314": 16*1024, | |
"gpt-4-32k": 32*1024, | |
"gpt-4-32k-0314": 32*1024, | |
"gpt-3.5-turbo": 4*1024, | |
"gpt-3.5-turbo-0301": 4*1024 | |
} | |
def __init__(self, provider, model_name: str, temperature: float, max_response_tokens: int): | |
self.provider = provider | |
self.model_name = model_name | |
self.temperature = temperature | |
self.max_response_tokens = max_response_tokens | |
def call_model(self, conversation: list) -> str: | |
"""Call the model with the given conversation. | |
Input always in the format of openai's conversation. | |
Output a string. | |
Args: | |
conversation (list[dict]): The conversation to be completed. Example: | |
[ | |
{"role": "system", "content": configuration_prompt}, | |
{"role": "user", "content": prompt} | |
] | |
Returns: | |
str: The response from the model. | |
""" | |
raise NotImplementedError | |
def max_input_length(self) -> int: | |
"""Returns the maximum length of the input to the model. | |
Returns: | |
int: The maximum length of the input to the model. | |
""" | |
raise NotImplementedError | |
def truncate_message(self, message): | |
"""Truncates the message using tiktoken""" | |
raise NotImplementedError | |
def max_input_length(self) -> int: | |
"""Returns the maximum length of the input to the model in temrs of tokens. | |
Returns: | |
int: The max tokens to input to the model. | |
""" | |
return self.TOKEN_LIMITS[self.model_name]-self.max_response_tokens | |
def truncate_message(self, message, token_limit=None): | |
"""Truncates the message using tiktoken""" | |
max_tokens = self.max_input_length() | |
message_tokens = self.tiktoken_encoding.encode(message) | |
if token_limit is not None: | |
max_tokens = min(max_tokens, token_limit) | |
if len(message_tokens) <= max_tokens: | |
return message | |
else: | |
return self.tiktoken_encoding.decode(message_tokens[:max_tokens]) |