NightRaven109 commited on
Commit
7c89d3a
1 Parent(s): 6ecc7d4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +212 -0
  2. requirements.txt +15 -14
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import spaces
5
+ import numpy as np
6
+ from PIL import Image
7
+ import safetensors.torch
8
+ from huggingface_hub import hf_hub_download
9
+ from accelerate import Accelerator
10
+ from accelerate.utils import set_seed
11
+ from diffusers import (
12
+ AutoencoderKL,
13
+ DDPMScheduler,
14
+ UNet2DConditionModel,
15
+ )
16
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
17
+ from models.controlnet import ControlNetModel
18
+ from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline
19
+ from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
20
+
21
+ # Initialize global variables for models
22
+ pipeline = None
23
+ generator = None
24
+ accelerator = None
25
+
26
+ @spaces.GPU
27
+ def initialize_models():
28
+ global pipeline, generator, accelerator
29
+
30
+ # Initialize accelerator
31
+ accelerator = Accelerator(
32
+ mixed_precision="fp16",
33
+ gradient_accumulation_steps=1
34
+ )
35
+
36
+ try:
37
+ # Download and load models with authentication token
38
+ scheduler = DDPMScheduler.from_pretrained(
39
+ "NightRaven109/CCSRModels",
40
+ subfolder="stable-diffusion-2-1-base/scheduler",
41
+ use_auth_token=os.environ['Read']
42
+ )
43
+
44
+ text_encoder = CLIPTextModel.from_pretrained(
45
+ "NightRaven109/CCSRModels",
46
+ subfolder="stable-diffusion-2-1-base/text_encoder",
47
+ use_auth_token=os.environ['Read']
48
+ )
49
+
50
+ tokenizer = CLIPTokenizer.from_pretrained(
51
+ "NightRaven109/CCSRModels",
52
+ subfolder="stable-diffusion-2-1-base/tokenizer",
53
+ use_auth_token=os.environ['Read']
54
+ )
55
+
56
+ feature_extractor = CLIPImageProcessor.from_pretrained(
57
+ "NightRaven109/CCSRModels",
58
+ subfolder="stable-diffusion-2-1-base/feature_extractor",
59
+ use_auth_token=os.environ['Read']
60
+ )
61
+
62
+ unet = UNet2DConditionModel.from_pretrained(
63
+ "NightRaven109/CCSRModels",
64
+ subfolder="stable-diffusion-2-1-base/unet",
65
+ use_auth_token=os.environ['Read']
66
+ )
67
+
68
+ controlnet = ControlNetModel.from_pretrained(
69
+ "NightRaven109/CCSRModels",
70
+ subfolder="Controlnet",
71
+ use_auth_token=os.environ['Read']
72
+ )
73
+
74
+ vae = AutoencoderKL.from_pretrained(
75
+ "NightRaven109/CCSRModels",
76
+ subfolder="vae",
77
+ use_auth_token=os.environ['Read']
78
+ )
79
+
80
+ # Rest of the code remains the same
81
+ # Freeze models
82
+ for model in [vae, text_encoder, unet, controlnet]:
83
+ model.requires_grad_(False)
84
+
85
+ # Initialize pipeline
86
+ pipeline = StableDiffusionControlNetPipeline(
87
+ vae=vae,
88
+ text_encoder=text_encoder,
89
+ tokenizer=tokenizer,
90
+ feature_extractor=feature_extractor,
91
+ unet=unet,
92
+ controlnet=controlnet,
93
+ scheduler=scheduler,
94
+ safety_checker=None,
95
+ requires_safety_checker=False,
96
+ )
97
+
98
+ # Get weight dtype based on mixed precision
99
+ weight_dtype = torch.float32
100
+ if accelerator.mixed_precision == "fp16":
101
+ weight_dtype = torch.float16
102
+ elif accelerator.mixed_precision == "bf16":
103
+ weight_dtype = torch.bfloat16
104
+
105
+ # Move models to device with appropriate dtype
106
+ for model in [text_encoder, vae, unet, controlnet]:
107
+ model.to(accelerator.device, dtype=weight_dtype)
108
+
109
+ # Initialize generator
110
+ generator = torch.Generator(device=accelerator.device)
111
+
112
+ return True
113
+
114
+ except Exception as e:
115
+ print(f"Error initializing models: {str(e)}")
116
+ return False
117
+
118
+ @spaces.GPU
119
+ def process_image(
120
+ input_image,
121
+ prompt="clean, high-resolution, 8k",
122
+ negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed",
123
+ guidance_scale=1.0,
124
+ conditioning_scale=1.0,
125
+ num_inference_steps=20,
126
+ seed=42,
127
+ upscale_factor=2,
128
+ color_fix_method="adain"
129
+ ):
130
+ global pipeline, generator, accelerator
131
+
132
+ if pipeline is None:
133
+ if not initialize_models():
134
+ return None
135
+
136
+ try:
137
+ # Set seed
138
+ if seed is not None:
139
+ generator.manual_seed(seed)
140
+
141
+ # Process input image
142
+ input_pil = Image.fromarray(input_image)
143
+ width, height = input_pil.size
144
+
145
+ # Resize image
146
+ target_width = width * upscale_factor
147
+ target_height = height * upscale_factor
148
+ target_width = target_width - (target_width % 8)
149
+ target_height = target_height - (target_height % 8)
150
+
151
+ # Move pipeline to GPU for processing
152
+ pipeline.to(accelerator.device)
153
+
154
+ # Generate image
155
+ with torch.no_grad():
156
+ output = pipeline(
157
+ t_max=0.6666,
158
+ t_min=0.0,
159
+ tile_diffusion=False,
160
+ added_prompt=prompt,
161
+ image=input_pil,
162
+ num_inference_steps=num_inference_steps,
163
+ generator=generator,
164
+ height=target_height,
165
+ width=target_width,
166
+ guidance_scale=guidance_scale,
167
+ negative_prompt=negative_prompt,
168
+ conditioning_scale=conditioning_scale,
169
+ )
170
+
171
+ generated_image = output.images[0]
172
+
173
+ # Apply color fixing if specified
174
+ if color_fix_method != "none":
175
+ fix_func = wavelet_color_fix if color_fix_method == "wavelet" else adain_color_fix
176
+ generated_image = fix_func(generated_image, input_pil)
177
+
178
+ # Move pipeline back to CPU
179
+ pipeline.to("cpu")
180
+ torch.cuda.empty_cache()
181
+
182
+ return generated_image
183
+
184
+ except Exception as e:
185
+ print(f"Error processing image: {str(e)}")
186
+ return None
187
+
188
+ # Create Gradio interface
189
+ iface = gr.Interface(
190
+ fn=process_image,
191
+ inputs=[
192
+ gr.Image(label="Input Image"),
193
+ gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
194
+ gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
195
+ gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
196
+ gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
197
+ gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
198
+ gr.Number(label="Seed", value=42),
199
+ gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
200
+ gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
201
+ ],
202
+ outputs=gr.Image(label="Generated Image"),
203
+ title="Controllable Conditional Super-Resolution",
204
+ description="Upload an image to enhance its resolution using CCSR.",
205
+ examples=[
206
+ ["example1.jpg", "clean, sharp, detailed", "blurry, noise", 1.0, 1.0, 20, 42, 2, "adain"],
207
+ ["example2.jpg", "high-resolution, pristine", "artifacts, pixelated", 1.5, 1.0, 30, 123, 2, "wavelet"],
208
+ ]
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ iface.launch()
requirements.txt CHANGED
@@ -1,14 +1,15 @@
1
- diffusers==0.21.0
2
- torch==2.0.1
3
- pytorch_lightning
4
- accelerate==1.2.0
5
- transformers==4.25.0
6
- xformers==0.0.22
7
- loralib
8
- fairscale==0.4.13
9
- basicsr==1.4.2
10
- timm==0.9.5
11
- pydantic==1.10.11
12
- huggingface_hub==0.25.2
13
- opencv-python-headless
14
- lpips
 
 
1
+ diffusers==0.21.0
2
+ torch==2.0.1
3
+ pytorch_lightning
4
+ accelerate==1.2.0
5
+ transformers==4.25.0
6
+ xformers==0.0.22
7
+ loralib
8
+ fairscale==0.4.13
9
+ basicsr==1.4.2
10
+ timm==0.9.5
11
+ pydantic==1.10.11
12
+ huggingface_hub==0.25.2
13
+ opencv-python-headless
14
+ lpips
15
+ einops