Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -48,28 +48,75 @@ import cv2
|
|
48 |
import numpy as np
|
49 |
import sys
|
50 |
import io
|
|
|
51 |
# FluxPipeline import 부분을 수정
|
52 |
from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
53 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
#
|
56 |
-
pipe =
|
57 |
-
"stabilityai/stable-diffusion-
|
58 |
torch_dtype=torch.float16,
|
59 |
-
|
|
|
60 |
)
|
61 |
pipe.to("cuda")
|
62 |
|
63 |
-
# LoRA 가중치 로드
|
64 |
-
pipe.load_lora_weights(
|
65 |
-
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
66 |
-
)
|
67 |
-
pipe.fuse_lora(lora_scale=0.125)
|
68 |
-
|
69 |
# 안전 검사기 설정
|
70 |
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
|
71 |
"CompVis/stable-diffusion-safety-checker"
|
72 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
logging.basicConfig(level=logging.INFO)
|
74 |
logger = logging.getLogger(__name__)
|
75 |
|
|
|
48 |
import numpy as np
|
49 |
import sys
|
50 |
import io
|
51 |
+
|
52 |
# FluxPipeline import 부분을 수정
|
53 |
from diffusers import StableDiffusionPipeline, DiffusionPipeline
|
54 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
55 |
+
from diffusers import AutoPipelineForText2Image
|
56 |
+
|
57 |
+
# Model initialization 부분 수정
|
58 |
+
if not path.exists(cache_path):
|
59 |
+
os.makedirs(cache_path, exist_ok=True)
|
60 |
|
61 |
+
# 파이프라인 초기화 수정
|
62 |
+
pipe = AutoPipelineForText2Image.from_pretrained(
|
63 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
64 |
torch_dtype=torch.float16,
|
65 |
+
use_safetensors=True,
|
66 |
+
variant="fp16"
|
67 |
)
|
68 |
pipe.to("cuda")
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
# 안전 검사기 설정
|
71 |
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
|
72 |
"CompVis/stable-diffusion-safety-checker"
|
73 |
)
|
74 |
+
|
75 |
+
# process_and_save_image 함수 수정
|
76 |
+
@spaces.GPU
|
77 |
+
def process_and_save_image(height, width, steps, scales, prompt, seed):
|
78 |
+
is_safe, processed_prompt = process_prompt(prompt)
|
79 |
+
if not is_safe:
|
80 |
+
gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
|
81 |
+
return None, load_gallery()
|
82 |
+
|
83 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
84 |
+
try:
|
85 |
+
generated_image = pipe(
|
86 |
+
prompt=processed_prompt,
|
87 |
+
negative_prompt="low quality, worst quality, bad anatomy, bad composition, poor, low effort",
|
88 |
+
num_inference_steps=steps,
|
89 |
+
guidance_scale=scales,
|
90 |
+
height=height,
|
91 |
+
width=width,
|
92 |
+
generator=torch.Generator("cuda").manual_seed(int(seed))
|
93 |
+
).images[0]
|
94 |
+
|
95 |
+
# PIL Image로 확실하게 변환
|
96 |
+
if not isinstance(generated_image, Image.Image):
|
97 |
+
generated_image = Image.fromarray(generated_image)
|
98 |
+
|
99 |
+
# RGB 모드로 변환
|
100 |
+
if generated_image.mode != 'RGB':
|
101 |
+
generated_image = generated_image.convert('RGB')
|
102 |
+
|
103 |
+
# 메모리에서 PNG로 변환
|
104 |
+
img_byte_arr = io.BytesIO()
|
105 |
+
generated_image.save(img_byte_arr, format='PNG')
|
106 |
+
img_byte_arr = img_byte_arr.getvalue()
|
107 |
+
|
108 |
+
# 디스크에 저장
|
109 |
+
saved_path = save_image(generated_image)
|
110 |
+
if saved_path is None:
|
111 |
+
logger.warning("Failed to save generated image")
|
112 |
+
return None, load_gallery()
|
113 |
+
|
114 |
+
# PNG 형식으로 다시 로드
|
115 |
+
return Image.open(io.BytesIO(img_byte_arr)), load_gallery()
|
116 |
+
except Exception as e:
|
117 |
+
logger.error(f"Error in image generation: {str(e)}")
|
118 |
+
return None, load_gallery()
|
119 |
+
|
120 |
logging.basicConfig(level=logging.INFO)
|
121 |
logger = logging.getLogger(__name__)
|
122 |
|