roshan_project / app.py
pujan
changes
bb49f69
from flask import Flask, jsonify, request
import requests
import redis
import json
from flask_cors import CORS
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from history import load_dataset, get_unique_next_words_from_dataset
from dotenv import load_dotenv
import os
os.environ["TRANSFORMERS_CACHE"] = "/code/.cache"
from typing import List, Dict, Optional, Union
import logging
from most_repeted_sentences import sentences_name, get_most_repeated_sentences, save_most_repeated_sentences
load_dotenv()
app = Flask(__name__)
CORS(app)
# Setup logging
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
# Pixabay API setup
PIXABAY_URL = "https://pixabay.com/api/?key=${pixabayApiKey}&q=${word}&image_type=all&per_page=3"
PIXABAY_API_KEY =os.getenv("API_kEY")
# setup redis
redis_client = redis.Redis(
host='redis-18594.c301.ap-south-1-1.ec2.redns.redis-cloud.com',
port=18594,
decode_responses=True,
username="default",
password=os.getenv("REDIS_PASSWORD")
)
print(redis_client)
# Load the model and tokenizer once when the app starts
model = GPT2LMHeadModel.from_pretrained("gpt2").to("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Global variables
predicted_words = []
append_list = []
global_count=0
default_predicted_words = ['i', 'what', 'hello', 'where', 'who', 'how', 'can', 'is', 'are', 'could',
'would', 'may', 'can', 'please', 'will', 'shall', 'did', 'have', 'has',
'had', 'am', 'were', 'was', 'should', 'might', 'must', 'please', 'you',
'he', 'she', 'they', 'it', 'this', 'that', 'these', 'those', 'let',
'we', 'my', 'your', 'his', 'her', 'their', 'our', 'the',
'there', 'come', 'go', 'bring', 'take', 'give', 'help', 'want',
'need', 'eat', 'drink', 'sleep', 'play', 'run', 'walk', 'talk', 'call',
'find', 'make', 'see', 'get', 'know']
def generate_predicted_words(input_text,index =0):
# Load the dataset
dataset_name = "dataset.txt"
dataset = load_dataset(dataset_name)
history_next_text = get_unique_next_words_from_dataset(input_text, dataset)
# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
# Forward pass through the model
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
logits = outputs.logits
# Get the logits for the last token
last_token_logits = logits[:, -1, :]
probabilities = torch.softmax(last_token_logits, dim=-1)
# Get the top 50 most probable next tokens
top_50_probs, top_50_indices = torch.topk(probabilities, 50)
top_50_tokens = [tokenizer.decode([idx], clean_up_tokenization_spaces=False) for idx in top_50_indices[0]]
words = []
removable_words = [' (', ' a', "'s", ' "', ' -', ' as', " '", "the", " the", "an", " an", "<|endoftext|>", '’d','’m', '’ll','t’s' ,]
for token in top_50_tokens:
if len(token) != 1 and token not in removable_words:
words.append(token.strip().lower())
return history_next_text + words
# fetch from pixabay
def fetch_images_from_pixabay(query: str) -> dict:
# print("yo query ko lagi fetch hudai xa..." , query)
response = requests.get(PIXABAY_URL, params={
"key": PIXABAY_API_KEY,
"q": query,
"image_type": "all",
"per_page": "3"
})
# print("this is from pixabay haita====>" , response.json())
if response.status_code != 200:
return {"error": "Failed to fetch data from Pixabay"}
return response.json()
# fetch images api
@app.route('/api/images', methods=['GET'])
def get_images():
query = request.args.get('query')
correspond_id=request.args.get('id')
print("yo chai id hai" , correspond_id)
print("yo chai query ho hai" , query)
if not query:
return jsonify({"error": "Query parameter is required"}), 400
# Check Redis cache for images
cached_images = redis_client.get('image_cache')
# print("yo ho chaiyeko cached heloooooooooooooooooooooo", cached_images)
if cached_images:
cached_images = json.loads(cached_images) # Convert JSON string to dictionary
# print("cached_img" , cached_images)
for i in cached_images['hits']:
# print("lagalagalag------------>",i.get('query_id'))
# compare the id of the already queried id and id of the query currently
if(i.get('query_id')==correspond_id):
print("Fetching from cache-------------->" , i['previewURL'])
return jsonify(i['previewURL'])
# print("Fetching from Pixabay")
# Fetch from Pixabay if not in cache
# Fetch from Pixabay if not in cache
data = fetch_images_from_pixabay(query)
if "error" in data:
return jsonify(data), 500
for i in data['hits']:
i['query_id']=correspond_id
# print("i bhitra haita",i['query_id'])
# get the total images i.e previously cached images and current images.
if cached_images:
data['hits'] = cached_images['hits'] + data['hits']
data['total'] = cached_images['total'] + data['total']
# Cache the result in Redis for 1hrs
redis_client.setex('image_cache', 86400, json.dumps(data))
print("image from Pixabay-------------->" , data['hits'].pop()['previewURL'])
return jsonify(data['hits'].pop()['previewURL'])
@app.route('/api/display_words', methods=['GET'])
def get_display_words():
try:
count = int(request.args.get('count', global_count)) # Default to 0 if 'count' is not provided
print(type(count))
except ValueError:
return jsonify({"error": "Invalid count value"}), 400
print("Count:", count)
start_index = 9 * count
end_index = start_index + 9
print("Start index:", start_index)
print("End index:", end_index)
if start_index >= len(default_predicted_words): # Reset if out of bounds
count = 0
start_index = 0
end_index = 9
display_words = default_predicted_words[start_index:end_index]
print("Display words:", display_words)
return display_words
# @app.route('/api/scenerio', methods=['POST'])
# # @app.route('/api/select_location', methods=['GET'])
# def scenerio():
# # Get the query parameter from the URL, e.g., /api/select_location?place=home
# place = request.args.get('place')
# if place == "home":
# display_words = default_predicted_words[start_index:end_index]
# return jsonify(display_words)
# @app.route('/api/huu', methods=['GET'])
# def fetch_most_repeated_sentences(): # Ensure the function name is unique
# try:
# with open('most_repeated_sentences.txt', 'r') as file:
# # Read the first 5 lines
# lines = []
# for _ in range(5):
# text = file.readline().strip().split(":")[0]
# print(text)
# lines.append(text)
# # lines = [file.readline().strip().split(':')[0] for _ in range(5)]
# return jsonify(lines), 200 # Return the lines as JSON with a 200 OK status
# except FileNotFoundError:
# return jsonify({"error": "File not found."}), 404 # Handle file not found error
# except Exception as e:
# return jsonify({"error": str(e)}), 500 # Handle other potential errors
@app.route('/api/most_repeated_sentence', methods=['GET'])
def fetch_most_repeated_sentences():
try:
sentences = []
with open('most_repeated_sentences.txt', 'r') as file:
for line in file:
line = line.strip()
if ':' in line: # Check if line contains the separator
try:
sentence, count = line.rsplit(':', 1) # Split from right side
count = int(count.strip())
sentences.append((sentence.strip(), count))
except (ValueError, IndexError):
continue # Skip invalid lines
# Sort sentences by count in descending order
sorted_sentences = sorted(sentences, key=lambda x: x[1], reverse=True)
# Get top 5 sentences only
top_5_sentences = [sentence[0] for sentence in sorted_sentences[:5]]
return jsonify(top_5_sentences)
except FileNotFoundError:
return jsonify({"error": "File not found"}), 404
except Exception as e:
logger.error(f"Error in fetch_most_repeated_sentences: {e}")
return jsonify({"error": "Internal server error"}), 500
@app.route('/api/guu', methods=['POST'])
def predict_words():
global predicted_words, append_list , global_count
try:
data = request.get_json()
print("Received data:", data)
if not isinstance(data, dict):
return jsonify({'error': 'Invalid JSON format'}), 400
input_text = data.get('item', '').strip() # Ensure we are checking the stripped input
# # Handle case when input_text is "1"
# if input_text == "1":
# print("Resetting append_list")
# append_list = [] # Reset the append list
# return jsonify(default_predicted_words[:9]) # Return the default words
# Handle reset request
if input_text == "1":
with open('dataset.txt', 'a') as file:
file.write(' '.join(append_list) + '\n')
append_list = []
global_count = 0
sentence= sentences_name('dataset.txt')
repeated_sentences = get_most_repeated_sentences(sentence)
print("Most repeated sentences:", repeated_sentences)
save_most_repeated_sentences(repeated_sentences, 'most_repeated_sentences.txt')
return jsonify(default_predicted_words[:9])
if not input_text:
return jsonify({'error': 'No input text provided'}), 400
append_list.append(input_text)
print("Current append list:", append_list)
combined_input = ' '.join(append_list)
print("Combined input for prediction:", combined_input)
predicted_words = generate_predicted_words(combined_input)
print("Predicted words:", predicted_words)
return jsonify(predicted_words[:9])
except Exception as e:
print(f"An error occurred: {str(e)}") # Log the error message
return jsonify({'error': str(e)}), 500
application = app
if __name__ == '__main__':
application.run()