import json import torch from flask import Flask, request, jsonify from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer import os app = Flask(__name__) # Load the fine-tuned model checkpoint if available; otherwise, load the pre-trained GPT-2 model if os.path.exists("fine_tuned_checkpoint"): model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint") else: model = GPT2LMHeadModel.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Function to fine-tune the model def fine_tune_model(chat_history): # Prepare training data for fine-tuning input_texts = [item["message"] for item in chat_history] with open("train.txt", "w") as f: f.write("\n".join(input_texts)) # Load the dataset and create data collator dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Fine-tune the model trainer = Trainer(model=model, data_collator=data_collator) trainer.train("./training_directory") # Save the fine-tuned model model.save_pretrained("fine_tuned_model") @app.route("/chat", methods=["POST"]) def chat_with_model(): request_data = request.get_json() user_input = request_data["user_input"] chat_history = request_data.get("chat_history", []) # Append user message to the chat history chat_history.append({"role": "user", "message": user_input}) # Generate response response = generate_response(user_input, chat_history) # Append bot message to the chat history chat_history.append({"role": "bot", "message": response}) return jsonify({"bot_response": response, "chat_history": chat_history}) @app.route("/train", methods=["POST"]) def train_model(): chat_history = request.json["data"] # Fine-tune the model with the provided data fine_tune_model(chat_history) return "Model trained and updated successfully." def generate_response(user_input, chat_history): # Set the maximum number of previous messages to consider max_history = 3 # Use the last `max_history` messages from the chat history inputs = [item["message"] for item in chat_history[-max_history:]] input_text = "\n".join(inputs) # Tokenize the input text input_ids = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=True) # Generate response with torch.no_grad(): output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id) # Decode response and extract bot message bot_response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) return bot_response @app.route("/") def index(): return jsonify({"status" : True}) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)