File size: 1,853 Bytes
188dc40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.config import Config
from utils.logger import Logger

logger = Logger.get_logger(__name__)


class ModelGenerator:
    """
    Singleton class responsible for generating text using a specified language model.

    This class initializes a language model and tokenizer, and provides methods 
    to generate text and extract code blocks from generated text.

    Attributes:
        device (torch.device): Device to run the model on (CPU or GPU).
        model (AutoModelForCausalLM): Language model for text generation.
        tokenizer (AutoTokenizer): Tokenizer corresponding to the language model.

    Methods:
        acceptTextGenerator(self, visitor, *args, **kwargs):
            Accepts a visitor to generates text based on the input provided with the model generator.
        acceptExtractCodeBlock(self, visitor, *args, **kwargs):
            Accepts a visitor to extract code blocks from the output text.
    """
    _instance = None
    _format_data_time = "%Y-%m-%d %H:%M:%S"

    def __new__(cls, model_name=Config.read('app', 'model')):
        if cls._instance is None:
            cls._instance = super(ModelGenerator, cls).__new__(cls)
            cls._instance._initialize(model_name)
        return cls._instance

    def _initialize(self, model_name):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def acceptTextGenerator(self, visitor, *args, **kwargs):
        return visitor.visit(self, *args, **kwargs)

    def acceptExtractCodeBlock(self, visitor, *args, **kwargs):
        return visitor.visit(self, *args, **kwargs)