File size: 2,449 Bytes
b3509ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from abc import ABC, abstractmethod

class EngineBase(ABC):
    """Abstract base class for the AI engines.
    Engines define the API for the AI engines that can be used in the swarm.
    """
    
    TOKEN_LIMITS = {
        "gpt-4": 16*1024,
        "gpt-4-0314": 16*1024,
        "gpt-4-32k": 32*1024,
        "gpt-4-32k-0314": 32*1024,
        "gpt-3.5-turbo": 4*1024,
        "gpt-3.5-turbo-0301": 4*1024
    }
    
    def __init__(self, provider, model_name: str, temperature: float, max_response_tokens: int):
        self.provider = provider
        self.model_name = model_name
        self.temperature = temperature
        self.max_response_tokens = max_response_tokens

    @abstractmethod
    def call_model(self, conversation: list) -> str:
        """Call the model with the given conversation.
        Input always in the format of openai's conversation.
        Output a string.

        Args:
            conversation (list[dict]): The conversation to be completed. Example:
                [
                    {"role": "system", "content": configuration_prompt},
                    {"role": "user", "content": prompt}
                ]

        Returns:
            str: The response from the model.
        """
        raise NotImplementedError

    @abstractmethod
    def max_input_length(self) -> int:
        """Returns the maximum length of the input to the model.

        Returns:
            int: The maximum length of the input to the model.
        """
        raise NotImplementedError
    
    @abstractmethod
    def truncate_message(self, message):
        """Truncates the message using tiktoken"""
        raise NotImplementedError
    
        
    def max_input_length(self) -> int:
        """Returns the maximum length of the input to the model in temrs of tokens.

        Returns:
            int: The max tokens to input to the model.
        """
        return self.TOKEN_LIMITS[self.model_name]-self.max_response_tokens
    
    def truncate_message(self, message, token_limit=None):
        """Truncates the message using tiktoken"""
        max_tokens = self.max_input_length()
        message_tokens = self.tiktoken_encoding.encode(message)

        if token_limit is not None:
            max_tokens = min(max_tokens, token_limit)

        if len(message_tokens) <= max_tokens:
            return message
        else:
            return self.tiktoken_encoding.decode(message_tokens[:max_tokens])