trashchenkov commited on
Commit
ba95d06
·
verified ·
1 Parent(s): d3e417f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -15
app.py CHANGED
@@ -2,9 +2,10 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
  from diffusers import DiffusionPipeline
 
5
  import re
6
 
7
- # Устройство и параметры загрузки модели
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
 
@@ -14,15 +15,21 @@ VALID_REPO_ID_REGEX = re.compile(r"^[a-zA-Z0-9._\-]+/[a-zA-Z0-9._\-]+$")
14
  def is_valid_repo_id(repo_id):
15
  return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(('-', '.'))
16
 
 
 
 
 
17
  # Изначально загружаем модель по умолчанию
18
  model_repo_id = "CompVis/stable-diffusion-v1-4"
19
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
20
 
21
- # --- Загрузка LoRA (1) ---
22
- pipe.load_lora_weights("AnastasiaSh/sticker-cat-lora3")
23
-
24
- MAX_SEED = np.iinfo(np.int32).max
25
- MAX_IMAGE_SIZE = 1024
 
 
26
 
27
  def infer(
28
  model,
@@ -37,25 +44,32 @@ def infer(
37
  ):
38
  global model_repo_id, pipe
39
 
40
- # Проверяем и загружаем новую модель, если она изменена
41
  if model != model_repo_id:
42
  if not is_valid_repo_id(model):
43
  raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.")
 
44
  try:
45
  new_pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype).to(device)
46
-
47
- # --- Загрузка LoRA (2) ---
48
- new_pipe.load_lora_weights("AnastasiaSh/sticker-cat-lora3")
49
-
 
 
 
 
 
50
  pipe = new_pipe
51
  model_repo_id = model
 
52
  except Exception as e:
53
  raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}")
54
 
55
- # Генератор случайных чисел для детерминированности
56
  generator = torch.Generator(device=device).manual_seed(seed)
57
 
58
- # Генерация изображения
59
  try:
60
  image = pipe(
61
  prompt=prompt,
@@ -71,12 +85,14 @@ def infer(
71
 
72
  return image, seed
73
 
 
74
  examples = [
75
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
76
  "An astronaut riding a green horse",
77
  "A delicious ceviche cheesecake slice",
78
  ]
79
 
 
80
  css = """
81
  #col-container {
82
  margin: 0 auto;
@@ -84,16 +100,19 @@ css = """
84
  }
85
  """
86
 
 
87
  with gr.Blocks(css=css) as demo:
88
  with gr.Column(elem_id="col-container"):
89
  gr.Markdown("# Text-to-Image App")
90
 
 
91
  model = gr.Textbox(
92
  label="Model",
93
  value="CompVis/stable-diffusion-v1-4", # Значение по умолчанию
94
  interactive=True
95
  )
96
 
 
97
  prompt = gr.Text(
98
  label="Prompt",
99
  show_label=False,
@@ -101,7 +120,6 @@ with gr.Blocks(css=css) as demo:
101
  placeholder="Enter your prompt",
102
  container=False,
103
  )
104
-
105
  negative_prompt = gr.Text(
106
  label="Negative prompt",
107
  max_lines=1,
@@ -109,6 +127,7 @@ with gr.Blocks(css=css) as demo:
109
  visible=True,
110
  )
111
 
 
112
  seed = gr.Slider(
113
  label="Seed",
114
  minimum=0,
@@ -117,6 +136,7 @@ with gr.Blocks(css=css) as demo:
117
  value=42,
118
  )
119
 
 
120
  guidance_scale = gr.Slider(
121
  label="Guidance scale",
122
  minimum=0.0,
@@ -124,7 +144,6 @@ with gr.Blocks(css=css) as demo:
124
  step=0.1,
125
  value=7.0,
126
  )
127
-
128
  num_inference_steps = gr.Slider(
129
  label="Number of inference steps",
130
  minimum=1,
@@ -133,9 +152,13 @@ with gr.Blocks(css=css) as demo:
133
  value=20,
134
  )
135
 
 
136
  run_button = gr.Button("Run", variant="primary")
 
 
137
  result = gr.Image(label="Result", show_label=False)
138
 
 
139
  with gr.Accordion("Advanced Settings", open=False):
140
  with gr.Row():
141
  width = gr.Slider(
@@ -153,8 +176,10 @@ with gr.Blocks(css=css) as demo:
153
  value=512,
154
  )
155
 
 
156
  gr.Examples(examples=examples, inputs=[prompt])
157
 
 
158
  run_button.click(
159
  infer,
160
  inputs=[
@@ -170,5 +195,6 @@ with gr.Blocks(css=css) as demo:
170
  outputs=[result, seed],
171
  )
172
 
 
173
  if __name__ == "__main__":
174
  demo.launch()
 
2
  import numpy as np
3
  import torch
4
  from diffusers import DiffusionPipeline
5
+ from peft import PeftModel
6
  import re
7
 
8
+ # Устройство и тип данных
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
11
 
 
15
  def is_valid_repo_id(repo_id):
16
  return bool(VALID_REPO_ID_REGEX.match(repo_id)) and not repo_id.endswith(('-', '.'))
17
 
18
+ # Базовые константы
19
+ MAX_SEED = np.iinfo(np.int32).max
20
+ MAX_IMAGE_SIZE = 1024
21
+
22
  # Изначально загружаем модель по умолчанию
23
  model_repo_id = "CompVis/stable-diffusion-v1-4"
24
  pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype).to(device)
25
 
26
+ # Попробуем подгрузить LoRA-модификации (unet + text_encoder)
27
+ try:
28
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, "AnastasiaSh/sticker-cat-lora3/unet")
29
+ pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "AnastasiaSh/sticker-cat-lora3/text_encoder")
30
+ except Exception as e:
31
+ # Если не удалось, можно вывести предупреждение или поднять ошибку
32
+ print(f"Не удалось подгрузить LoRA по умолчанию: {e}")
33
 
34
  def infer(
35
  model,
 
44
  ):
45
  global model_repo_id, pipe
46
 
47
+ # Если пользователь ввёл другую модель, пробуем её загрузить с нуля
48
  if model != model_repo_id:
49
  if not is_valid_repo_id(model):
50
  raise gr.Error(f"Некорректный идентификатор модели: '{model}'. Проверьте название.")
51
+
52
  try:
53
  new_pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch_dtype).to(device)
54
+
55
+ # Повторно подгружаем LoRA для нового пайплайна
56
+ try:
57
+ new_pipe.unet = PeftModel.from_pretrained(new_pipe.unet, "AnastasiaSh/sticker-cat-lora3/unet")
58
+ new_pipe.text_encoder = PeftModel.from_pretrained(new_pipe.text_encoder, "AnastasiaSh/sticker-cat-lora3/text_encoder")
59
+ except Exception as e:
60
+ raise gr.Error(f"Не удалось подгрузить LoRA: {e}")
61
+
62
+ # Обновляем глобальные переменные
63
  pipe = new_pipe
64
  model_repo_id = model
65
+
66
  except Exception as e:
67
  raise gr.Error(f"Не удалось загрузить модель '{model}'.\nОшибка: {e}")
68
 
69
+ # Создаём генератор случайных чисел для детерминированности
70
  generator = torch.Generator(device=device).manual_seed(seed)
71
 
72
+ # Пытаемся сгенерировать изображение
73
  try:
74
  image = pipe(
75
  prompt=prompt,
 
85
 
86
  return image, seed
87
 
88
+ # Примеры для удобного тестирования
89
  examples = [
90
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
91
  "An astronaut riding a green horse",
92
  "A delicious ceviche cheesecake slice",
93
  ]
94
 
95
+ # Дополнительный CSS для оформления
96
  css = """
97
  #col-container {
98
  margin: 0 auto;
 
100
  }
101
  """
102
 
103
+ # Создаём Gradio-приложение
104
  with gr.Blocks(css=css) as demo:
105
  with gr.Column(elem_id="col-container"):
106
  gr.Markdown("# Text-to-Image App")
107
 
108
+ # Поле для ввода/смены модели
109
  model = gr.Textbox(
110
  label="Model",
111
  value="CompVis/stable-diffusion-v1-4", # Значение по умолчанию
112
  interactive=True
113
  )
114
 
115
+ # Основные поля для Prompt и Negative Prompt
116
  prompt = gr.Text(
117
  label="Prompt",
118
  show_label=False,
 
120
  placeholder="Enter your prompt",
121
  container=False,
122
  )
 
123
  negative_prompt = gr.Text(
124
  label="Negative prompt",
125
  max_lines=1,
 
127
  visible=True,
128
  )
129
 
130
+ # Слайдер для выбора seed
131
  seed = gr.Slider(
132
  label="Seed",
133
  minimum=0,
 
136
  value=42,
137
  )
138
 
139
+ # Слайдеры для guidance_scale и num_inference_steps
140
  guidance_scale = gr.Slider(
141
  label="Guidance scale",
142
  minimum=0.0,
 
144
  step=0.1,
145
  value=7.0,
146
  )
 
147
  num_inference_steps = gr.Slider(
148
  label="Number of inference steps",
149
  minimum=1,
 
152
  value=20,
153
  )
154
 
155
+ # Кнопка запуска
156
  run_button = gr.Button("Run", variant="primary")
157
+
158
+ # Поле для отображения результата
159
  result = gr.Image(label="Result", show_label=False)
160
 
161
+ # Продвинутые настройки (Accordion)
162
  with gr.Accordion("Advanced Settings", open=False):
163
  with gr.Row():
164
  width = gr.Slider(
 
176
  value=512,
177
  )
178
 
179
+ # Примеры
180
  gr.Examples(examples=examples, inputs=[prompt])
181
 
182
+ # Связка кнопки "Run" с функцией "infer"
183
  run_button.click(
184
  infer,
185
  inputs=[
 
195
  outputs=[result, seed],
196
  )
197
 
198
+ # Запуск
199
  if __name__ == "__main__":
200
  demo.launch()