# model.py import logging import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig # Logger configuration logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') logger = logging.getLogger(__name__) #model_path = "/opt/Llama-2-13B-chat-GPTQ" class Model: def __init__(self, model_path): self.model_name = model_path self.model = None self.tokenizer = None self.loaded = False def load(self, precision='fp16'): try: # Check if CUDA is available if not torch.cuda.is_available(): raise EnvironmentError("CUDA not available.") # Set precision settings if precision == 'fp16': torch_dtype = torch.float16 else: torch_dtype = torch.float32 # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) # Set up model configuration config = AutoConfig.from_pretrained(self.model_name) #config.quantization_config["disable_exllama"] = False #config.quantization_config["use_exllama"] = True #config.quantization_config["exllama_config"] = {"version": 2} # Load model with configuration and precision self.model = AutoModelForCausalLM.from_pretrained( self.model_name, config=config, device_map="cuda:0", # Set to GPU 0 torch_dtype=torch_dtype ) self.loaded = True logger.info(f"Model loaded successfully on GPU with {precision} precision.") except Exception as e: logger.error(f"Error loading model: {e}") def predict(self, input_text, max_length=50): if not self.loaded: logger.error("Model not loaded. Please load the model before prediction.") return None logger.info("========== Start Prediction ==========") try: # Ensure the input_text is a string if not isinstance(input_text, str): raise ValueError("Input text must be a string.") # Encoding the input text input_ids = self.tokenizer.encode(input_text, return_tensors='pt') # Move input to the same device as model input_ids = input_ids.to(next(self.model.parameters()).device) # Generating output using the model outputs = self.model.generate(input_ids, max_length=max_length) # Decoding and returning the generated text response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) logger.info("Response: {}".format(response)) except Exception as e: logger.error(f"Error during prediction: {e}") response = None logger.info("========== End Prediction ==========") return response