fine-tuned-models / chatbot.py
arcsu1's picture
update
507cd03
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
from datasets import load_dataset
import pandas as pd
import re
class ChatBot:
def __init__(self,dir,tokenizer,model,device):
self.directory = dir
self.tokenizer = tokenizer
self.model = model
self.device = device
self.model.to(self.device)
def generate_response(self, history):
combined_prompt = ""
# self.tokenizer.eos_token_id = '<|endoftext|>'
if len(history.user) > 7:
history.user = history.user[-7:]
history.ai = history.ai[-6:]
# Iterate over user and AI messages
for user_message, ai_message in zip(history.user, history.ai):
combined_prompt += f"<user> {user_message}{self.tokenizer.eos_token_id}<AI> {ai_message}{self.tokenizer.eos_token_id}"
# Include the last user message in the prompt for response generation
if history.user:
combined_prompt += f"<user> {history.user[-1]}{self.tokenizer.eos_token_id}<AI>"
# Tokenize and generate response
inputs = self.tokenizer.encode(combined_prompt, return_tensors="pt").to(self.device)
attention_mask = torch.ones(inputs.shape, device=self.device)
outputs = self.model.generate(
inputs,
max_new_tokens=20, # Adjust length as needed
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95,
pad_token_id=self.tokenizer.eos_token_id,
attention_mask=attention_mask,
repetition_penalty=1.2
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# response = response.replace(combined_prompt, "").split(".")[0]#.replace("(user 1's name)",'AI').replace("(user 2's name)",'AI').replace("[user 1's name]",'AI').replace('<user>','')
# print('here:\n', combined_prompt,'\n\n response:\n', response,'\n\n edit-resposne: \n', response.replace(combined_prompt, "").replace('(name)','AI').split(".")[0],'\n\n')
return response.replace(combined_prompt, "").split(".")[0]