ChiKyi commited on
Commit
e5282e5
·
1 Parent(s): 01c3f1c

add new model

Browse files
Files changed (2) hide show
  1. app.py +133 -51
  2. stable.py +115 -0
app.py CHANGED
@@ -5,13 +5,73 @@ from torchvision import transforms
5
  from matplotlib import pyplot as plt
6
  import gradio as gr
7
 
 
 
 
 
 
 
 
 
 
 
 
8
  from models import MainModel, UNetAuto, Autoencoder
9
  from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB
 
10
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- # Hàm load models
14
- def load_unet_model(auto_model_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  unet = UNetAuto(in_channels=1, out_channels=2).to(device)
16
  model = Autoencoder(unet).to(device)
17
  model.load_state_dict(torch.load(auto_model_path, map_location=device))
@@ -44,7 +104,7 @@ mobilenet_model = load_model(
44
  model_type='mobilenet'
45
  )
46
 
47
- unet_model = load_unet_model("weight/autoencoder.pt")
48
 
49
  # Transformations
50
  def preprocess_image(image):
@@ -67,68 +127,90 @@ def colorize_image(input_image, mode):
67
  with torch.no_grad():
68
  resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
69
  mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
70
- unet_output = unet_model(grayscale.unsqueeze(0))
71
 
72
  # Resize outputs to match the original size
73
  resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
74
  mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
75
- unet_colorized = postprocess_image(grayscale, unet_output, original_size)
76
 
77
  if mode == "ResNet":
78
  return resnet_colorized, None, None
79
  elif mode == "MobileNet":
80
  return None, mobilenet_colorized, None
81
  elif mode == "Unet":
82
- return None, None, unet_colorized
83
  elif mode == "Comparison":
84
- return resnet_colorized, mobilenet_colorized, unet_colorized
85
 
 
 
 
 
86
 
87
- # Gradio Interface
88
  def gradio_interface():
89
- with gr.Blocks() as demo:
90
- # Input components
91
- input_image = gr.Image(type="numpy", label="Upload an Image")
92
- output_modes = gr.Radio(
93
- choices=["ResNet", "MobileNet", "Unet", "Comparison"],
94
- value="ResNet",
95
- label="Output Mode"
96
- )
97
-
98
- submit_button = gr.Button("Submit")
99
-
100
- # Output components
101
- with gr.Row(): # Place output images in a single row
102
- resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
103
- mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
104
- unet_output = gr.Image(label="Colorized Image (Unet)", visible=False)
105
-
106
- # Output mode logic
107
- def update_visibility(mode):
108
- if mode == "ResNet":
109
- return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
110
- elif mode == "MobileNet":
111
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
112
- elif mode == "Unet":
113
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
114
- elif mode == "Comparison":
115
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
116
-
117
- # Dynamic event listener for output mode changes
118
- output_modes.change(
119
- fn=update_visibility,
120
- inputs=[output_modes],
121
- outputs=[resnet_output, mobilenet_output, unet_output]
122
- )
123
-
124
- # Submit logic
125
- submit_button.click(
126
- fn=colorize_image,
127
- inputs=[input_image, output_modes],
128
- outputs=[resnet_output, mobilenet_output, unet_output]
129
- )
130
-
131
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  # Launch
 
5
  from matplotlib import pyplot as plt
6
  import gradio as gr
7
 
8
+ import transformers
9
+ transformers.utils.move_cache()
10
+
11
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
12
+ from transformers import BlipProcessor, BlipForConditionalGeneration
13
+ from accelerate import Accelerator
14
+
15
+ import warnings
16
+ warnings.filterwarnings("ignore")
17
+
18
+
19
  from models import MainModel, UNetAuto, Autoencoder
20
  from utils import lab_to_rgb, build_res_unet, build_mobilenet_unet # Utility to convert LAB to RGB
21
+ from stable import blip_image_captioning, apply_color
22
 
23
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
 
25
+
26
+ # Stable diffusion
27
+
28
+ accelerator = Accelerator(
29
+ mixed_precision="fp16"
30
+ )
31
+
32
+ controlnet = ControlNetModel.from_pretrained(
33
+ pretrained_model_name_or_path="nickpai/sdxl_light_caption_output",
34
+ subfolder="checkpoint-30000/controlnet",
35
+ )
36
+
37
+ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
38
+ pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0",
39
+ controlnet=controlnet
40
+ )
41
+ blip_processor = BlipProcessor.from_pretrained(
42
+ "Salesforce/blip-image-captioning-large",
43
+ )
44
+ blip_generator = BlipForConditionalGeneration.from_pretrained(
45
+ "Salesforce/blip-image-captioning-large",
46
+ )
47
+
48
+ pipe.to(accelerator.device)
49
+ blip_generator.to(accelerator.device)
50
+
51
+ def colorize_single_image(image, positive_prompt, negative_prompt, caption_generate):
52
+ image = PIL.Image.fromarray(image)
53
+
54
+ torch.cuda.empty_cache()
55
+ if caption_generate:
56
+ caption = blip_image_captioning(image=image, device=accelerator.device, processor=blip_processor, generator=blip_generator)
57
+ else:
58
+ caption = ""
59
+
60
+ original_size = image.size
61
+ control_image = image.convert("L").convert("RGB").resize((512, 512))
62
+ prompt = [positive_prompt + ", " + caption]
63
+
64
+ colorized_image = pipe(prompt=prompt,
65
+ num_inference_steps=5,
66
+ generator=torch.manual_seed(0),
67
+ image=control_image,
68
+ negative_prompt=negative_prompt).images[0]
69
+ result_image = apply_color(control_image, colorized_image)
70
+ result_image = result_image.resize(original_size)
71
+ return result_image, caption if caption_generate else gr.update(visible=False)
72
+
73
+ # Hàm load models cho autoencoder và gan
74
+ def load_autoencoder_model(auto_model_path):
75
  unet = UNetAuto(in_channels=1, out_channels=2).to(device)
76
  model = Autoencoder(unet).to(device)
77
  model.load_state_dict(torch.load(auto_model_path, map_location=device))
 
104
  model_type='mobilenet'
105
  )
106
 
107
+ autoencoder_model = load_autoencoder_model("weight/autoencoder.pt")
108
 
109
  # Transformations
110
  def preprocess_image(image):
 
127
  with torch.no_grad():
128
  resnet_output = resnet_model.net_G(grayscale.unsqueeze(0))
129
  mobilenet_output = mobilenet_model.net_G(grayscale.unsqueeze(0))
130
+ autoencoder_output = autoencoder_model(grayscale.unsqueeze(0))
131
 
132
  # Resize outputs to match the original size
133
  resnet_colorized = postprocess_image(grayscale, resnet_output, original_size)
134
  mobilenet_colorized = postprocess_image(grayscale, mobilenet_output, original_size)
135
+ autoencoder_colorized = postprocess_image(grayscale, autoencoder_output, original_size)
136
 
137
  if mode == "ResNet":
138
  return resnet_colorized, None, None
139
  elif mode == "MobileNet":
140
  return None, mobilenet_colorized, None
141
  elif mode == "Unet":
142
+ return None, None, autoencoder_colorized
143
  elif mode == "Comparison":
144
+ return resnet_colorized, mobilenet_colorized, autoencoder_colorized
145
 
146
+ def colorize_single_image(input_image, positive_prompt, negative_prompt, generate_caption):
147
+
148
+ caption = "Generated Caption Example" if generate_caption else ""
149
+ return input_image, caption
150
 
 
151
  def gradio_interface():
152
+ with gr.Blocks() as app:
153
+ with gr.Tab("Mode Colorization no Prompting"):
154
+ with gr.Blocks():
155
+ input_image = gr.Image(type="numpy", label="Upload an Image")
156
+ output_modes = gr.Radio(
157
+ choices=["ResNet", "MobileNet", "Unet", "Comparison"],
158
+ value="ResNet",
159
+ label="Output Mode"
160
+ )
161
+
162
+ submit_button = gr.Button("Submit")
163
+
164
+ with gr.Row(): # Place output images in a single row
165
+ resnet_output = gr.Image(label="Colorized Image (ResNet18)", visible=False)
166
+ mobilenet_output = gr.Image(label="Colorized Image (MobileNet)", visible=False)
167
+ autoencoder_output = gr.Image(label="Colorized Image (Unet)", visible=False)
168
+
169
+ def update_visibility(mode):
170
+ if mode == "ResNet":
171
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
172
+ elif mode == "MobileNet":
173
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
174
+ elif mode == "Unet":
175
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
176
+ elif mode == "Comparison":
177
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
178
+
179
+ output_modes.change(
180
+ fn=update_visibility,
181
+ inputs=[output_modes],
182
+ outputs=[resnet_output, mobilenet_output, autoencoder_output]
183
+ )
184
+
185
+ submit_button.click(
186
+ fn=colorize_image,
187
+ inputs=[input_image, output_modes],
188
+ outputs=[resnet_output, mobilenet_output, autoencoder_output]
189
+ )
190
+
191
+ with gr.Tab("Stable Diffusion"):
192
+ with gr.Blocks():
193
+ sd_image = gr.Image(label="Upload a Color Image")
194
+ positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="Text for positive prompt")
195
+ negative_prompt = gr.Textbox(
196
+ value="low quality, bad quality, low contrast, black and white, bw, monochrome, grainy, blurry, historical, restored, desaturate",
197
+ label="Negative Prompt", placeholder="Text for negative prompt"
198
+ )
199
+ generate_caption = gr.Checkbox(label="Generate Caption?", value=False)
200
+ submit_sd = gr.Button("Generate")
201
+
202
+ sd_output_image = gr.Image(label="Colorized Image")
203
+ sd_caption = gr.Textbox(label="Captioning Result", show_copy_button=True, visible=False)
204
+
205
+ submit_sd.click(
206
+ fn=colorize_single_image,
207
+ inputs=[sd_image, positive_prompt, negative_prompt, generate_caption],
208
+ outputs=[sd_output_image, sd_caption]
209
+ )
210
+
211
+ return app
212
+
213
+
214
 
215
 
216
  # Launch
stable.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # for image captioning
2
+ import PIL
3
+ import torch
4
+ from torchvision import transforms
5
+ import gradio as gr
6
+
7
+ import transformers
8
+ transformers.utils.move_cache()
9
+
10
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
11
+ from transformers import BlipProcessor, BlipForConditionalGeneration
12
+ from accelerate import Accelerator
13
+
14
+ def remove_unlikely_words(prompt: str) -> str:
15
+ """
16
+ Removes unlikely words from a prompt.
17
+
18
+ Args:
19
+ prompt: The text prompt to be cleaned.
20
+
21
+ Returns:
22
+ The cleaned prompt with unlikely words removed.
23
+ """
24
+ unlikely_words = []
25
+
26
+ a1_list = [f'{i}s' for i in range(1900, 2000)]
27
+ a2_list = [f'{i}' for i in range(1900, 2000)]
28
+ a3_list = [f'year {i}' for i in range(1900, 2000)]
29
+ a4_list = [f'circa {i}' for i in range(1900, 2000)]
30
+ b1_list = [f"{year[0]} {year[1]} {year[2]} {year[3]} s" for year in a1_list]
31
+ b2_list = [f"{year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
32
+ b3_list = [f"year {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
33
+ b4_list = [f"circa {year[0]} {year[1]} {year[2]} {year[3]}" for year in a1_list]
34
+
35
+ words_list = [
36
+ "black and white,", "black and white", "black & white,", "black & white", "circa",
37
+ "balck and white,", "monochrome,", "black-and-white,", "black-and-white photography,",
38
+ "black - and - white photography,", "monochrome bw,", "black white,", "black an white,",
39
+ "grainy footage,", "grainy footage", "grainy photo,", "grainy photo", "b&w photo",
40
+ "back and white", "back and white,", "monochrome contrast", "monochrome", "grainy",
41
+ "grainy photograph,", "grainy photograph", "low contrast,", "low contrast", "b & w",
42
+ "grainy black-and-white photo,", "bw", "bw,", "grainy black-and-white photo",
43
+ "b & w,", "b&w,", "b&w!,", "b&w", "black - and - white,", "bw photo,", "grainy photo,",
44
+ "black-and-white photo,", "black-and-white photo", "black - and - white photography",
45
+ "b&w photo,", "monochromatic photo,", "grainy monochrome photo,", "monochromatic",
46
+ "blurry photo,", "blurry,", "blurry photography,", "monochromatic photo",
47
+ "black - and - white photograph,", "black - and - white photograph", "black on white,",
48
+ "black on white", "black-and-white", "historical image,", "historical picture,",
49
+ "historical photo,", "historical photograph,", "archival photo,", "taken in the early",
50
+ "taken in the late", "taken in the", "historic photograph,", "restored,", "restored",
51
+ "historical photo", "historical setting,",
52
+ "historic photo,", "historic", "desaturated!!,", "desaturated!,", "desaturated,", "desaturated",
53
+ "taken in", "shot on leica", "shot on leica sl2", "sl2",
54
+ "taken with a leica camera", "taken with a leica camera", "leica sl2", "leica", "setting",
55
+ "overcast day", "overcast weather", "slight overcast", "overcast",
56
+ "picture taken in", "photo taken in",
57
+ ", photo", ", photo", ", photo", ", photo", ", photograph",
58
+ ",,", ",,,", ",,,,", " ,", " ,", " ,", " ,",
59
+ ]
60
+
61
+ unlikely_words.extend(a1_list)
62
+ unlikely_words.extend(a2_list)
63
+ unlikely_words.extend(a3_list)
64
+ unlikely_words.extend(a4_list)
65
+ unlikely_words.extend(b1_list)
66
+ unlikely_words.extend(b2_list)
67
+ unlikely_words.extend(b3_list)
68
+ unlikely_words.extend(b4_list)
69
+ unlikely_words.extend(words_list)
70
+
71
+ for word in unlikely_words:
72
+ prompt = prompt.replace(word, "")
73
+ return prompt
74
+
75
+ def blip_image_captioning(image, device, processor, generator, conditional="a photography of"):
76
+ # Load the processor and model
77
+ if processor is None:
78
+ processor = BlipProcessor.from_pretrained(
79
+ "Salesforce/blip-image-captioning-large"
80
+ )
81
+ if generator is None:
82
+ model = BlipForConditionalGeneration.from_pretrained(
83
+ "Salesforce/blip-image-captioning-large",
84
+ torch_dtype=torch.float16
85
+ ).to(device)
86
+
87
+ # Prepare inputs
88
+ inputs = processor(
89
+ image,
90
+ text=conditional,
91
+ return_tensors="pt"
92
+ ).to(device)
93
+
94
+ # Generate the caption
95
+ out = generator.generate(**inputs, max_new_tokens=20) # Use max_new_tokens for better clarity
96
+ caption = processor.decode(out[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
97
+ caption = remove_unlikely_words(caption)
98
+
99
+ return caption
100
+
101
+ def apply_color(image: PIL.Image.Image, color_map: PIL.Image.Image) -> PIL.Image.Image:
102
+ # Convert input images to LAB color space
103
+ image_lab = image.convert('LAB')
104
+ color_map_lab = color_map.convert('LAB')
105
+
106
+ # Split LAB channels
107
+ l, a , b = image_lab.split()
108
+ _, a_map, b_map = color_map_lab.split()
109
+
110
+ # Merge LAB channels with color map
111
+ merged_lab = PIL.Image.merge('LAB', (l, a_map, b_map))
112
+
113
+ # Convert merged LAB image back to RGB color space
114
+ result_rgb = merged_lab.convert('RGB')
115
+ return result_rgb