Spaces:
Runtime error
Runtime error
File size: 2,449 Bytes
b3509ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from abc import ABC, abstractmethod
class EngineBase(ABC):
"""Abstract base class for the AI engines.
Engines define the API for the AI engines that can be used in the swarm.
"""
TOKEN_LIMITS = {
"gpt-4": 16*1024,
"gpt-4-0314": 16*1024,
"gpt-4-32k": 32*1024,
"gpt-4-32k-0314": 32*1024,
"gpt-3.5-turbo": 4*1024,
"gpt-3.5-turbo-0301": 4*1024
}
def __init__(self, provider, model_name: str, temperature: float, max_response_tokens: int):
self.provider = provider
self.model_name = model_name
self.temperature = temperature
self.max_response_tokens = max_response_tokens
@abstractmethod
def call_model(self, conversation: list) -> str:
"""Call the model with the given conversation.
Input always in the format of openai's conversation.
Output a string.
Args:
conversation (list[dict]): The conversation to be completed. Example:
[
{"role": "system", "content": configuration_prompt},
{"role": "user", "content": prompt}
]
Returns:
str: The response from the model.
"""
raise NotImplementedError
@abstractmethod
def max_input_length(self) -> int:
"""Returns the maximum length of the input to the model.
Returns:
int: The maximum length of the input to the model.
"""
raise NotImplementedError
@abstractmethod
def truncate_message(self, message):
"""Truncates the message using tiktoken"""
raise NotImplementedError
def max_input_length(self) -> int:
"""Returns the maximum length of the input to the model in temrs of tokens.
Returns:
int: The max tokens to input to the model.
"""
return self.TOKEN_LIMITS[self.model_name]-self.max_response_tokens
def truncate_message(self, message, token_limit=None):
"""Truncates the message using tiktoken"""
max_tokens = self.max_input_length()
message_tokens = self.tiktoken_encoding.encode(message)
if token_limit is not None:
max_tokens = min(max_tokens, token_limit)
if len(message_tokens) <= max_tokens:
return message
else:
return self.tiktoken_encoding.decode(message_tokens[:max_tokens]) |