Tj's picture
Duplicate from Fisharp/starcoder-playground
8320d4e
from dataclasses import dataclass
from typing import Dict, Any, Union
from constants import (
FIM_MIDDLE,
FIM_PREFIX,
FIM_SUFFIX,
MIN_TEMPERATURE,
)
from settings import (
FIM_INDICATOR,
)
@dataclass
class StarCoderRequestConfig:
temperature: float
max_new_tokens: int
top_p: float
repetition_penalty: float
version: str
def __post_init__(self):
self.temperature = min(float(self.temperature), MIN_TEMPERATURE)
self.max_new_tokens = int(self.max_new_tokens)
self.top_p = float(self.top_p)
self.repetition_penalty = float(self.repetition_penalty)
self.do_sample = True
self.seed = 42
def __repr__(self) -> str:
"""Returns a custom string representation of the Configurations."""
values = dict(
model = self.version,
temp = self.temperature,
tokens = self.max_new_tokens,
p = self.top_p,
penalty = self.repetition_penalty,
sample = self.do_sample,
seed = self.seed,
)
return f"StarCoderRequestConfig({values})"
def kwargs(self) -> Dict[str, Union[Any, float, int]]:
"""
Returns a custom dictionary representation of the Configurations.
removing the model version.
"""
values = vars(self).copy()
values.pop("version")
return values
@dataclass
class StarCoderRequest:
prompt: str
settings: StarCoderRequestConfig
def __post_init__(self):
self.fim_mode = FIM_INDICATOR in self.prompt
self.prefix, self.suffix = None, None
if self.fim_mode:
try:
self.prefix, self.suffix = self.prompt.split(FIM_INDICATOR)
except Exception as err:
print(str(err))
raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!") from err
self.prompt = f"{FIM_PREFIX}{self.prefix}{FIM_SUFFIX}{self.suffix}{FIM_MIDDLE}"
def __repr__(self) -> str:
"""Returns a custom string representation of the Request."""
values = dict(
prompt = self.prompt,
configuration = self.settings,
)
return f"StarCoderRequest({values})"