doublelotus commited on
Commit
2637637
·
verified ·
1 Parent(s): 61b8bfc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from diffusers import (
4
+ StableDiffusionXLControlNetPipeline,
5
+ ControlNetModel
6
+ )
7
+ from PIL import Image
8
+
9
+ # For demonstration, we’ll assume there’s a published SDXL lineart controlnet on HF.
10
+ # Replace with a valid repo if the name below doesn’t exist or adjust to your needs.
11
+ LINEART_CONTROLNET_REPO = "lllyasviel/sdxl-controlnet-lineart" # Example placeholder
12
+ SDXL_MODEL_REPO = "RunDiffusion/Juggernaut-XL-v9" # Or "stabilityai/stable-diffusion-xl-base-1.0"
13
+
14
+ @st.cache_resource
15
+ def load_pipeline():
16
+ """
17
+ Loads the ControlNet model (line-art) and the main Stable Diffusion XL model (Juggernaut XL).
18
+ Returns a pipeline ready for inference.
19
+ """
20
+ controlnet = ControlNetModel.from_pretrained(
21
+ LINEART_CONTROLNET_REPO,
22
+ torch_dtype=torch.float16
23
+ )
24
+
25
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
26
+ SDXL_MODEL_REPO,
27
+ controlnet=controlnet,
28
+ torch_dtype=torch.float16
29
+ )
30
+
31
+ # Move to GPU if available
32
+ if torch.cuda.is_available():
33
+ pipe.to("cuda")
34
+
35
+ return pipe
36
+
37
+ def combine_lineart_and_colormask(lineart: Image.Image, color_mask: Image.Image) -> Image.Image:
38
+ """
39
+ Naive example of combining lineart and color mask into a single control image.
40
+ Here we just alpha-blend them for demonstration.
41
+ In practice, you might want more sophisticated merges,
42
+ or you could use multi-ControlNet if the pipelines/models are available.
43
+ """
44
+ # Resize color mask to match lineart size
45
+ color_mask = color_mask.resize(lineart.size)
46
+
47
+ # Convert to RGBA
48
+ lineart_rgba = lineart.convert("RGBA")
49
+ color_mask_rgba = color_mask.convert("RGBA")
50
+
51
+ # Simple alpha blend for demonstration
52
+ blended = Image.blend(lineart_rgba, color_mask_rgba, alpha=0.5)
53
+ return blended.convert("RGB")
54
+
55
+ def main():
56
+ st.title("Line-Art + Color Mask with SDXL ControlNet")
57
+ st.markdown(
58
+ "Upload a **line-art sketch** and a **color mask**, then let "
59
+ "Stable Diffusion XL (Juggernaut XL) + ControlNet (Lineart) do the rest!"
60
+ )
61
+
62
+ # Sidebar inputs for text prompt, etc.
63
+ prompt = st.sidebar.text_input(
64
+ "Prompt",
65
+ value="A cute cartoon-style character with vibrant colors"
66
+ )
67
+ negative_prompt = st.sidebar.text_input(
68
+ "Negative Prompt",
69
+ value="ugly, deformed"
70
+ )
71
+ guidance_scale = st.sidebar.slider(
72
+ "Guidance Scale (classifier-free)",
73
+ min_value=1.0,
74
+ max_value=20.0,
75
+ value=9.0
76
+ )
77
+ num_inference_steps = st.sidebar.slider(
78
+ "Number of Inference Steps",
79
+ min_value=10,
80
+ max_value=100,
81
+ value=30
82
+ )
83
+
84
+ # Main area for uploading images
85
+ lineart_file = st.file_uploader("Upload Line-Art Sketch (png/jpg)", type=["png", "jpg", "jpeg"])
86
+ color_file = st.file_uploader("Upload Color Mask (png/jpg)", type=["png", "jpg", "jpeg"])
87
+
88
+ if lineart_file and color_file:
89
+ lineart_image = Image.open(lineart_file)
90
+ color_mask = Image.open(color_file)
91
+
92
+ st.image(lineart_image, caption="Line-Art Preview", width=300)
93
+ st.image(color_mask, caption="Color Mask Preview", width=300)
94
+
95
+ # Combine images into a single control image
96
+ combined_control_image = combine_lineart_and_colormask(lineart_image, color_mask)
97
+ st.image(combined_control_image, caption="Combined Control Image", width=300)
98
+
99
+ # Button to run inference
100
+ if st.button("Generate"):
101
+ pipe = load_pipeline()
102
+
103
+ with st.spinner("Generating image..."):
104
+ result = pipe(
105
+ prompt=prompt,
106
+ negative_prompt=negative_prompt,
107
+ control_image=combined_control_image,
108
+ num_inference_steps=num_inference_steps,
109
+ guidance_scale=guidance_scale,
110
+ # For SDXL pipelines, also pass an additional prompt for the refiner if needed
111
+ # refiner_prompt=prompt, # if your pipeline supports it
112
+ ).images[0]
113
+
114
+ st.image(result, caption="Generated Image", width=512)
115
+ else:
116
+ st.warning("Please upload both a line-art sketch and a color mask.")
117
+
118
+ if __name__ == "__main__":
119
+ main()