ironjr commited on
Commit
30b6ccf
·
verified ·
1 Parent(s): 183248b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +31 -30
model.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024 Jaerin Lee
2
 
3
  # Permission is hereby granted, free of charge, to any person obtaining a copy
4
  # of this software and associated documentation files (the "Software"), to deal
@@ -19,7 +19,7 @@
19
  # SOFTWARE.
20
 
21
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
22
- from diffusers import DiffusionPipeline, LCMScheduler, DDIMScheduler, AutoencoderTiny
23
 
24
  import torch
25
  import torch.nn as nn
@@ -31,15 +31,15 @@ from typing import Tuple, List, Literal, Optional, Union
31
  from tqdm import tqdm
32
  from PIL import Image
33
 
34
- from util import gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
35
 
36
 
37
- class StableMultiDiffusionPipeline(nn.Module):
38
  def __init__(
39
  self,
40
  device: torch.device,
41
  dtype: torch.dtype = torch.float16,
42
- sd_version: Literal['1.5', '2.0', '2.1', 'xl'] = '1.5',
43
  hf_key: Optional[str] = None,
44
  lora_key: Optional[str] = None,
45
  load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
@@ -52,8 +52,9 @@ class StableMultiDiffusionPipeline(nn.Module):
52
  default_preprocess_mask_cover_alpha: float = 0.3,
53
  t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
54
  mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
 
55
  ) -> None:
56
- r"""Stabilized MultiDiffusion for fast sampling.
57
 
58
  Accelrated region-based text-to-image synthesis with Latent Consistency
59
  Model while preserving mask fidelity and quality.
@@ -95,13 +96,16 @@ class StableMultiDiffusionPipeline(nn.Module):
95
  default_preprocess_mask_cover_alpha (float): Optional preprocessing
96
  where each mask covered by other masks is reduced in its alpha
97
  value by this specified factor.
98
- t_index_list (List[int]): The default scheduling for LCM scheduler.
99
  mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
100
  defines the mask quantization modes. Details in the codes of
101
  `self.process_mask`. Basically, this (subtly) controls the
102
  smoothness of foreground-background blending. More continuous
103
  means more blending, but smaller generated patch depending on
104
  the mask standard deviation.
 
 
 
105
  """
106
  super().__init__()
107
 
@@ -120,30 +124,24 @@ class StableMultiDiffusionPipeline(nn.Module):
120
  self.mask_type = mask_type
121
 
122
  print(f'[INFO] Loading Stable Diffusion...')
123
- variant = None
124
  lora_weight_name = None
125
  if self.sd_version == '1.5':
126
  if hf_key is not None:
127
- print(f'[INFO] Using Hugging Face custom model key: {hf_key}')
128
  model_key = hf_key
129
  else:
130
  model_key = 'runwayml/stable-diffusion-v1-5'
131
- # variant = 'fp16'
132
  lora_key = 'latent-consistency/lcm-lora-sdv1-5'
133
  lora_weight_name = 'pytorch_lora_weights.safetensors'
134
- # elif self.sd_version == 'xl':
135
- # model_key = 'stabilityai/stable-diffusion-xl-base-1.0'
136
- # lora_key = 'latent-consistency/lcm-lora-sdxl'
137
- # variant = 'fp16'
138
- # lora_weight_name = 'pytorch_lora_weights.safetensors'
139
  else:
140
  raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
141
 
142
  # Create model
143
- self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
144
- self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
 
145
 
146
- self.pipe = DiffusionPipeline.from_pretrained(model_key, variant=variant, torch_dtype=dtype).to(self.device)
147
  if lora_key is None:
148
  print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
149
  self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
@@ -166,7 +164,7 @@ class StableMultiDiffusionPipeline(nn.Module):
166
  self.vae_scale_factor = self.pipe.vae_scale_factor
167
 
168
  # Prepare white background for bootstrapping.
169
- # self.get_white_background(768, 768) # This cause problem in HF ZeroGPU environment.
170
 
171
  print(f'[INFO] Model is loaded!')
172
 
@@ -281,11 +279,14 @@ class StableMultiDiffusionPipeline(nn.Module):
281
  Returns:
282
  A single string of text prompt.
283
  """
284
- question = 'Question: What are in the image? Answer:'
285
- inputs = self.i2t_processor(image, question, return_tensors='pt')
286
- out = self.i2t_model.generate(**inputs, max_new_tokens=77)
287
- prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
288
- return prompt
 
 
 
289
 
290
  @torch.no_grad()
291
  def encode_imgs(
@@ -405,7 +406,7 @@ class StableMultiDiffusionPipeline(nn.Module):
405
  25, 37], the masks are split into binary masks whose values are
406
  greater than these levels. This results in tradual increase of mask
407
  region as the timesteps increase. Details are described in our
408
- paper at https://arxiv.org/pdf/2403.09055.pdf.
409
 
410
  On the Three Modes of `mask_type`:
411
  `self.mask_type` is predefined at the initialization stage of this
@@ -609,7 +610,7 @@ class StableMultiDiffusionPipeline(nn.Module):
609
 
610
  Minimal Example:
611
  >>> device = torch.device('cuda:0')
612
- >>> smd = StableMultiDiffusionPipeline(device)
613
  >>> image = smd.sample('A photo of the dolomites')
614
  >>> image.save('my_creation.png')
615
 
@@ -675,7 +676,7 @@ class StableMultiDiffusionPipeline(nn.Module):
675
 
676
  Minimal Example:
677
  >>> device = torch.device('cuda:0')
678
- >>> smd = StableMultiDiffusionPipeline(device)
679
  >>> image = smd.sample_panorama(
680
  >>> 'A photo of Alps', height=512, width=3072)
681
  >>> image.save('my_panorama_creation.png')
@@ -792,7 +793,7 @@ class StableMultiDiffusionPipeline(nn.Module):
792
 
793
  Example:
794
  >>> device = torch.device('cuda:0')
795
- >>> smd = StableMultiDiffusionPipeline(device)
796
  >>> prompts = {... specify prompts}
797
  >>> masks = {... specify mask tensors}
798
  >>> height, width = masks.shape[-2:]
@@ -881,7 +882,7 @@ class StableMultiDiffusionPipeline(nn.Module):
881
 
882
  # prompts is None: return background.
883
  # masks is None but prompts is not None: return prompts
884
- # masks is not None and prompts is not None: Do StableMultiDiffusion.
885
 
886
  if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
887
  if background is None and background_prompt is not None:
@@ -1103,4 +1104,4 @@ class StableMultiDiffusionPipeline(nn.Module):
1103
  image = blend(image, background[0], fg_mask)
1104
  else:
1105
  image = T.ToPILImage()(image)
1106
- return image
 
1
+ # Copyright (c) 2025 Jaerin Lee
2
 
3
  # Permission is hereby granted, free of charge, to any person obtaining a copy
4
  # of this software and associated documentation files (the "Software"), to deal
 
19
  # SOFTWARE.
20
 
21
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
22
+ from diffusers import LCMScheduler, DDIMScheduler, AutoencoderTiny
23
 
24
  import torch
25
  import torch.nn as nn
 
31
  from tqdm import tqdm
32
  from PIL import Image
33
 
34
+ from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center
35
 
36
 
37
+ class SemanticDrawPipeline(nn.Module):
38
  def __init__(
39
  self,
40
  device: torch.device,
41
  dtype: torch.dtype = torch.float16,
42
+ sd_version: Literal['1.5'] = '1.5',
43
  hf_key: Optional[str] = None,
44
  lora_key: Optional[str] = None,
45
  load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down.
 
52
  default_preprocess_mask_cover_alpha: float = 0.3,
53
  t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # [0, 12, 25, 37], # Magic number.
54
  mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete',
55
+ has_i2t: bool = True,
56
  ) -> None:
57
+ r"""Stabilized regionally assigned texts-to-image generation for fast sampling.
58
 
59
  Accelrated region-based text-to-image synthesis with Latent Consistency
60
  Model while preserving mask fidelity and quality.
 
96
  default_preprocess_mask_cover_alpha (float): Optional preprocessing
97
  where each mask covered by other masks is reduced in its alpha
98
  value by this specified factor.
99
+ t_index_list (List[int]): The default scheduling for the scheduler.
100
  mask_type (Literal['discrete', 'semi-continuous', 'continuous']):
101
  defines the mask quantization modes. Details in the codes of
102
  `self.process_mask`. Basically, this (subtly) controls the
103
  smoothness of foreground-background blending. More continuous
104
  means more blending, but smaller generated patch depending on
105
  the mask standard deviation.
106
+ has_i2t (bool): Automatic background image to text prompt con-
107
+ version with BLIP-2 model. May not be necessary for the non-
108
+ streaming application.
109
  """
110
  super().__init__()
111
 
 
124
  self.mask_type = mask_type
125
 
126
  print(f'[INFO] Loading Stable Diffusion...')
 
127
  lora_weight_name = None
128
  if self.sd_version == '1.5':
129
  if hf_key is not None:
130
+ print(f'[INFO] Using custom model key: {hf_key}')
131
  model_key = hf_key
132
  else:
133
  model_key = 'runwayml/stable-diffusion-v1-5'
 
134
  lora_key = 'latent-consistency/lcm-lora-sdv1-5'
135
  lora_weight_name = 'pytorch_lora_weights.safetensors'
 
 
 
 
 
136
  else:
137
  raise ValueError(f'Stable Diffusion version {self.sd_version} not supported.')
138
 
139
  # Create model
140
+ if has_i2t:
141
+ self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b')
142
+ self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b')
143
 
144
+ self.pipe = load_model(model_key, self.sd_version, self.device, self.dtype)
145
  if lora_key is None:
146
  print(f'[INFO] LCM LoRA is not available for SD version {sd_version}. Using DDIM Scheduler instead...')
147
  self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
 
164
  self.vae_scale_factor = self.pipe.vae_scale_factor
165
 
166
  # Prepare white background for bootstrapping.
167
+ self.get_white_background(768, 768)
168
 
169
  print(f'[INFO] Model is loaded!')
170
 
 
279
  Returns:
280
  A single string of text prompt.
281
  """
282
+ if hasattr(self, 'i2t_model'):
283
+ question = 'Question: What are in the image? Answer:'
284
+ inputs = self.i2t_processor(image, question, return_tensors='pt')
285
+ out = self.i2t_model.generate(**inputs, max_new_tokens=77)
286
+ prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip()
287
+ return prompt
288
+ else:
289
+ return ''
290
 
291
  @torch.no_grad()
292
  def encode_imgs(
 
406
  25, 37], the masks are split into binary masks whose values are
407
  greater than these levels. This results in tradual increase of mask
408
  region as the timesteps increase. Details are described in our
409
+ paper.
410
 
411
  On the Three Modes of `mask_type`:
412
  `self.mask_type` is predefined at the initialization stage of this
 
610
 
611
  Minimal Example:
612
  >>> device = torch.device('cuda:0')
613
+ >>> smd = SemanticDrawPipeline(device)
614
  >>> image = smd.sample('A photo of the dolomites')
615
  >>> image.save('my_creation.png')
616
 
 
676
 
677
  Minimal Example:
678
  >>> device = torch.device('cuda:0')
679
+ >>> smd = SemanticDrawPipeline(device)
680
  >>> image = smd.sample_panorama(
681
  >>> 'A photo of Alps', height=512, width=3072)
682
  >>> image.save('my_panorama_creation.png')
 
793
 
794
  Example:
795
  >>> device = torch.device('cuda:0')
796
+ >>> smd = SemanticDrawPipeline(device)
797
  >>> prompts = {... specify prompts}
798
  >>> masks = {... specify mask tensors}
799
  >>> height, width = masks.shape[-2:]
 
882
 
883
  # prompts is None: return background.
884
  # masks is None but prompts is not None: return prompts
885
+ # masks is not None and prompts is not None: Do SemanticDraw.
886
 
887
  if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0):
888
  if background is None and background_prompt is not None:
 
1104
  image = blend(image, background[0], fg_mask)
1105
  else:
1106
  image = T.ToPILImage()(image)
1107
+ return image