File size: 2,250 Bytes
8320d4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from dataclasses import dataclass
from typing import Dict, Any, Union

from constants import (
    FIM_MIDDLE,
    FIM_PREFIX,
    FIM_SUFFIX,
    MIN_TEMPERATURE,
)
from settings import (
    FIM_INDICATOR,
)

@dataclass
class StarCoderRequestConfig:
    temperature: float
    max_new_tokens: int
    top_p: float
    repetition_penalty: float
    version: str

    def __post_init__(self):
        self.temperature = min(float(self.temperature), MIN_TEMPERATURE)
        self.max_new_tokens = int(self.max_new_tokens)
        self.top_p = float(self.top_p)
        self.repetition_penalty = float(self.repetition_penalty)
        self.do_sample = True
        self.seed = 42

    def __repr__(self) -> str:
        """Returns a custom string representation of the Configurations."""
        values = dict(
            model = self.version,
            temp = self.temperature,
            tokens = self.max_new_tokens,
            p = self.top_p,
            penalty = self.repetition_penalty,
            sample = self.do_sample,
            seed = self.seed,
        )
        return f"StarCoderRequestConfig({values})"

    def kwargs(self) -> Dict[str, Union[Any, float, int]]:
        """
        Returns a custom dictionary representation of the Configurations.
        removing the model version.
        """
        values = vars(self).copy()
        values.pop("version")
        return values

@dataclass
class StarCoderRequest:
    prompt: str
    settings: StarCoderRequestConfig

    def __post_init__(self):
        self.fim_mode = FIM_INDICATOR in self.prompt
        self.prefix, self.suffix = None, None
        if self.fim_mode:
            try:
                self.prefix, self.suffix = self.prompt.split(FIM_INDICATOR)
            except Exception as err:
                print(str(err))
                raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!") from err
            self.prompt = f"{FIM_PREFIX}{self.prefix}{FIM_SUFFIX}{self.suffix}{FIM_MIDDLE}"

    def __repr__(self) -> str:
        """Returns a custom string representation of the Request."""
        values = dict(
            prompt = self.prompt,
            configuration = self.settings,
        )
        return f"StarCoderRequest({values})"