roshan_project / tempCodeRunnerFile.py
pujan
first commit
1ff5755
from flask import Flask, jsonify, request
from flask_cors import CORS
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from history import load_dataset, get_unique_next_words_from_dataset
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Global variable to store the predicted words
predicted_words = []
def generate_predicted_words(input_text):
# Load the model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Load the dataset
dataset_name = "dataset.txt"
dataset = load_dataset(dataset_name)
history_next_text = get_unique_next_words_from_dataset(input_text, dataset)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
# Forward pass through the model
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
logits = outputs.logits
# Get the logits for the last token
last_token_logits = logits[:, -1, :]
probabilities = torch.softmax(last_token_logits, dim=-1)
# Get the top 50 most probable next tokens
top_50_probs, top_50_indices = torch.topk(probabilities, 50)
top_50_tokens = [tokenizer.decode([idx]) for idx in top_50_indices[0]]
words = []
removable_words = [' (', ' a', "'s", ' "', ' -', ' as', " '"]
for token in top_50_tokens:
if len(token) != 1 and token not in removable_words:
words.append(token)
return history_next_text + words # Return combined words
@app.route('/api/display_words', methods=['GET'])
def get_display_words():
# Get the count from query parameters
count = int(request.args.get('count', 0))
if not predicted_words:
# Generate the list only once if it's not generated yet
input_text = "Are" # Default input, can be changed as needed
predicted_words.extend(generate_predicted_words(input_text))
# Serve the slice of predicted words based on the count
start_index = 9 * count
end_index = start_index + 9
if start_index >= len(predicted_words): # Reset if out of bounds
count = 0
start_index = 0
end_index = 9
display_words = predicted_words[start_index:end_index]
return jsonify(display_words)
@app.route('/api/guu', methods=['POST'])
def predict_words():
try:
# Get the JSON data from the request
data = request.get_json()
print("data", data)
# Check if the JSON was parsed properly
if not isinstance(data, dict):
return jsonify({'error': 'Invalid JSON format'}), 400
input_text = data.get('message', '') # Extract the message
if not input_text:
return jsonify({'error': 'No input text provided'}), 400
global predicted_words
predicted_words = generate_predicted_words(
input_text) # Generate words based on the input
return jsonify(predicted_words) # Return the predicted words
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)