Spaces:
Running
Running
File size: 5,856 Bytes
1ff5755 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from flask import Flask, jsonify, request
from flask_cors import CORS
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from history import load_dataset, get_unique_next_words_from_dataset
app = Flask(__name__)
CORS(app)
# Load the model and tokenizer once when the app starts
model = GPT2LMHeadModel.from_pretrained("gpt2").to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Global variables
predicted_words = []
append_list = []
default_predicted_words = ['i', 'what', 'hello', 'where', 'who', 'how', 'can', 'is', 'are', 'could',
'would', 'may', 'do', 'does', 'will', 'shall', 'did', 'have', 'has',
'had', 'am', 'were', 'was', 'should', 'might', 'must', 'please', 'you',
'he', 'she', 'they', 'it', 'this', 'that', 'these', 'those', 'let',
'we', 'my', 'your', 'his', 'her', 'their', 'our', 'the',
'there', 'come', 'go', 'bring', 'take', 'give', 'help', 'want',
'need', 'eat', 'drink', 'sleep', 'play', 'run', 'walk', 'talk', 'call',
'find', 'make', 'see', 'get', 'know']
starting_words_for_home = [
"i", "let’s", "can", "the", "please", "it’s", "this", "i’m", "i’ll",
"you", "we", "my", "can’t", "shall", "would", "will", "do",
"should", "they", "let"
]
def generate_predicted_words(input_text,dataset_name= 'dataset.txt'):
# Load the dataset
# dataset_name = "dataset.txt"
dataset = load_dataset(dataset_name)
history_next_text = get_unique_next_words_from_dataset(input_text, dataset)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# Forward pass through the model
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
logits = outputs.logits
# Get the logits for the last token
last_token_logits = logits[:, -1, :]
probabilities = torch.softmax(last_token_logits, dim=-1)
# Get the top 50 most probable next tokens
top_50_probs, top_50_indices = torch.topk(probabilities, 50)
top_50_tokens = [tokenizer.decode([idx], clean_up_tokenization_spaces=False) for idx in top_50_indices[0]]
words = []
removable_words = [' (', ' a', "'s", ' "', ' -', ' as', " '", "the", " the", "an", " an", "<|endoftext|>, "]
for token in top_50_tokens:
if len(token) != 1 and token not in removable_words:
words.append(token.strip().lower())
return history_next_text + words
@app.route('/api/display_words', methods=['GET'])
def get_display_words():
count = int(request.args.get('count', 0))
start_index = 9 * count
end_index = start_index + 9
if start_index >= len(predicted_words): # Reset if out of bounds
count = 0
start_index = 0
end_index = 9
display_words = default_predicted_words[start_index:end_index]
return jsonify(display_words)
@app.route('/api/scenerio', methods=['POST'])
# # @app.route('/api/select_location', methods=['GET'])
def scenerio():
# Get the query parameter from the URL, e.g., /api/select_location?place=home
place = request.args.get('place')
if place == "home":
dataset = "C:\Users\bhand\OneDrive\Desktop\hackthon_ktm\scenerio\home_scenerio.txt"
# predicted_words = generate_predicted_words(input_text,dataset_name= 'dataset.txt')
return jsonify(starting_words_for_home[:9])
elif place =='school':
dataset = 'C:\Users\bhand\OneDrive\Desktop\hackthon_ktm\scenerio\home_scenerio.txt'
return jsonify(starting_words_for_school[:9])
# display_words = default_predicted_words[start_index:end_index]
# return jsonify(display_words)
@app.route('/api/huu', methods=['GET'])
def fetch_most_repeated_sentences(): # Ensure the function name is unique
try:
with open('most_repeated_sentences.txt', 'r') as file:
# Read the first 5 lines
lines = []
for _ in range(5):
text = file.readline().strip().split(":")[0]
print(text)
lines.append(text)
# lines = [file.readline().strip().split(':')[0] for _ in range(5)]
return jsonify(lines), 200 # Return the lines as JSON with a 200 OK status
except FileNotFoundError:
return jsonify({"error": "File not found."}), 404 # Handle file not found error
except Exception as e:
return jsonify({"error": str(e)}), 500 # Handle other potential errors
@app.route('/api/guu', methods=['POST'])
def predict_words():
global predicted_words, append_list
try:
data = request.get_json()
print("Received data:", data)
if not isinstance(data, dict):
return jsonify({'error': 'Invalid JSON format'}), 400
input_text = data.get('item', '').strip() # Ensure we are checking the stripped input
# Handle case when input_text is "1"
if input_text == "1":
print("Resetting append_list")
append_list = [] # Reset the append list
return jsonify(default_predicted_words[:9]) # Return the default words
if not input_text:
return jsonify({'error': 'No input text provided'}), 400
append_list.append(input_text)
print("Current append list:", append_list)
combined_input = ' '.join(append_list)
print("Combined input for prediction:", combined_input)
predicted_words = generate_predicted_words(combined_input)
print("Predicted words:", predicted_words)
return jsonify(predicted_words[:9])
except Exception as e:
print(f"An error occurred: {str(e)}") # Log the error message
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)
|