import torch from transformers import GPT2Tokenizer, GPT2LMHeadModel import gradio as gr # Load the custom model and tokenizer model_path = 'redael/model_udc' tokenizer = GPT2Tokenizer.from_pretrained(model_path) model = GPT2LMHeadModel.from_pretrained(model_path) # Check if CUDA is available and use GPU if possible, enable FP16 precision device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) if device.type == 'cuda': model = model.half() # Use FP16 precision def generate_response(prompt, model, tokenizer, max_length=100, num_beams=1, temperature=0.7, top_p=0.9, repetition_penalty=2.0): # Prepare the prompt prompt = f"User: {prompt}\nAssistant:" inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) outputs = model.generate( inputs['input_ids'], max_length=max_length, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, num_beams=num_beams, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, early_stopping=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Post-processing to clean up the response response = response.split("Assistant:")[-1].strip() response_lines = response.split('\n') clean_response = [] for line in response_lines: if "User:" not in line and "Assistant:" not in line: clean_response.append(line) response = ' '.join(clean_response) return response.strip() def respond(message, history): # Prepare the prompt from the history and the new message system_message = "You are a friendly chatbot." conversation = system_message + "\n" for user_message, assistant_response in history: conversation += f"User: {user_message}\nAssistant: {assistant_response}\n" conversation += f"User: {message}\nAssistant:" # Fixed values for generation parameters max_tokens = 100 # Adjusted max tokens temperature = 0.7 top_p = 0.9 response = generate_response(conversation, model, tokenizer, max_length=max_tokens, temperature=temperature, top_p=top_p) return response # Gradio Chat Interface demo = gr.ChatInterface( respond ) if __name__ == "__main__": demo.launch()