Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright (c)
|
2 |
|
3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
# of this software and associated documentation files (the "Software"), to deal
|
@@ -19,7 +19,7 @@
|
|
19 |
# SOFTWARE.
|
20 |
|
21 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
22 |
-
from diffusers import
|
23 |
|
24 |
import torch
|
25 |
import torch.nn as nn
|
@@ -31,15 +31,15 @@ from typing import Tuple, List, Literal, Optional, Union
|
|
31 |
from tqdm import tqdm
|
32 |
from PIL import Image
|
33 |
|
34 |
-
from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
35 |
|
36 |
|
37 |
-
class
|
38 |
def __init__(
|
39 |
self,
|
40 |
device: torch.device,
|
41 |
dtype: torch.dtype = torch.float16,
|
42 |
-
sd_version: Literal['1.5'
|
43 |
hf_key: Optional[str] = None,
|
44 |
lora_key: Optional[str] = None,
|
45 |
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
|
@@ -52,8 +52,9 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
52 |
default_preprocess_mask_cover_alpha: float = 0.3,
|
53 |
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
|
54 |
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
|
|
|
55 |
) -> None:
|
56 |
-
r"""Stabilized
|
57 |
|
58 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
59 |
Model while preserving mask fidelity and quality.
|
@@ -95,13 +96,16 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
95 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
96 |
where each mask covered by other masks is reduced in its alpha
|
97 |
value by this specified factor.
|
98 |
-
t_index_list (List[int]): The default scheduling for
|
99 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
100 |
defines the mask quantization modes. Details in the codes of
|
101 |
`self.process_mask`. Basically, this (subtly) controls the
|
102 |
smoothness of foreground-background blending. More continuous
|
103 |
means more blending, but smaller generated patch depending on
|
104 |
the mask standard deviation.
|
|
|
|
|
|
|
105 |
"""
|
106 |
super().__init__()
|
107 |
|
@@ -120,30 +124,24 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
120 |
self.mask_type = mask_type
|
121 |
|
122 |
print(f'[INFO] Loading Stable Diffusion...')
|
123 |
-
variant = None
|
124 |
lora_weight_name = None
|
125 |
if self.sd_version == '1.5':
|
126 |
if hf_key is not None:
|
127 |
-
print(f'[INFO] Using
|
128 |
model_key = hf_key
|
129 |
else:
|
130 |
model_key = 'runwayml/stable-diffusion-v1-5'
|
131 |
-
# variant = 'fp16'
|
132 |
lora_key = 'latent-consistency/lcm-lora-sdv1-5'
|
133 |
lora_weight_name = 'pytorch_lora_weights.safetensors'
|
134 |
-
# elif self.sd_version == 'xl':
|
135 |
-
# model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
|
136 |
-
# lora_key = 'latent-consistency/lcm-lora-sdxl'
|
137 |
-
# variant = 'fp16'
|
138 |
-
# lora_weight_name = 'pytorch_lora_weights.safetensors'
|
139 |
else:
|
140 |
raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
|
141 |
|
142 |
# Create model
|
143 |
-
|
144 |
-
|
|
|
145 |
|
146 |
-
self.pipe =
|
147 |
if lora_key is None:
|
148 |
print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
|
149 |
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
@@ -166,7 +164,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
166 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
167 |
|
168 |
# Prepare white background for bootstrapping.
|
169 |
-
|
170 |
|
171 |
print(f'[INFO] Model is loaded!')
|
172 |
|
@@ -281,11 +279,14 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
281 |
Returns:
|
282 |
A single string of text prompt.
|
283 |
"""
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
289 |
|
290 |
@torch.no_grad()
|
291 |
def encode_imgs(
|
@@ -405,7 +406,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
405 |
25, 37], the masks are split into binary masks whose values are
|
406 |
greater than these levels. This results in tradual increase of mask
|
407 |
region as the timesteps increase. Details are described in our
|
408 |
-
paper
|
409 |
|
410 |
On the Three Modes of `mask_type`:
|
411 |
`self.mask_type` is predefined at the initialization stage of this
|
@@ -609,7 +610,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
609 |
|
610 |
Minimal Example:
|
611 |
>>> device = torch.device('cuda:0')
|
612 |
-
>>> smd =
|
613 |
>>> image = smd.sample('A photo of the dolomites')
|
614 |
>>> image.save('my_creation.png')
|
615 |
|
@@ -675,7 +676,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
675 |
|
676 |
Minimal Example:
|
677 |
>>> device = torch.device('cuda:0')
|
678 |
-
>>> smd =
|
679 |
>>> image = smd.sample_panorama(
|
680 |
>>> 'A photo of Alps', height=512, width=3072)
|
681 |
>>> image.save('my_panorama_creation.png')
|
@@ -792,7 +793,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
792 |
|
793 |
Example:
|
794 |
>>> device = torch.device('cuda:0')
|
795 |
-
>>> smd =
|
796 |
>>> prompts = {... specify prompts}
|
797 |
>>> masks = {... specify mask tensors}
|
798 |
>>> height, width = masks.shape[-2:]
|
@@ -881,7 +882,7 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
881 |
|
882 |
# prompts is None: return background.
|
883 |
# masks is None but prompts is not None: return prompts
|
884 |
-
# masks is not None and prompts is not None: Do
|
885 |
|
886 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
887 |
if background is None and background_prompt is not None:
|
@@ -1103,4 +1104,4 @@ class StableMultiDiffusionPipeline(nn.Module):
|
|
1103 |
image = blend(image, background[0], fg_mask)
|
1104 |
else:
|
1105 |
image = T.ToPILImage()(image)
|
1106 |
-
return image
|
|
|
1 |
+
# Copyright (c) 2025 Jaerin Lee
|
2 |
|
3 |
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
# of this software and associated documentation files (the "Software"), to deal
|
|
|
19 |
# SOFTWARE.
|
20 |
|
21 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration
|
22 |
+
from diffusers import LCMScheduler, DDIMScheduler, AutoencoderTiny
|
23 |
|
24 |
import torch
|
25 |
import torch.nn as nn
|
|
|
31 |
from tqdm import tqdm
|
32 |
from PIL import Image
|
33 |
|
34 |
+
from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
|
35 |
|
36 |
|
37 |
+
class SemanticDrawPipeline(nn.Module):
|
38 |
def __init__(
|
39 |
self,
|
40 |
device: torch.device,
|
41 |
dtype: torch.dtype = torch.float16,
|
42 |
+
sd_version: Literal['1.5'] = '1.5',
|
43 |
hf_key: Optional[str] = None,
|
44 |
lora_key: Optional[str] = None,
|
45 |
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
|
|
|
52 |
default_preprocess_mask_cover_alpha: float = 0.3,
|
53 |
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
|
54 |
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
|
55 |
+
has_i2t: bool = True,
|
56 |
) -> None:
|
57 |
+
r"""Stabilized regionally assigned texts-to-image generation for fast sampling.
|
58 |
|
59 |
Accelrated region-based text-to-image synthesis with Latent Consistency
|
60 |
Model while preserving mask fidelity and quality.
|
|
|
96 |
default_preprocess_mask_cover_alpha (float): Optional preprocessing
|
97 |
where each mask covered by other masks is reduced in its alpha
|
98 |
value by this specified factor.
|
99 |
+
t_index_list (List[int]): The default scheduling for the scheduler.
|
100 |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
|
101 |
defines the mask quantization modes. Details in the codes of
|
102 |
`self.process_mask`. Basically, this (subtly) controls the
|
103 |
smoothness of foreground-background blending. More continuous
|
104 |
means more blending, but smaller generated patch depending on
|
105 |
the mask standard deviation.
|
106 |
+
has_i2t (bool): Automatic background image to text prompt con-
|
107 |
+
version with BLIP-2 model. May not be necessary for the non-
|
108 |
+
streaming application.
|
109 |
"""
|
110 |
super().__init__()
|
111 |
|
|
|
124 |
self.mask_type = mask_type
|
125 |
|
126 |
print(f'[INFO] Loading Stable Diffusion...')
|
|
|
127 |
lora_weight_name = None
|
128 |
if self.sd_version == '1.5':
|
129 |
if hf_key is not None:
|
130 |
+
print(f'[INFO] Using custom model key: {hf_key}')
|
131 |
model_key = hf_key
|
132 |
else:
|
133 |
model_key = 'runwayml/stable-diffusion-v1-5'
|
|
|
134 |
lora_key = 'latent-consistency/lcm-lora-sdv1-5'
|
135 |
lora_weight_name = 'pytorch_lora_weights.safetensors'
|
|
|
|
|
|
|
|
|
|
|
136 |
else:
|
137 |
raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
|
138 |
|
139 |
# Create model
|
140 |
+
if has_i2t:
|
141 |
+
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
|
142 |
+
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
|
143 |
|
144 |
+
self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype)
|
145 |
if lora_key is None:
|
146 |
print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
|
147 |
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
|
|
164 |
self.vae_scale_factor = self.pipe.vae_scale_factor
|
165 |
|
166 |
# Prepare white background for bootstrapping.
|
167 |
+
self.get_white_background(768, 768)
|
168 |
|
169 |
print(f'[INFO] Model is loaded!')
|
170 |
|
|
|
279 |
Returns:
|
280 |
A single string of text prompt.
|
281 |
"""
|
282 |
+
if hasattr(self, 'i2t_model'):
|
283 |
+
question = 'Question: What are in the image? Answer:'
|
284 |
+
inputs = self.i2t_processor(image, question, return_tensors='pt')
|
285 |
+
out = self.i2t_model.generate(**inputs, max_new_tokens=77)
|
286 |
+
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
|
287 |
+
return prompt
|
288 |
+
else:
|
289 |
+
return ''
|
290 |
|
291 |
@torch.no_grad()
|
292 |
def encode_imgs(
|
|
|
406 |
25, 37], the masks are split into binary masks whose values are
|
407 |
greater than these levels. This results in tradual increase of mask
|
408 |
region as the timesteps increase. Details are described in our
|
409 |
+
paper.
|
410 |
|
411 |
On the Three Modes of `mask_type`:
|
412 |
`self.mask_type` is predefined at the initialization stage of this
|
|
|
610 |
|
611 |
Minimal Example:
|
612 |
>>> device = torch.device('cuda:0')
|
613 |
+
>>> smd = SemanticDrawPipeline(device)
|
614 |
>>> image = smd.sample('A photo of the dolomites')
|
615 |
>>> image.save('my_creation.png')
|
616 |
|
|
|
676 |
|
677 |
Minimal Example:
|
678 |
>>> device = torch.device('cuda:0')
|
679 |
+
>>> smd = SemanticDrawPipeline(device)
|
680 |
>>> image = smd.sample_panorama(
|
681 |
>>> 'A photo of Alps', height=512, width=3072)
|
682 |
>>> image.save('my_panorama_creation.png')
|
|
|
793 |
|
794 |
Example:
|
795 |
>>> device = torch.device('cuda:0')
|
796 |
+
>>> smd = SemanticDrawPipeline(device)
|
797 |
>>> prompts = {... specify prompts}
|
798 |
>>> masks = {... specify mask tensors}
|
799 |
>>> height, width = masks.shape[-2:]
|
|
|
882 |
|
883 |
# prompts is None: return background.
|
884 |
# masks is None but prompts is not None: return prompts
|
885 |
+
# masks is not None and prompts is not None: Do SemanticDraw.
|
886 |
|
887 |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
|
888 |
if background is None and background_prompt is not None:
|
|
|
1104 |
image = blend(image, background[0], fg_mask)
|
1105 |
else:
|
1106 |
image = T.ToPILImage()(image)
|
1107 |
+
return image
|