myspace / app.py
mahdideveloepr's picture
Update app.py
f291ac5 verified
import os
import math
import gradio as gr
import numpy as np
import torch
import safetensors.torch as sf
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
from diffusers.models.attention_processor import AttnProcessor2_0
from transformers import CLIPTextModel, CLIPTokenizer
from briarmbg import BriaRMBG
from torch.hub import download_url_to_file
from huggingface_hub import hf_hub_download
# بررسی دسترسی به GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_name(0).startswith("A100") else torch.float16
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
print(f"✅ {num_gpus} CUDA GPUs are available!")
for i in range(num_gpus):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
else:
print(f"✅ CUDA is available! GPU: {torch.cuda.get_device_name(0)}")
else:
print("❌ No CUDA GPUs are available. Switching to CPU.")
# دانلود و بارگذاری مدل
model_path = './models/iclight_sd15_fc.safetensors'
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
hf_hub_download(
repo_id='lllyasviel/ic-light',
filename='iclight_sd15_fc.safetensors',
local_dir='./models',
local_dir_use_symlinks=False
)
# بارگذاری مدل‌های مورد نیاز
sd15_name = 'stablediffusionapi/realistic-vision-v51'
tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder").to(device=device, dtype=dtype)
vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae").to(device=device, dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet").to(device=device, dtype=dtype)
rmbg = BriaRMBG.from_pretrained("briaai/RMBG-1.4").to(device=device, dtype=torch.float32)
# تنظیم پردازشگر توجه برای بهینه‌سازی عملکرد
unet.set_attn_processor(AttnProcessor2_0())
vae.set_attn_processor(AttnProcessor2_0())
# تنظیم نمونه‌گیرها
scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
steps_offset=1
)
# ساخت لوله‌های پردازش تصویر
t2i_pipe = StableDiffusionPipeline.from_pretrained(
sd15_name,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
feature_extractor=None,
image_encoder=None
).to(device)
i2i_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
sd15_name,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
requires_safety_checker=False,
feature_extractor=None,
image_encoder=None
).to(device)
# پاک کردن کش حافظه GPU برای جلوگیری از خطای Out Of Memory
torch.cuda.empty_cache()
# توابع پردازش تصویر
def process_image(prompt: str, width: int, height: int, num_steps: int, guidance: float):
rng = torch.Generator(device=device).manual_seed(42)
result = t2i_pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_steps,
guidance_scale=guidance,
generator=rng
).images[0]
return result
# راه‌اندازی رابط کاربری با Gradio
block = gr.Blocks()
with block:
with gr.Row():
gr.Markdown("## IC-Light - Image Processing")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
width = gr.Slider(256, 1024, value=512, step=64, label="Width")
height = gr.Slider(256, 1024, value=640, step=64, label="Height")
steps = gr.Slider(10, 100, value=50, step=5, label="Inference Steps")
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
submit = gr.Button("Generate Image")
with gr.Column():
result_image = gr.Image(type="pil", label="Generated Image")
submit.click(fn=process_image, inputs=[prompt, width, height, steps, guidance], outputs=result_image)
block.launch(server_name='0.0.0.0', server_port=7860)