File size: 1,392 Bytes
41d1bc5
15d89f9
 
41d1bc5
15d89f9
 
41d1bc5
 
 
 
15d89f9
 
 
 
41d1bc5
 
 
 
 
 
 
15d89f9
 
 
 
 
 
41d1bc5
 
15d89f9
 
 
 
 
 
 
 
41d1bc5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from .py_generate import PyGenerator
from .rs_generate import RsGenerator
from .go_generate import GoGenerator
from .generator_types import Generator
from .model import CodeLlama, ModelBase, GPT4, GPT35, StarChat, GPTDavinci, Samba, GPT4o, GroqBase


def generator_factory(lang: str) -> Generator:
    if lang == "py" or lang == "python":
        return PyGenerator()
    elif lang == "rs" or lang == "rust":
        return RsGenerator()
    elif lang == "go" or lang == "golang":
        return GoGenerator()
    else:
        raise ValueError(f"Invalid language for generator: {lang}")


def model_factory(model_name: str) -> ModelBase:
    if model_name == "gpt-4":
        return GPT4()
    elif model_name == "gpt-4o":
        return GPT4o()
    elif model_name == "samba":
        return Samba()
    elif model_name == "groq":
        return GroqBase()
    elif model_name == "gpt-3.5-turbo-0613":
        return GPT35()
    elif model_name == "starchat":
        return StarChat()
    elif model_name.startswith("codellama"):
        # if it has `-` in the name, version was specified
        kwargs = {}
        if "-" in model_name:
            kwargs["version"] = model_name.split("-")[1]
        return CodeLlama(**kwargs)
    elif model_name.startswith("text-davinci"):
        return GPTDavinci(model_name)
    else:
        raise ValueError(f"Invalid model name: {model_name}")