Ubantubot1 / app.py
FridayMaster's picture
Update app.py
95fd627 verified
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import gradio as gr
# Load the custom model and tokenizer
model_path = 'redael/model_udc'
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
# Check if CUDA is available and use GPU if possible, enable FP16 precision
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
if device.type == 'cuda':
model = model.half() # Use FP16 precision
def generate_response(prompt, model, tokenizer, max_length=100, num_beams=1, temperature=0.7, top_p=0.9, repetition_penalty=2.0):
# Prepare the prompt
prompt = f"User: {prompt}\nAssistant:"
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
outputs = model.generate(
inputs['input_ids'],
max_length=max_length,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
num_beams=num_beams,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
early_stopping=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Post-processing to clean up the response
response = response.split("Assistant:")[-1].strip()
response_lines = response.split('\n')
clean_response = []
for line in response_lines:
if "User:" not in line and "Assistant:" not in line:
clean_response.append(line)
response = ' '.join(clean_response)
return response.strip()
def respond(message, history):
# Prepare the prompt from the history and the new message
system_message = "You are a friendly chatbot."
conversation = system_message + "\n"
for user_message, assistant_response in history:
conversation += f"User: {user_message}\nAssistant: {assistant_response}\n"
conversation += f"User: {message}\nAssistant:"
# Fixed values for generation parameters
max_tokens = 100 # Adjusted max tokens
temperature = 0.7
top_p = 0.9
response = generate_response(conversation, model, tokenizer, max_length=max_tokens, temperature=temperature, top_p=top_p)
return response
# Gradio Chat Interface
demo = gr.ChatInterface(
respond
)
if __name__ == "__main__":
demo.launch()