Anonym26 commited on
Commit
c0a9da8
·
verified ·
1 Parent(s): a320656

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -36
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import gradio as gr
2
- import asyncio
3
  import aiohttp
 
4
  from PIL import Image
5
  from io import BytesIO
 
6
  from dotenv import load_dotenv
7
  import os
8
-
9
  # Загрузка токена из .env файла
10
  load_dotenv()
11
  API_TOKEN = os.getenv("HF_API_TOKEN")
@@ -16,48 +16,54 @@ MODELS = {
16
  "Stable Diffusion v1.5": "Yntec/stable-diffusion-v1-5",
17
  "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
18
  "Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
19
- "Midjourney": "Jovie/Midjourney",
20
- "FLUX.1 [dev]": "black-forest-labs/FLUX.1-dev",
21
- "Leonardo AI": "goofyai/Leonardo_Ai_Style_Illustration",
22
  }
23
 
 
 
 
 
24
  # Асинхронная функция для отправки запроса к API
25
- async def query_model(session, prompt, model_name, model_url):
26
- try:
27
- async with session.post(
28
- f"https://api-inference.huggingface.co/models/{model_url}",
29
- headers=HEADERS,
30
- json={"inputs": prompt},
31
- ) as response:
32
- if response.status == 200:
33
- image_data = await response.read()
34
- return model_name, Image.open(BytesIO(image_data))
35
- else:
36
- error_message = await response.text()
37
- print(f"Ошибка для модели {model_name}: {error_message}")
38
- return model_name, None
39
- except Exception as e:
40
- print(f"Ошибка соединения с моделью {model_name}: {e}")
41
- return model_name, None
 
 
 
 
 
 
42
 
43
- # Асинхронная обработка запросов для всех моделей
44
  async def handle(prompt):
45
- async with aiohttp.ClientSession() as session:
46
- # Создаём асинхронный поток результатов
47
- results = {}
48
- async for outputs in async_zip_stream(
49
- *(query_model(session, prompt, model_name, model_url) for model_name, model_url in MODELS.items())
50
- ):
51
- for model_name, image in outputs:
52
- results[model_name] = image
53
- yield list(results.values())
54
 
55
  # Интерфейс Gradio
56
  with gr.Blocks() as demo:
57
- gr.Markdown("## Генерация изображений с помощью различных моделей нейросетей")
58
 
59
  # Поле ввода
60
- user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Астронавт верхом на лошади'")
61
 
62
  # Вывод изображений
63
  with gr.Row():
@@ -68,8 +74,8 @@ with gr.Blocks() as demo:
68
 
69
  # Асинхронная обработка ввода
70
  async def on_submit(prompt):
71
- async for result in handle(prompt):
72
- return [result.get(name, None) for name in MODELS.keys()]
73
 
74
  generate_button.click(
75
  fn=on_submit,
 
1
  import gradio as gr
 
2
  import aiohttp
3
+ import asyncio
4
  from PIL import Image
5
  from io import BytesIO
6
+ from asyncio import Semaphore
7
  from dotenv import load_dotenv
8
  import os
 
9
  # Загрузка токена из .env файла
10
  load_dotenv()
11
  API_TOKEN = os.getenv("HF_API_TOKEN")
 
16
  "Stable Diffusion v1.5": "Yntec/stable-diffusion-v1-5",
17
  "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
18
  "Stable Diffusion v3.5 Large": "stabilityai/stable-diffusion-3.5-large",
 
 
 
19
  }
20
 
21
+ # Настройки
22
+ MAX_CONCURRENT_REQUESTS = 3
23
+
24
+
25
  # Асинхронная функция для отправки запроса к API
26
+ async def query_model(prompt, model_name, model_url, semaphore):
27
+ async with semaphore: # Ограничиваем количество одновременно выполняемых задач
28
+ try:
29
+ async with aiohttp.ClientSession() as session:
30
+ async with session.post(
31
+ f"https://api-inference.huggingface.co/models/{model_url}",
32
+ headers=HEADERS,
33
+ json={"inputs": prompt},
34
+ ) as response:
35
+ if response.status == 200:
36
+ image_data = await response.read()
37
+ return model_name, Image.open(BytesIO(image_data))
38
+ else:
39
+ error_message = await response.json()
40
+ warnings = error_message.get("warnings", [])
41
+ print(f"Ошибка для модели {model_name}: {error_message.get('error', 'unknown error')}")
42
+ if warnings:
43
+ print(f"Предупреждения для модели {model_name}: {warnings}")
44
+ return model_name, None
45
+ except Exception as e:
46
+ print(f"Ошибка соединения с моделью {model_name}: {e}")
47
+ return model_name, None
48
+
49
 
50
+ # Асинхронная обработка запросов
51
  async def handle(prompt):
52
+ semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) # Создаём локальный семафор
53
+ tasks = [
54
+ query_model(prompt, model_name, model_url, semaphore)
55
+ for model_name, model_url in MODELS.items()
56
+ ]
57
+ results = await asyncio.gather(*tasks)
58
+ return {model_name: image for model_name, image in results if image}
59
+
 
60
 
61
  # Интерфейс Gradio
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("## Генерация изображений с использованием моделей Hugging Face")
64
 
65
  # Поле ввода
66
+ user_input = gr.Textbox(label="Введите описание изображения", placeholder="Например, 'Красный автомобиль в лесу'")
67
 
68
  # Вывод изображений
69
  with gr.Row():
 
74
 
75
  # Асинхронная обработка ввода
76
  async def on_submit(prompt):
77
+ results = await handle(prompt)
78
+ return [results.get(name, None) for name in MODELS.keys()]
79
 
80
  generate_button.click(
81
  fn=on_submit,