File size: 2,955 Bytes
389b910 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import json
import torch
from flask import Flask, request, jsonify
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer
import os
app = Flask(__name__)
# Load the fine-tuned model checkpoint if available; otherwise, load the pre-trained GPT-2 model
if os.path.exists("fine_tuned_checkpoint"):
model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint")
else:
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Function to fine-tune the model
def fine_tune_model(chat_history):
# Prepare training data for fine-tuning
input_texts = [item["message"] for item in chat_history]
with open("train.txt", "w") as f:
f.write("\n".join(input_texts))
# Load the dataset and create data collator
dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Fine-tune the model
trainer = Trainer(model=model, data_collator=data_collator)
trainer.train("./training_directory")
# Save the fine-tuned model
model.save_pretrained("fine_tuned_model")
@app.route("/chat", methods=["POST"])
def chat_with_model():
request_data = request.get_json()
user_input = request_data["user_input"]
chat_history = request_data.get("chat_history", [])
# Append user message to the chat history
chat_history.append({"role": "user", "message": user_input})
# Generate response
response = generate_response(user_input, chat_history)
# Append bot message to the chat history
chat_history.append({"role": "bot", "message": response})
return jsonify({"bot_response": response, "chat_history": chat_history})
@app.route("/train", methods=["POST"])
def train_model():
chat_history = request.json["data"]
# Fine-tune the model with the provided data
fine_tune_model(chat_history)
return "Model trained and updated successfully."
def generate_response(user_input, chat_history):
# Set the maximum number of previous messages to consider
max_history = 3
# Use the last `max_history` messages from the chat history
inputs = [item["message"] for item in chat_history[-max_history:]]
input_text = "\n".join(inputs)
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=True)
# Generate response
with torch.no_grad():
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
# Decode response and extract bot message
bot_response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
return bot_response
@app.route("/")
def index():
return jsonify({"status" : True})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|