File size: 5,424 Bytes
4b722ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from vertexai.generative_models import GenerativeModel

from dotenv import load_dotenv
from anthropic import AnthropicVertex
import os
from openai import OpenAI
from src.text_generation.vertexai_setup import initialize_vertexai_params

load_dotenv()
if "OPENAI_API_KEY" in os.environ:
    OAI_API_KEY = os.environ["OPENAI_API_KEY"]
if "VERTEXAI_PROJECTID" in os.environ:
    VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECTID"]


class LLMBaseClass:
    """
    Base Class for text generation - user needs to provide the HF model ID while instantiating the class after which
    the generate method can be called to generate responses
    
    """

    def __init__(self, model_id) -> None:

        match (model_id[0].lower()):
            case "gpt-4o-mini":  # for open AI models
                self.api_key = OAI_API_KEY
                self.model = OpenAI(api_key=self.api_key)
            case "claude-3-5-sonnet@20240620":  # for Claude through vertexAI
                self.api_key = None
                self.model = AnthropicVertex(region="europe-west1", project_id=VERTEXAI_PROJECT)
            case "gemini-1.0-pro":
                self.api_key = None
                self.model = GenerativeModel(model_id[0].lower())
            case _:  # for HF models
                self.api_key = None
                self.tokenizer = AutoTokenizer.from_pretrained(model_id)
                self.tokenizer.pad_token = self.tokenizer.eos_token

                self.tokenizer.chat_template = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- " \
                                               "bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{%- " \
                                               "elif " \
                                               "message['role'] == 'system' %}{{- '<<SYS>>\\n' + message[" \
                                               "'content'].strip() + " \
                                               "'\\n<</SYS>>\\n\\n' }}{%- elif message['role'] == 'assistant' %}{{- '[" \
                                               "ASST] ' " \
                                               "+ message['content'] + ' [/ASST]' + eos_token }}{%- endif %}{%- " \
                                               "endfor %} "
                # Initialize quantization to use less GPU
                if torch.cuda.is_available():
                    bnb_config = BitsAndBytesConfig(
                        load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4",
                        bnb_4bit_compute_dtype=torch.bfloat16
                    )
                else:
                    bnb_config = None
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_id,
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                    quantization_config=bnb_config,
                )

                self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id

                self.terminators = [
                    self.tokenizer.eos_token_id,
                    self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
                ]

    def generate(self, messages):
        match (self.model_id[0].lower()):
            case "gpt-4o-mini":
                completion = self.model.chat.completions.create(
                    model=self.model_id[0],
                    messages=messages,
                    temperature=0.6,
                    top_p=0.9,
                )
                # Return the generated content from the API response
                return completion.choices[0].message.content
            case "claude-3-5-sonnet@20240620" | "gemini-1.0-pro":
                initialize_vertexai_params()
                if "claude" in self.model_id[0].lower():
                    message = self.model.messages.create(
                        max_tokens=1024,
                        model=self.model_id[0],
                        messages=[
                            {
                              "role": "user",
                              "content": messages[0]["content"],
                            }
                          ],
                    )
                    return message.content[0].text
                else:
                    response = self.model.generate_content(messages[0]["content"])
                    return response
            case _:
                input_ids = self.tokenizer.apply_chat_template(
                    conversation=messages,
                    add_generation_prompt=True,
                    return_tensors="pt",
                    padding=True
                ).to(self.model.device)

                outputs = self.model.generate(
                    input_ids,
                    max_new_tokens=1024,
                    # eos_token_id=self.terminators,
                    pad_token_id=self.tokenizer.eos_token_id,
                    do_sample=True,
                    temperature=0.6,
                    top_p=0.9,
                )
                response = outputs[0][input_ids.shape[-1]:]

                return self.tokenizer.decode(response, skip_special_tokens=True)


# database/wikivoyage/wikivoyage_listings.lance/data/e2940f51-d754-4b54-a688-004bdb8e7aa2.lance