scrAI / app.py
ZeroTwo3's picture
Duplicate from ZeroTwo3/flask_test
389b910
raw
history blame
2.96 kB
import json
import torch
from flask import Flask, request, jsonify
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer
import os
app = Flask(__name__)
# Load the fine-tuned model checkpoint if available; otherwise, load the pre-trained GPT-2 model
if os.path.exists("fine_tuned_checkpoint"):
model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint")
else:
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Function to fine-tune the model
def fine_tune_model(chat_history):
# Prepare training data for fine-tuning
input_texts = [item["message"] for item in chat_history]
with open("train.txt", "w") as f:
f.write("\n".join(input_texts))
# Load the dataset and create data collator
dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Fine-tune the model
trainer = Trainer(model=model, data_collator=data_collator)
trainer.train("./training_directory")
# Save the fine-tuned model
model.save_pretrained("fine_tuned_model")
@app.route("/chat", methods=["POST"])
def chat_with_model():
request_data = request.get_json()
user_input = request_data["user_input"]
chat_history = request_data.get("chat_history", [])
# Append user message to the chat history
chat_history.append({"role": "user", "message": user_input})
# Generate response
response = generate_response(user_input, chat_history)
# Append bot message to the chat history
chat_history.append({"role": "bot", "message": response})
return jsonify({"bot_response": response, "chat_history": chat_history})
@app.route("/train", methods=["POST"])
def train_model():
chat_history = request.json["data"]
# Fine-tune the model with the provided data
fine_tune_model(chat_history)
return "Model trained and updated successfully."
def generate_response(user_input, chat_history):
# Set the maximum number of previous messages to consider
max_history = 3
# Use the last `max_history` messages from the chat history
inputs = [item["message"] for item in chat_history[-max_history:]]
input_text = "\n".join(inputs)
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=True)
# Generate response
with torch.no_grad():
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)
# Decode response and extract bot message
bot_response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
return bot_response
@app.route("/")
def index():
return jsonify({"status" : True})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)