File size: 2,955 Bytes
389b910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import torch
from flask import Flask, request, jsonify
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling, Trainer
import os

app = Flask(__name__)

# Load the fine-tuned model checkpoint if available; otherwise, load the pre-trained GPT-2 model
if os.path.exists("fine_tuned_checkpoint"):
    model = GPT2LMHeadModel.from_pretrained("fine_tuned_checkpoint")
else:
    model = GPT2LMHeadModel.from_pretrained("gpt2")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Function to fine-tune the model
def fine_tune_model(chat_history):
    # Prepare training data for fine-tuning
    input_texts = [item["message"] for item in chat_history]
    with open("train.txt", "w") as f:
        f.write("\n".join(input_texts))
    
    # Load the dataset and create data collator
    dataset = TextDataset(tokenizer=tokenizer, file_path="train.txt", block_size=128)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    
    # Fine-tune the model
    trainer = Trainer(model=model, data_collator=data_collator)
    trainer.train("./training_directory")
    
    # Save the fine-tuned model
    model.save_pretrained("fine_tuned_model")

@app.route("/chat", methods=["POST"])
def chat_with_model():
    request_data = request.get_json()
    user_input = request_data["user_input"]
    chat_history = request_data.get("chat_history", [])

    # Append user message to the chat history
    chat_history.append({"role": "user", "message": user_input})

    # Generate response
    response = generate_response(user_input, chat_history)

    # Append bot message to the chat history
    chat_history.append({"role": "bot", "message": response})

    return jsonify({"bot_response": response, "chat_history": chat_history})

@app.route("/train", methods=["POST"])
def train_model():
    chat_history = request.json["data"]
    
    # Fine-tune the model with the provided data
    fine_tune_model(chat_history)
    
    return "Model trained and updated successfully."

def generate_response(user_input, chat_history):
    # Set the maximum number of previous messages to consider
    max_history = 3

    # Use the last `max_history` messages from the chat history
    inputs = [item["message"] for item in chat_history[-max_history:]]
    input_text = "\n".join(inputs)

    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors="pt", add_special_tokens=True)

    # Generate response
    with torch.no_grad():
        output = model.generate(input_ids, max_length=100, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)

    # Decode response and extract bot message
    bot_response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

    return bot_response

@app.route("/")
def index():
    return jsonify({"status" : True})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)