Spaces:
Running
Running
import gradio as gr | |
import asyncio | |
import aiohttp | |
from PIL import Image | |
from io import BytesIO | |
from dotenv import load_dotenv | |
import os | |
# Загрузка токена из .env файла | |
load_dotenv() | |
API_TOKEN = os.getenv("HF_API_TOKEN") | |
# Конфигурация API | |
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"} | |
MODELS = { | |
"Stable Diffusion v1.5": "Yntec/stable-diffusion-v1-5", | |
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1", | |
"Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large", | |
"Midjourney": "Jovie/Midjourney", | |
"FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev", | |
"Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration", | |
} | |
# Асинхронная функция для отправки запроса к API | |
async def query_model(session, prompt, model_name, model_url): | |
try: | |
async with session.post( | |
f"https://api-inference.huggingface.co/models/{model_url}", | |
headers=HEADERS, | |
json={"inputs": prompt}, | |
) as response: | |
if response.status == 200: | |
image_data = await response.read() | |
return model_name, Image.open(BytesIO(image_data)) | |
else: | |
error_message = await response.text() | |
print(f"Ошибка для модели {model_name}: {error_message}") | |
return model_name, None | |
except Exception as e: | |
print(f"Ошибка соединения с моделью {model_name}: {e}") | |
return model_name, None | |
# Асинхронная обработка запросов для всех моделей | |
async def handle(prompt): | |
async with aiohttp.ClientSession() as session: | |
# Создаём асинхронный поток результатов | |
results = {} | |
async for outputs in async_zip_stream( | |
*(query_model(session, prompt, model_name, model_url) for model_name, model_url in MODELS.items()) | |
): | |
for model_name, image in outputs: | |
results[model_name] = image | |
yield list(results.values()) | |
# Интерфейс Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("## Генерация изображений с помощью различных моделей нейросетей") | |
# Поле ввода | |
user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Астронавт верхом на лошади'") | |
# Вывод изображений | |
with gr.Row(): | |
outputs = {name: gr.Image(label=name) for name in MODELS.keys()} | |
# Кнопка генерации | |
generate_button = gr.Button("Сгенерировать") | |
# Асинхронная обработка ввода | |
async def on_submit(prompt): | |
async for result in handle(prompt): | |
return [result.get(name, None) for name in MODELS.keys()] | |
generate_button.click( | |
fn=on_submit, | |
inputs=[user_input], | |
outputs=list(outputs.values()), | |
) | |
user_input.submit( | |
fn=on_submit, | |
inputs=[user_input], | |
outputs=list(outputs.values()), | |
) | |
# Ссылки на соцсети | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
### Поддержка проекта | |
- [Telegram](https://t.me/mlphys) | |
- [GitHub](https://github.com/freQuensy23-coder) | |
""" | |
) | |
demo.launch() | |