Spaces:
Running
Running
from flask import Flask, request, send_file, abort | |
import requests | |
import io | |
from PIL import Image | |
from cachetools import TTLCache, cached | |
import random | |
import os | |
import urllib.parse | |
import hashlib | |
app = Flask(__name__) | |
# Максимальные значения для ширины и высоты | |
MAX_WIDTH = 850 | |
MAX_HEIGHT = 850 | |
# Кэш на 10 минут | |
cache = TTLCache(maxsize=100, ttl=600) | |
# Получаем ключи из переменной окружения | |
keys = os.getenv("keys", "").split(',') | |
if not keys: | |
raise ValueError("Environment variable 'keys' must be set with a comma-separated list of API keys.") | |
def get_random_key(): | |
return random.choice(keys) | |
def generate_cache_key(prompt, width, height, seed, model_name): | |
# Создаем уникальный ключ на основе всех параметров | |
return hashlib.md5(f"{prompt}_{width}_{height}_{seed}_{model_name}".encode()).hexdigest() | |
def scale_dimensions(width, height, max_width, max_height): | |
"""Масштабирует размеры изображения, сохраняя соотношение сторон.""" | |
aspect_ratio = width / height | |
if width > max_width or height > max_height: | |
if width / max_width > height / max_height: | |
width = max_width | |
height = int(width / aspect_ratio) | |
else: | |
height = max_height | |
width = int(height * aspect_ratio) | |
return width, height | |
def generate_cached_image(cache_key, prompt, width, height, seed, model_name): | |
api_key = get_random_key() | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"inputs": prompt, | |
"parameters": { | |
"width": width, | |
"height": height, | |
"seed": seed | |
} | |
} | |
try: | |
response = requests.post(f"https://api-inference.huggingface.co/models/{model_name}", headers=headers, json=data) | |
response.raise_for_status() | |
image_data = response.content | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |
except requests.exceptions.RequestException as e: | |
app.logger.error(f"Error generating image: {e}") | |
abort(500, description="Error generating image") | |
def get_image(prompt): | |
width = request.args.get('width', type=int, default=512) | |
height = request.args.get('height', type=int, default=512) | |
seed = request.args.get('seed', type=int, default=22) | |
model_name = request.args.get('model', default="black-forest-labs_FLUX.1-dev").replace('_', '/') | |
# Декодируем URL-кодированный prompt | |
prompt = urllib.parse.unquote(prompt) | |
# Масштабируем размеры изображения, если они превышают максимальные значения | |
width, height = scale_dimensions(width, height, MAX_WIDTH, MAX_HEIGHT) | |
# Генерируем уникальный ключ для кэша | |
cache_key = generate_cache_key(prompt, width, height, seed, model_name) | |
try: | |
image = generate_cached_image(cache_key, prompt, width, height, seed, model_name) | |
except Exception as e: | |
app.logger.error(f"Error generating image: {e}") | |
abort(500, description="Error generating image") | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
return send_file( | |
io.BytesIO(img_byte_arr), | |
mimetype='image/png' | |
) | |
def health_check(): | |
return "OK", 200 | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860, debug=False) |