trashchenkov commited on
Commit
02fc3d4
·
verified ·
1 Parent(s): 9d6f756

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -116
app.py CHANGED
@@ -1,51 +1,24 @@
1
  import gradio as gr
 
2
  import torch
3
- from huggingface_hub import HfApi, RepositoryNotFoundError
4
  from diffusers import DiffusionPipeline
5
 
6
- # Проверка доступности модели на Hugging Face Hub
7
- def is_model_available(model_id):
8
- try:
9
- api = HfApi()
10
- api.model_info(model_id)
11
- return True
12
- except RepositoryNotFoundError:
13
- return False
14
- except Exception:
15
- return False
16
-
17
- def validate_model(model_id):
18
- if not model_id:
19
- raise ValueError("Необходимо указать модель")
20
-
21
- if not is_model_available(model_id):
22
- raise ValueError(f"Модель '{model_id}' не найдена на Hugging Face Hub")
23
-
24
- if not any(x in model_id.lower() for x in ["stable-diffusion", "sdxl"]):
25
- raise ValueError("Поддерживаются только Stable Diffusion и SDXL модели")
26
-
27
- # Инициализация устройства
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
-
31
- # Глобальные переменные
32
- current_model = None
33
- pipe = None
34
-
35
- def load_pipeline(model_id):
36
- global pipe, current_model
37
- if model_id != current_model:
38
- validate_model(model_id)
39
- if pipe is not None:
40
- del pipe
41
- torch.cuda.empty_cache()
42
- pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
43
- pipe = pipe.to(device)
44
- current_model = model_id
45
- return pipe
46
 
47
  def infer(
48
- model_id,
49
  prompt,
50
  negative_prompt,
51
  seed,
@@ -55,40 +28,38 @@ def infer(
55
  num_inference_steps,
56
  progress=gr.Progress(track_tqdm=True),
57
  ):
58
- try:
59
- # Загрузка и проверка модели
60
- pipeline = load_pipeline(model_id)
61
-
62
- # Генерация изображения
63
- generator = torch.Generator(device=device).manual_seed(seed)
64
-
65
- result = pipeline(
66
- prompt=prompt,
67
- negative_prompt=negative_prompt,
68
- width=width,
69
- height=height,
70
- guidance_scale=guidance_scale,
71
- num_inference_steps=num_inference_steps,
72
- generator=generator,
73
- ).images[0]
74
-
75
- return result, seed
76
 
77
- except Exception as e:
78
- raise gr.Error(f"Ошибка генерации: {str(e)}")
79
-
80
- # Список доступных моделей по умолчанию
81
- available_models = [
82
- "stabilityai/stable-diffusion-2-1",
83
- "stabilityai/sdxl-turbo",
84
- "runwayml/stable-diffusion-v1-5",
85
- "prompthero/openjourney-v4"
86
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  examples = [
89
- ["stabilityai/sdxl-turbo", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"],
90
- ["runwayml/stable-diffusion-v1-5", "An astronaut riding a green horse"],
91
- ["prompthero/openjourney-v4", "A cyberpunk cityscape at night, neon lights, rain"],
92
  ]
93
 
94
  css = """
@@ -100,58 +71,81 @@ css = """
100
 
101
  with gr.Blocks(css=css) as demo:
102
  with gr.Column(elem_id="col-container"):
103
- gr.Markdown("# 🎨 Генератор изображений по тексту")
 
 
 
 
 
 
 
104
 
105
- with gr.Row():
106
- model_id = gr.Dropdown(
107
- label="Выберите или введите модель",
108
- choices=available_models,
109
- value="stabilityai/sdxl-turbo",
110
- allow_custom_value=True,
111
- scale=3
112
- )
113
-
114
- prompt = gr.Textbox(
115
- label="Промпт",
116
- placeholder="Введите описание изображения...",
117
- lines=2
118
  )
119
 
120
- negative_prompt = gr.Textbox(
121
- label="Негативный промпт",
122
- placeholder="Что исключить из изображения...",
123
- lines=2
 
124
  )
125
 
126
- with gr.Accordion("Настройки генерации", open=False):
127
- with gr.Row():
128
- seed = gr.Slider(0, 2147483647, value=42, label="Сид")
129
- guidance_scale = gr.Slider(0.0, 20.0, value=7.5, label="Guidance Scale")
130
-
131
- with gr.Row():
132
- width = gr.Slider(256, 1024, value=512, step=64, label="Ширина")
133
- height = gr.Slider(256, 1024, value=512, step=64, label="Высота")
134
-
135
- num_inference_steps = gr.Slider(1, 100, value=20, step=1, label="Шаги генерации")
136
 
137
- generate_btn = gr.Button("Сгенерировать", variant="primary")
138
-
139
- output_image = gr.Image(label="Результат", show_label=False)
140
- used_seed = gr.Number(label="Использованный сид", visible=True)
141
-
142
- gr.Examples(
143
- examples=examples,
144
- inputs=[model_id, prompt],
145
- outputs=[output_image, used_seed],
146
- fn=infer,
147
- cache_examples=True,
148
- label="Примеры"
149
  )
150
 
151
- generate_btn.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  fn=infer,
153
  inputs=[
154
- model_id,
155
  prompt,
156
  negative_prompt,
157
  seed,
@@ -160,8 +154,8 @@ with gr.Blocks(css=css) as demo:
160
  guidance_scale,
161
  num_inference_steps,
162
  ],
163
- outputs=[output_image, used_seed]
164
  )
165
 
166
  if __name__ == "__main__":
167
- demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
  import torch
 
4
  from diffusers import DiffusionPipeline
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ model_repo_id = "stabilityai/sdxl-turbo" # Текущая/последняя загруженная модель
8
+ if torch.cuda.is_available():
9
+ torch_dtype = torch.float16
10
+ else:
11
+ torch_dtype = torch.float32
12
+
13
+ # Изначально загружаем модель по умолчанию
14
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
+ pipe = pipe.to(device)
16
+
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+ MAX_IMAGE_SIZE = 1024
 
 
 
 
 
19
 
20
  def infer(
21
+ model,
22
  prompt,
23
  negative_prompt,
24
  seed,
 
28
  num_inference_steps,
29
  progress=gr.Progress(track_tqdm=True),
30
  ):
31
+ global model_repo_id, pipe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Проверяем, нужно ли менять модель
34
+ if model != model_repo_id:
35
+ try:
36
+ # Пробуем загрузить новую модель
37
+ new_pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype)
38
+ new_pipe = new_pipe.to(device)
39
+ # Если успешно, то обновляем pipe и модель
40
+ pipe = new_pipe
41
+ model_repo_id = model
42
+ except Exception as e:
43
+ raise gr.Error(f"Не удалось загрузить модель {model}. Ошибка: {str(e)}")
44
+
45
+ generator = torch.Generator(device=device).manual_seed(seed)
46
+
47
+ image = pipe(
48
+ prompt=prompt,
49
+ negative_prompt=negative_prompt,
50
+ guidance_scale=guidance_scale,
51
+ num_inference_steps=num_inference_steps,
52
+ width=width,
53
+ height=height,
54
+ generator=generator,
55
+ ).images[0]
56
+
57
+ return image, seed
58
 
59
  examples = [
60
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
61
+ "An astronaut riding a green horse",
62
+ "A delicious ceviche cheesecake slice",
63
  ]
64
 
65
  css = """
 
71
 
72
  with gr.Blocks(css=css) as demo:
73
  with gr.Column(elem_id="col-container"):
74
+ gr.Markdown(" # Text-to-Image App")
75
+
76
+ # Вместо выпадающего списка — текстовое поле для ввода модели
77
+ model = gr.Textbox(
78
+ label="Model name or path",
79
+ value="stabilityai/sdxl-turbo", # Значение по умолчанию
80
+ interactive=True
81
+ )
82
 
83
+ prompt = gr.Text(
84
+ label="Prompt",
85
+ show_label=False,
86
+ max_lines=1,
87
+ placeholder="Enter your prompt",
88
+ container=False,
 
 
 
 
 
 
 
89
  )
90
 
91
+ negative_prompt = gr.Text(
92
+ label="Negative prompt",
93
+ max_lines=1,
94
+ placeholder="Enter a negative prompt",
95
+ visible=True,
96
  )
97
 
98
+ seed = gr.Slider(
99
+ label="Seed",
100
+ minimum=0,
101
+ maximum=MAX_SEED,
102
+ step=1,
103
+ value=42,
104
+ )
 
 
 
105
 
106
+ guidance_scale = gr.Slider(
107
+ label="Guidance scale",
108
+ minimum=0.0,
109
+ maximum=10.0,
110
+ step=0.1,
111
+ value=7.0,
 
 
 
 
 
 
112
  )
113
 
114
+ num_inference_steps = gr.Slider(
115
+ label="Number of inference steps",
116
+ minimum=1,
117
+ maximum=50,
118
+ step=1,
119
+ value=20,
120
+ )
121
+
122
+ run_button = gr.Button("Run", scale=0, variant="primary")
123
+ result = gr.Image(label="Result", show_label=False)
124
+
125
+ with gr.Accordion("Advanced Settings", open=False):
126
+ with gr.Row():
127
+ width = gr.Slider(
128
+ label="Width",
129
+ minimum=256,
130
+ maximum=MAX_IMAGE_SIZE,
131
+ step=32,
132
+ value=1024,
133
+ )
134
+ height = gr.Slider(
135
+ label="Height",
136
+ minimum=256,
137
+ maximum=MAX_IMAGE_SIZE,
138
+ step=32,
139
+ value=1024,
140
+ )
141
+
142
+ gr.Examples(examples=examples, inputs=[prompt])
143
+
144
+ gr.on(
145
+ triggers=[run_button.click, prompt.submit],
146
  fn=infer,
147
  inputs=[
148
+ model,
149
  prompt,
150
  negative_prompt,
151
  seed,
 
154
  guidance_scale,
155
  num_inference_steps,
156
  ],
157
+ outputs=[result, seed],
158
  )
159
 
160
  if __name__ == "__main__":
161
+ demo.launch()