Spaces:
Running
Running
File size: 4,492 Bytes
cf44c0c 405f02e f2f40d5 cf44c0c 550e3c0 f2f40d5 550e3c0 cf44c0c 405f02e 550e3c0 cf44c0c 405f02e cf44c0c 6a72c67 550e3c0 cf44c0c 6a72c67 cf44c0c 6a72c67 312d90f cf44c0c f2f40d5 550e3c0 405f02e cf44c0c 405f02e cf44c0c b8b68ce cf44c0c f2f40d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from flask import Flask, request, send_file, abort
import requests
import io
from PIL import Image
from cachetools import TTLCache, cached
import random
import os
import urllib.parse
import hashlib
from deep_translator import GoogleTranslator
from langdetect import detect
app = Flask(__name__)
# Максимальные значения для ширины и высоты
#MAX_WIDTH = 1024
#MAX_HEIGHT = 1024
MAX_WIDTH = 512
MAX_HEIGHT = 512
# Кэш на 10 минут
cache = TTLCache(maxsize=100, ttl=600)
# Получаем ключи из переменной окружения
keys = os.getenv("keys", "").split(',')
if not keys:
raise ValueError("Environment variable 'keys' must be set with a comma-separated list of API keys.")
def get_random_key():
return random.choice(keys)
def generate_cache_key(prompt, width, height, seed, model_name):
# Создаем уникальный ключ на основе всех параметров
return hashlib.md5(f"{prompt}_{width}_{height}_{seed}_{model_name}".encode()).hexdigest()
def scale_dimensions(width, height, max_width, max_height):
"""Масштабирует размеры изображения, сохраняя соотношение сторон."""
aspect_ratio = width / height
if width > max_width or height > max_height:
if width / max_width > height / max_height:
width = max_width
height = int(width / aspect_ratio)
else:
height = max_height
width = int(height * aspect_ratio)
return width, height
@cached(cache)
def generate_cached_image(cache_key, prompt, width, height, seed, model_name):
api_key = get_random_key()
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
data = {
"inputs": prompt,
"parameters": {
"width": width,
"height": height,
"seed": seed
}
}
try:
response = requests.post(f"https://api-inference.huggingface.co/models/{model_name}", headers=headers, json=data)
response.raise_for_status()
image_data = response.content
image = Image.open(io.BytesIO(image_data))
return image
except requests.exceptions.RequestException as e:
app.logger.error(f"Error generating image: {e}")
abort(500, description="Error generating image")
@app.route('/prompt/<path:prompt>')
def get_image(prompt):
width = request.args.get('width', type=int, default=512)
height = request.args.get('height', type=int, default=512)
seed = request.args.get('seed', type=int, default=22)
model_name = request.args.get('model', default="black-forest-labs_FLUX.1-dev").replace('_', '/')
# Декодируем URL-кодированный prompt
prompt = urllib.parse.unquote(prompt)
# Определяем язык промпта
try:
language = detect(prompt)
except Exception as e:
app.logger.error(f"Error detecting language: {e}")
abort(500, description="Error detecting language")
# Переводим промпт, если он не на английском языке
if language != 'en':
try:
translator = GoogleTranslator(source=language, target='en')
prompt = translator.translate(prompt)
except Exception as e:
app.logger.error(f"Error translating prompt: {e}")
abort(500, description="Error translating prompt")
# Масштабируем размеры изображения, если они превышают максимальные значения
width, height = scale_dimensions(width, height, MAX_WIDTH, MAX_HEIGHT)
# Генерируем уникальный ключ для кэша
cache_key = generate_cache_key(prompt, width, height, seed, model_name)
try:
image = generate_cached_image(cache_key, prompt, width, height, seed, model_name)
except Exception as e:
app.logger.error(f"Error generating image: {e}")
abort(500, description="Error generating image")
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return send_file(
io.BytesIO(img_byte_arr),
mimetype='image/png'
)
@app.route('/')
def health_check():
return "OK", 200
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=False)
|