# src/llms/llama_llm.py from transformers import LlamaTokenizer, LlamaForCausalLM import torch from typing import Optional, List from .base_llm import BaseLLM class LlamaLanguageModel(BaseLLM): def __init__( self, model_name: str = "meta-llama/Llama-2-7b", device: str = "cuda" if torch.cuda.is_available() else "cpu" ): """Initialize Llama model""" self.tokenizer = LlamaTokenizer.from_pretrained(model_name) self.model = LlamaForCausalLM.from_pretrained( model_name, device_map=device, torch_dtype=torch.float16 ) self.device = device def generate( self, prompt: str, max_tokens: Optional[int] = None, temperature: float = 0.7, **kwargs ) -> str: """Generate text using Llama""" inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) outputs = self.model.generate( **inputs, max_length=max_tokens if max_tokens else 100, temperature=temperature, **kwargs ) return self.tokenizer.decode(outputs[0], skip_special_tokens=True)