File size: 4,775 Bytes
e02f42e
8f6011c
e81e176
c0a9da8
8f6011c
 
c0a9da8
8f6011c
 
 
 
 
 
e81e176
 
 
8f6011c
e81e176
 
8f6011c
 
c0a9da8
 
 
 
e81e176
c0a9da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e81e176
c0a9da8
e81e176
c0a9da8
 
 
 
 
 
 
 
e81e176
 
8f6011c
c0a9da8
e81e176
 
c0a9da8
e81e176
 
8f6011c
e81e176
 
 
8f6011c
 
e81e176
 
c0a9da8
 
e81e176
8f6011c
e81e176
8f6011c
e81e176
8f6011c
e81e176
 
 
 
 
 
 
 
e02f42e
 
 
 
 
 
 
 
 
 
 
 
 
8f6011c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio
import gradio as gr
import aiohttp
import asyncio
from PIL import Image
from io import BytesIO
from asyncio import Semaphore
from dotenv import load_dotenv
import os
# Загрузка токена из .env файла
load_dotenv()
API_TOKEN = os.getenv("HF_API_TOKEN")

# Конфигурация API
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
MODELS = {
    "Stable Diffusion v1.5": "Yntec/stable-diffusion-v1-5",
    "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
    "Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
}

# Настройки
MAX_CONCURRENT_REQUESTS = 3


# Асинхронная функция для отправки запроса к API
async def query_model(prompt, model_name, model_url, semaphore):
    async with semaphore:  # Ограничиваем количество одновременно выполняемых задач
        try:
            async with aiohttp.ClientSession() as session:
                async with session.post(
                    f"https://api-inference.huggingface.co/models/{model_url}",
                    headers=HEADERS,
                    json={"inputs": prompt},
                ) as response:
                    if response.status == 200:
                        image_data = await response.read()
                        return model_name, Image.open(BytesIO(image_data))
                    else:
                        error_message = await response.json()
                        warnings = error_message.get("warnings", [])
                        print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}")
                        if warnings:
                            print(f"Предупреждения для модели {model_name}: {warnings}")
                        return model_name, None
        except Exception as e:
            print(f"Ошибка соединения с моделью {model_name}: {e}")
            return model_name, None


# Асинхронная обработка запросов
async def handle(prompt):
    semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)  # Создаём локальный семафор
    tasks = [
        query_model(prompt, model_name, model_url, semaphore)
        for model_name, model_url in MODELS.items()
    ]
    results = await asyncio.gather(*tasks)
    return {model_name: image for model_name, image in results if image}


# Интерфейс Gradio
with gr.Blocks() as demo:
    gr.Markdown("## Генерация изображений с использованием моделей Hugging Face")

    # Поле ввода
    user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'")

    # Вывод изображений
    with gr.Row():
        outputs = {name: gr.Image(label=name) for name in MODELS.keys()}

    # Кнопка генерации
    generate_button = gr.Button("Сгенерировать")

    # Асинхронная обработка ввода
    async def on_submit(prompt):
        results = await handle(prompt)
        return [results.get(name, None) for name in MODELS.keys()]

    generate_button.click(
        fn=on_submit,
        inputs=[user_input],
        outputs=list(outputs.values()),
    )
    user_input.submit(
        fn=on_submit,
        inputs=[user_input],
        outputs=list(outputs.values()),
    )

    # Ссылки на соцсети
    with gr.Row():
        with gr.Column(scale=1):
            gr.Image(value='icon.jpg')
        with gr.Column(scale=4):
            gradio.HTML("""<div style="text-align: center; font-family: 'Helvetica Neue', sans-serif; padding: 10px; color: #333333;">

        <p style="font-size: 18px; font-weight: 600; margin-bottom: 8px;">

            Эта демка была создана телеграм каналом <strong style="color: #007ACC;"><a href='https://t.me/mlphys'> mlphys</a></strong>. Другие мои социальные сети:

        </p>

        <p style="font-size: 16px;">

            <a href="https://t.me/mlphys" target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">Telegram</a> |

            <a href="https://x.com/quensy23" target="_blank" style="color: #1DA1F2; text-decoration: none; font-weight: 500;">Twitter</a> |

            <a href="https://github.com/freQuensy23-coder"  target="_blank" style="color: #0088cc; text-decoration: none; font-weight: 500;">GitHub</a>

        </p>

    </div>""")

demo.launch()