Spaces:
Running
Running
# src/llms/bert_llm.py | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline | |
from typing import Optional, List | |
from .base_llm import BaseLLM | |
class BERTLanguageModel(BaseLLM): | |
def __init__( | |
self, | |
model_name: str = "bert-base-uncased", | |
max_length: int = 512 | |
): | |
"""Initialize BERT model""" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
self.generator = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer | |
) | |
self.max_length = max_length | |
def generate( | |
self, | |
prompt: str, | |
max_tokens: Optional[int] = None, | |
temperature: float = 0.7, | |
**kwargs | |
) -> str: | |
"""Generate text using BERT""" | |
output = self.generator( | |
prompt, | |
max_length=max_tokens or self.max_length, | |
temperature=temperature, | |
**kwargs | |
) | |
return output[0]['generated_text'] | |
def tokenize(self, text: str) -> List[str]: | |
"""Tokenize text using BERT tokenizer""" | |
return self.tokenizer.tokenize(text) | |
def count_tokens(self, text: str) -> int: | |
"""Count tokens in text""" | |
return len(self.tokenizer.encode(text)) |