any_model / app.py
trashchenkov's picture
Update app.py
9d6f756 verified
raw
history blame
5.36 kB
import gradio as gr
import torch
from huggingface_hub import HfApi, RepositoryNotFoundError
from diffusers import DiffusionPipeline
# Проверка доступности модели на Hugging Face Hub
def is_model_available(model_id):
try:
api = HfApi()
api.model_info(model_id)
return True
except RepositoryNotFoundError:
return False
except Exception:
return False
def validate_model(model_id):
if not model_id:
raise ValueError("Необходимо указать модель")
if not is_model_available(model_id):
raise ValueError(f"Модель '{model_id}' не найдена на Hugging Face Hub")
if not any(x in model_id.lower() for x in ["stable-diffusion", "sdxl"]):
raise ValueError("Поддерживаются только Stable Diffusion и SDXL модели")
# Инициализация устройства
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Глобальные переменные
current_model = None
pipe = None
def load_pipeline(model_id):
global pipe, current_model
if model_id != current_model:
validate_model(model_id)
if pipe is not None:
del pipe
torch.cuda.empty_cache()
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
current_model = model_id
return pipe
def infer(
model_id,
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True),
):
try:
# Загрузка и проверка модели
pipeline = load_pipeline(model_id)
# Генерация изображения
generator = torch.Generator(device=device).manual_seed(seed)
result = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
return result, seed
except Exception as e:
raise gr.Error(f"Ошибка генерации: {str(e)}")
# Список доступных моделей по умолчанию
available_models = [
"stabilityai/stable-diffusion-2-1",
"stabilityai/sdxl-turbo",
"runwayml/stable-diffusion-v1-5",
"prompthero/openjourney-v4"
]
examples = [
["stabilityai/sdxl-turbo", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"],
["runwayml/stable-diffusion-v1-5", "An astronaut riding a green horse"],
["prompthero/openjourney-v4", "A cyberpunk cityscape at night, neon lights, rain"],
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# 🎨 Генератор изображений по тексту")
with gr.Row():
model_id = gr.Dropdown(
label="Выберите или введите модель",
choices=available_models,
value="stabilityai/sdxl-turbo",
allow_custom_value=True,
scale=3
)
prompt = gr.Textbox(
label="Промпт",
placeholder="Введите описание изображения...",
lines=2
)
negative_prompt = gr.Textbox(
label="Негативный промпт",
placeholder="Что исключить из изображения...",
lines=2
)
with gr.Accordion("Настройки генерации", open=False):
with gr.Row():
seed = gr.Slider(0, 2147483647, value=42, label="Сид")
guidance_scale = gr.Slider(0.0, 20.0, value=7.5, label="Guidance Scale")
with gr.Row():
width = gr.Slider(256, 1024, value=512, step=64, label="Ширина")
height = gr.Slider(256, 1024, value=512, step=64, label="Высота")
num_inference_steps = gr.Slider(1, 100, value=20, step=1, label="Шаги генерации")
generate_btn = gr.Button("Сгенерировать", variant="primary")
output_image = gr.Image(label="Результат", show_label=False)
used_seed = gr.Number(label="Использованный сид", visible=True)
gr.Examples(
examples=examples,
inputs=[model_id, prompt],
outputs=[output_image, used_seed],
fn=infer,
cache_examples=True,
label="Примеры"
)
generate_btn.click(
fn=infer,
inputs=[
model_id,
prompt,
negative_prompt,
seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=[output_image, used_seed]
)
if __name__ == "__main__":
demo.launch()