|
import json |
|
import torch |
|
from flask import Flask, request, jsonify |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer |
|
import os |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
if os.path.exists("fine_tuned_checkpoint"): |
|
model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint") |
|
else: |
|
model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
|
|
|
|
def fine_tune_model(chat_history): |
|
|
|
input_texts = [item["message"] for item in chat_history] |
|
with open("train.txt", "w") as f: |
|
f.write("\n".join(input_texts)) |
|
|
|
|
|
dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128) |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
|
|
trainer = Trainer(model=model, data_collator=data_collator) |
|
trainer.train("./training_directory") |
|
|
|
|
|
model.save_pretrained("fine_tuned_model") |
|
|
|
@app.route("/chat", methods=["POST"]) |
|
def chat_with_model(): |
|
request_data = request.get_json() |
|
user_input = request_data["user_input"] |
|
chat_history = request_data.get("chat_history", []) |
|
|
|
|
|
chat_history.append({"role": "user", "message": user_input}) |
|
|
|
|
|
response = generate_response(user_input, chat_history) |
|
|
|
|
|
chat_history.append({"role": "bot", "message": response}) |
|
|
|
return jsonify({"bot_response": response, "chat_history": chat_history}) |
|
|
|
@app.route("/train", methods=["POST"]) |
|
def train_model(): |
|
chat_history = request.json["data"] |
|
|
|
|
|
fine_tune_model(chat_history) |
|
|
|
return "Model trained and updated successfully." |
|
|
|
def generate_response(user_input, chat_history): |
|
|
|
max_history = 3 |
|
|
|
|
|
inputs = [item["message"] for item in chat_history[-max_history:]] |
|
input_text = "\n".join(inputs) |
|
|
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=True) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id) |
|
|
|
|
|
bot_response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
|
|
return bot_response |
|
|
|
@app.route("/") |
|
def index(): |
|
return jsonify({"status" : True}) |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860) |
|
|