from flask import Flask, jsonify, request
from flask_cors import CORS
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from history import load_dataset, get_unique_next_words_from_dataset

app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

# Global variable to store the predicted words
predicted_words = []


def generate_predicted_words(input_text):
    # Load the model and tokenizer
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Load the dataset
    dataset_name = "dataset.txt"
    dataset = load_dataset(dataset_name)

    history_next_text = get_unique_next_words_from_dataset(input_text, dataset)

    # Tokenize input
    inputs = tokenizer(input_text, return_tensors="pt").to("cpu")

    # Forward pass through the model
    with torch.no_grad():
        outputs = model(**inputs, return_dict=True)
        logits = outputs.logits

    # Get the logits for the last token
    last_token_logits = logits[:, -1, :]
    probabilities = torch.softmax(last_token_logits, dim=-1)

    # Get the top 50 most probable next tokens
    top_50_probs, top_50_indices = torch.topk(probabilities, 50)
    top_50_tokens = [tokenizer.decode([idx]) for idx in top_50_indices[0]]

    words = []
    removable_words = [' (', ' a', "'s", ' "', ' -', ' as', " '"]

    for token in top_50_tokens:
        if len(token) != 1 and token not in removable_words:
            words.append(token)

    return history_next_text + words  # Return combined words


@app.route('/api/display_words', methods=['GET'])
def get_display_words():
    # Get the count from query parameters
    count = int(request.args.get('count', 0))

    if not predicted_words:
        # Generate the list only once if it's not generated yet
        input_text = "Are"  # Default input, can be changed as needed
        predicted_words.extend(generate_predicted_words(input_text))

    # Serve the slice of predicted words based on the count
    start_index = 9 * count
    end_index = start_index + 9

    if start_index >= len(predicted_words):  # Reset if out of bounds
        count = 0
        start_index = 0
        end_index = 9

    display_words = predicted_words[start_index:end_index]

    return jsonify(display_words)


@app.route('/api/guu', methods=['POST'])
def predict_words():
    try:
        # Get the JSON data from the request
        data = request.get_json()
        print("data", data)

        # Check if the JSON was parsed properly
        if not isinstance(data, dict):
            return jsonify({'error': 'Invalid JSON format'}), 400

        input_text = data.get('message', '')  # Extract the message

        if not input_text:
            return jsonify({'error': 'No input text provided'}), 400

        global predicted_words
        predicted_words = generate_predicted_words(
            input_text)  # Generate words based on the input

        return jsonify(predicted_words)  # Return the predicted words

    except Exception as e:
        return jsonify({'error': str(e)}), 500


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug=True)