Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from diffusers import (
|
4 |
+
StableDiffusionXLControlNetPipeline,
|
5 |
+
ControlNetModel
|
6 |
+
)
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
# For demonstration, we’ll assume there’s a published SDXL lineart controlnet on HF.
|
10 |
+
# Replace with a valid repo if the name below doesn’t exist or adjust to your needs.
|
11 |
+
LINEART_CONTROLNET_REPO = "lllyasviel/sdxl-controlnet-lineart" # Example placeholder
|
12 |
+
SDXL_MODEL_REPO = "RunDiffusion/Juggernaut-XL-v9" # Or "stabilityai/stable-diffusion-xl-base-1.0"
|
13 |
+
|
14 |
+
@st.cache_resource
|
15 |
+
def load_pipeline():
|
16 |
+
"""
|
17 |
+
Loads the ControlNet model (line-art) and the main Stable Diffusion XL model (Juggernaut XL).
|
18 |
+
Returns a pipeline ready for inference.
|
19 |
+
"""
|
20 |
+
controlnet = ControlNetModel.from_pretrained(
|
21 |
+
LINEART_CONTROLNET_REPO,
|
22 |
+
torch_dtype=torch.float16
|
23 |
+
)
|
24 |
+
|
25 |
+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
26 |
+
SDXL_MODEL_REPO,
|
27 |
+
controlnet=controlnet,
|
28 |
+
torch_dtype=torch.float16
|
29 |
+
)
|
30 |
+
|
31 |
+
# Move to GPU if available
|
32 |
+
if torch.cuda.is_available():
|
33 |
+
pipe.to("cuda")
|
34 |
+
|
35 |
+
return pipe
|
36 |
+
|
37 |
+
def combine_lineart_and_colormask(lineart: Image.Image, color_mask: Image.Image) -> Image.Image:
|
38 |
+
"""
|
39 |
+
Naive example of combining lineart and color mask into a single control image.
|
40 |
+
Here we just alpha-blend them for demonstration.
|
41 |
+
In practice, you might want more sophisticated merges,
|
42 |
+
or you could use multi-ControlNet if the pipelines/models are available.
|
43 |
+
"""
|
44 |
+
# Resize color mask to match lineart size
|
45 |
+
color_mask = color_mask.resize(lineart.size)
|
46 |
+
|
47 |
+
# Convert to RGBA
|
48 |
+
lineart_rgba = lineart.convert("RGBA")
|
49 |
+
color_mask_rgba = color_mask.convert("RGBA")
|
50 |
+
|
51 |
+
# Simple alpha blend for demonstration
|
52 |
+
blended = Image.blend(lineart_rgba, color_mask_rgba, alpha=0.5)
|
53 |
+
return blended.convert("RGB")
|
54 |
+
|
55 |
+
def main():
|
56 |
+
st.title("Line-Art + Color Mask with SDXL ControlNet")
|
57 |
+
st.markdown(
|
58 |
+
"Upload a **line-art sketch** and a **color mask**, then let "
|
59 |
+
"Stable Diffusion XL (Juggernaut XL) + ControlNet (Lineart) do the rest!"
|
60 |
+
)
|
61 |
+
|
62 |
+
# Sidebar inputs for text prompt, etc.
|
63 |
+
prompt = st.sidebar.text_input(
|
64 |
+
"Prompt",
|
65 |
+
value="A cute cartoon-style character with vibrant colors"
|
66 |
+
)
|
67 |
+
negative_prompt = st.sidebar.text_input(
|
68 |
+
"Negative Prompt",
|
69 |
+
value="ugly, deformed"
|
70 |
+
)
|
71 |
+
guidance_scale = st.sidebar.slider(
|
72 |
+
"Guidance Scale (classifier-free)",
|
73 |
+
min_value=1.0,
|
74 |
+
max_value=20.0,
|
75 |
+
value=9.0
|
76 |
+
)
|
77 |
+
num_inference_steps = st.sidebar.slider(
|
78 |
+
"Number of Inference Steps",
|
79 |
+
min_value=10,
|
80 |
+
max_value=100,
|
81 |
+
value=30
|
82 |
+
)
|
83 |
+
|
84 |
+
# Main area for uploading images
|
85 |
+
lineart_file = st.file_uploader("Upload Line-Art Sketch (png/jpg)", type=["png", "jpg", "jpeg"])
|
86 |
+
color_file = st.file_uploader("Upload Color Mask (png/jpg)", type=["png", "jpg", "jpeg"])
|
87 |
+
|
88 |
+
if lineart_file and color_file:
|
89 |
+
lineart_image = Image.open(lineart_file)
|
90 |
+
color_mask = Image.open(color_file)
|
91 |
+
|
92 |
+
st.image(lineart_image, caption="Line-Art Preview", width=300)
|
93 |
+
st.image(color_mask, caption="Color Mask Preview", width=300)
|
94 |
+
|
95 |
+
# Combine images into a single control image
|
96 |
+
combined_control_image = combine_lineart_and_colormask(lineart_image, color_mask)
|
97 |
+
st.image(combined_control_image, caption="Combined Control Image", width=300)
|
98 |
+
|
99 |
+
# Button to run inference
|
100 |
+
if st.button("Generate"):
|
101 |
+
pipe = load_pipeline()
|
102 |
+
|
103 |
+
with st.spinner("Generating image..."):
|
104 |
+
result = pipe(
|
105 |
+
prompt=prompt,
|
106 |
+
negative_prompt=negative_prompt,
|
107 |
+
control_image=combined_control_image,
|
108 |
+
num_inference_steps=num_inference_steps,
|
109 |
+
guidance_scale=guidance_scale,
|
110 |
+
# For SDXL pipelines, also pass an additional prompt for the refiner if needed
|
111 |
+
# refiner_prompt=prompt, # if your pipeline supports it
|
112 |
+
).images[0]
|
113 |
+
|
114 |
+
st.image(result, caption="Generated Image", width=512)
|
115 |
+
else:
|
116 |
+
st.warning("Please upload both a line-art sketch and a color mask.")
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
main()
|