Ashmi Banerjee
first draft, with the gcp bucket
4b722ec
raw
history blame
5.42 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from vertexai.generative_models import GenerativeModel
from dotenv import load_dotenv
from anthropic import AnthropicVertex
import os
from openai import OpenAI
from src.text_generation.vertexai_setup import initialize_vertexai_params
load_dotenv()
if "OPENAI_API_KEY" in os.environ:
OAI_API_KEY = os.environ["OPENAI_API_KEY"]
if "VERTEXAI_PROJECTID" in os.environ:
VERTEXAI_PROJECT = os.environ["VERTEXAI_PROJECTID"]
class LLMBaseClass:
"""
Base Class for text generation - user needs to provide the HF model ID while instantiating the class after which
the generate method can be called to generate responses
"""
def __init__(self, model_id) -> None:
match (model_id[0].lower()):
case "gpt-4o-mini": # for open AI models
self.api_key = OAI_API_KEY
self.model = OpenAI(api_key=self.api_key)
case "claude-3-5-sonnet@20240620": # for Claude through vertexAI
self.api_key = None
self.model = AnthropicVertex(region="europe-west1", project_id=VERTEXAI_PROJECT)
case "gemini-1.0-pro":
self.api_key = None
self.model = GenerativeModel(model_id[0].lower())
case _: # for HF models
self.api_key = None
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.chat_template = "{%- for message in messages %}{%- if message['role'] == 'user' %}{{- " \
"bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}{%- " \
"elif " \
"message['role'] == 'system' %}{{- '<<SYS>>\\n' + message[" \
"'content'].strip() + " \
"'\\n<</SYS>>\\n\\n' }}{%- elif message['role'] == 'assistant' %}{{- '[" \
"ASST] ' " \
"+ message['content'] + ' [/ASST]' + eos_token }}{%- endif %}{%- " \
"endfor %} "
# Initialize quantization to use less GPU
if torch.cuda.is_available():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
else:
bnb_config = None
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
)
self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
def generate(self, messages):
match (self.model_id[0].lower()):
case "gpt-4o-mini":
completion = self.model.chat.completions.create(
model=self.model_id[0],
messages=messages,
temperature=0.6,
top_p=0.9,
)
# Return the generated content from the API response
return completion.choices[0].message.content
case "claude-3-5-sonnet@20240620" | "gemini-1.0-pro":
initialize_vertexai_params()
if "claude" in self.model_id[0].lower():
message = self.model.messages.create(
max_tokens=1024,
model=self.model_id[0],
messages=[
{
"role": "user",
"content": messages[0]["content"],
}
],
)
return message.content[0].text
else:
response = self.model.generate_content(messages[0]["content"])
return response
case _:
input_ids = self.tokenizer.apply_chat_template(
conversation=messages,
add_generation_prompt=True,
return_tensors="pt",
padding=True
).to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=1024,
# eos_token_id=self.terminators,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
return self.tokenizer.decode(response, skip_special_tokens=True)
# database/wikivoyage/wikivoyage_listings.lance/data/e2940f51-d754-4b54-a688-004bdb8e7aa2.lance