TextToImages / app.py
Anonym26's picture
Update app.py
ee546da verified
raw
history blame
4.43 kB
import gradio
import gradio as gr
import aiohttp
import asyncio
from PIL import Image
from io import BytesIO
from dotenv import load_dotenv
import os
# Загрузка токена из .env файла
load_dotenv()
API_TOKEN = os.getenv("HF_API_TOKEN")
# Конфигурация API
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
MODELS = {
"Midjourney": "Jovie/Midjourney",
"FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev",
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
"Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
"Stable Diffusion v1.0 Large": "stabilityai/stable-diffusion-xl-base-1.0",
"Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration",
}
# Асинхронная функция для отправки запроса к API
async def query_model(prompt, model_name, model_url):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://api-inference.huggingface.co/models/{model_url}",
headers=HEADERS,
json={"inputs": prompt},
) as response:
if response.status == 200:
image_data = await response.read()
return model_name, Image.open(BytesIO(image_data))
else:
error_message = await response.json()
warnings = error_message.get("warnings", [])
print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}")
if warnings:
print(f"Предупреждения для модели {model_name}: {warnings}")
return model_name, None
except Exception as e:
print(f"Ошибка соединения с моделью {model_name}: {e}")
return model_name, None
# Асинхронная обработка всех запросов
async def handle(prompt):
tasks = [
query_model(prompt, model_name, model_url)
for model_name, model_url in MODELS.items()
]
results = await asyncio.gather(*tasks)
return {model_name: image for model_name, image in results if image}
# Интерфейс Gradio
with gr.Blocks() as demo:
gr.Markdown("## Генерация изображений с использованием моделей Hugging Face")
# Поле ввода
user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'")
# Вывод изображений
with gr.Row():
outputs = {name: gr.Image(label=name) for name in MODELS.keys()}
# Кнопка генерации
generate_button = gr.Button("Сгенерировать")
# Асинхронная обработка ввода
async def on_submit(prompt):
results = await handle(prompt)
return [results.get(name, None) for name in MODELS.keys()]
generate_button.click(
fn=on_submit,
inputs=[user_input],
outputs=list(outputs.values()),
)
user_input.submit(
fn=on_submit,
inputs=[user_input],
outputs=list(outputs.values()),
)
# Ссылки на соцсети
with gr.Row():
with gr.Column(scale=1):
gr.Image(value='icon.jpg')
with gr.Column(scale=4):
gradio.HTML("""<div style="text-align: center; font-family: 'Helvetica Neue', sans-serif; padding: 10px; color: #333333;">
<p style="font-size: 18px; font-weight: 600; margin-bottom: 8px;">
Эта демка была создана телеграм каналом <strong style="color: #007ACC;"><a href='https://t.me/mlphys'> mlphys</a></strong>. Другие мои социальные сети:
</p>
<p style="font-size: 16px;">
<a href="https://t.me/mlphys" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">Telegram</a> |
<a href="https://x.com/quensy23" target="_blank" style="color: #1DA1F2; text-decoration: none; font-weight: 500;">Twitter</a> |
<a href="https://github.com/freQuensy23-coder" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">GitHub</a>
</p>
</div>""")
demo.launch()