import logging import os import torch from transformers import AutoTokenizer, AutoModelForCausalLM # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Checkpoint paths model_checkpoint_path = "model_checkpoint.pth" tokenizer_checkpoint_path = "tokenizer_checkpoint.pth" # Load model and tokenizer from checkpoint if they exist if os.path.exists(model_checkpoint_path) and os.path.exists(tokenizer_checkpoint_path): try: model = torch.load(model_checkpoint_path) tokenizer = torch.load(tokenizer_checkpoint_path) logger.info("Model and tokenizer loaded from checkpoint.") except Exception as e: logger.error(f"Failed to load model or tokenizer from checkpoint: {e}") raise else: # Load model directly try: tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") model = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b") logger.info("Model and tokenizer loaded successfully.") except Exception as e: logger.error(f"Failed to load model or tokenizer: {e}") raise def respond(user_input, history, system_message, max_tokens=20, temperature=0.9, top_p=0.9): messages = [{"role": "system", "content": system_message}] messages.extend(history) messages.append({"role": "user", "content": user_input}) # Convert messages to a single string input_text = " ".join([msg["content"] for msg in messages]) # Tokenize the input text inputs = tokenizer(input_text, return_tensors="pt") # Generate attention mask attention_mask = inputs["attention_mask"] # Generate text using the model outputs = model.generate( inputs.input_ids, attention_mask=attention_mask, max_length=max_tokens, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, do_sample=True ) # Decode the generated text response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response if __name__ == "__main__": print("Welcome to the Chatbot!") while True: user_input = input("You: ") system_message = "Chatbot: " history = [{"role": "assistant", "content": "Hello, how can I assist you today?"}] response = respond(user_input, history, system_message) print(response)