genevera commited on
Commit
1a8a5f1
·
1 Parent(s): 5919897

reformat app.py with black

Browse files
Files changed (1) hide show
  1. app.py +91 -38
app.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
  import gradio as gr
4
  from scipy import signal
5
  from diffusers.utils import logging
 
6
  logging.set_verbosity_error()
7
  from diffusers.loaders import AttnProcsLayers
8
  from transformers import CLIPTextModel, CLIPTokenizer
@@ -36,25 +37,42 @@ class AudioTokenWrapper(torch.nn.Module):
36
  lora,
37
  device,
38
  ):
39
-
40
  super().__init__()
41
  self.repo_id = repo_id
42
  # Load scheduler and models
43
  self.ddpm = DDPMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
44
  self.ddim = DDIMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
45
  self.pndm = PNDMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
46
- self.lms = LMSDiscreteScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
47
- self.euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
48
- self.euler = EulerDiscreteScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
49
- self.dpm = DPMSolverMultistepScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
50
- self.dpms = DPMSolverSinglestepScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
51
- self.deis = DEISMultistepScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
52
- self.unipc = UniPCMultistepScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
53
- self.heun = HeunDiscreteScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
54
- self.kdpm2_anc = KDPM2AncestralDiscreteScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
55
- self.kdpm2 = KDPM2DiscreteScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
56
-
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  self.tokenizer = CLIPTokenizer.from_pretrained(
60
  self.repo_id, subfolder="tokenizer"
@@ -70,10 +88,11 @@ class AudioTokenWrapper(torch.nn.Module):
70
  )
71
 
72
  checkpoint = torch.load(
73
- 'models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
74
- cfg = BEATsConfig(checkpoint['cfg'])
 
75
  self.aud_encoder = BEATs(cfg)
76
- self.aud_encoder.load_state_dict(checkpoint['model'])
77
  self.aud_encoder.predictor = None
78
  input_size = 768 * 3
79
  self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
@@ -87,46 +106,58 @@ class AudioTokenWrapper(torch.nn.Module):
87
  # Set correct lora layers
88
  lora_attn_procs = {}
89
  for name in self.unet.attn_processors.keys():
90
- cross_attention_dim = None if name.endswith(
91
- "attn1.processor") else self.unet.config.cross_attention_dim
 
 
 
92
  if name.startswith("mid_block"):
93
  hidden_size = self.unet.config.block_out_channels[-1]
94
  elif name.startswith("up_blocks"):
95
  block_id = int(name[len("up_blocks.")])
96
- hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
 
 
97
  elif name.startswith("down_blocks"):
98
  block_id = int(name[len("down_blocks.")])
99
  hidden_size = self.unet.config.block_out_channels[block_id]
100
 
101
- lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size,
102
- cross_attention_dim=cross_attention_dim)
 
103
 
104
  self.unet.set_attn_processor(lora_attn_procs)
105
  self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
106
  self.lora_layers.eval()
107
- lora_layers_learned_embeds = 'models/lora_layers_learned_embeds.bin'
108
- self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=device))
 
 
109
  self.unet.load_attn_procs(lora_layers_learned_embeds)
110
 
111
  self.embedder.eval()
112
- embedder_learned_embeds = 'models/embedder_learned_embeds.bin'
113
- self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=device))
 
 
114
 
115
- self.placeholder_token = '<*>'
116
  num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token)
117
  if num_added_tokens == 0:
118
  raise ValueError(
119
  f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different"
120
  " `placeholder_token` that is not already in the tokenizer."
121
  )
122
- self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(self.placeholder_token)
 
 
123
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
124
  self.text_encoder.resize_token_embeddings(len(self.tokenizer))
125
 
126
 
127
  def greet(audio, steps=25, scheduler="ddpm"):
128
  sample_rate, audio = audio
129
- audio = audio.astype(np.float32, order='C') / 32768.0
130
  desired_sample_rate = 16000
131
 
132
  match scheduler:
@@ -171,9 +202,11 @@ def greet(audio, steps=25, scheduler="ddpm"):
171
  audio = signal.resample(audio, new_length)
172
 
173
  weight_dtype = torch.float32
174
- prompt = 'a photo of <*>'
175
 
176
- audio_values = torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
 
 
177
  if audio_values.ndim == 1:
178
  audio_values = torch.unsqueeze(audio_values, dim=0)
179
 
@@ -185,22 +218,25 @@ def greet(audio, steps=25, scheduler="ddpm"):
185
 
186
  token_embeds[model.placeholder_token_id] = audio_token.clone()
187
  generator = torch.Generator(device=device)
188
- generator.manual_seed(23229249375547) # no reason this can't be input by the user!
189
  pipeline = StableDiffusionPipeline.from_pretrained(
190
  pretrained_model_name_or_path=model.repo_id,
191
  tokenizer=model.tokenizer,
192
  text_encoder=model.text_encoder,
193
  vae=model.vae,
194
  unet=model.unet,
195
- scheduler=use_sched,
196
  safety_checker=None,
197
  ).to(device)
198
  pipeline.enable_xformers_memory_efficient_attention()
199
 
200
  # print(f"taking {steps} steps using the {scheduler} scheduler")
201
- image = pipeline(prompt, num_inference_steps=steps, guidance_scale=8.5, generator=generator).images[0]
 
 
202
  return image
203
 
 
204
  lora = False
205
  repo_id = "philz1337/reliberate"
206
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -223,13 +259,30 @@ examples = [
223
  my_demo = gr.Interface(
224
  fn=greet,
225
  inputs=[
226
- "audio",
227
- gr.Slider(value=25,step=1,label="diffusion steps"),
228
- gr.Dropdown(choices=["ddim","ddpm","pndm","lms","euler_anc","euler","dpm","dpms","deis","unipc","heun","kdpm2_anc","kdpm2"],value="unipc"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  ],
230
  outputs="image",
231
- title='AudioToken',
232
  description=description,
233
- examples=examples
234
  )
235
  my_demo.launch()
 
3
  import gradio as gr
4
  from scipy import signal
5
  from diffusers.utils import logging
6
+
7
  logging.set_verbosity_error()
8
  from diffusers.loaders import AttnProcsLayers
9
  from transformers import CLIPTextModel, CLIPTokenizer
 
37
  lora,
38
  device,
39
  ):
 
40
  super().__init__()
41
  self.repo_id = repo_id
42
  # Load scheduler and models
43
  self.ddpm = DDPMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
44
  self.ddim = DDIMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
45
  self.pndm = PNDMScheduler.from_pretrained(self.repo_id, subfolder="scheduler")
46
+ self.lms = LMSDiscreteScheduler.from_pretrained(
47
+ self.repo_id, subfolder="scheduler"
48
+ )
49
+ self.euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(
50
+ self.repo_id, subfolder="scheduler"
51
+ )
52
+ self.euler = EulerDiscreteScheduler.from_pretrained(
53
+ self.repo_id, subfolder="scheduler"
54
+ )
55
+ self.dpm = DPMSolverMultistepScheduler.from_pretrained(
56
+ self.repo_id, subfolder="scheduler"
57
+ )
58
+ self.dpms = DPMSolverSinglestepScheduler.from_pretrained(
59
+ self.repo_id, subfolder="scheduler"
60
+ )
61
+ self.deis = DEISMultistepScheduler.from_pretrained(
62
+ self.repo_id, subfolder="scheduler"
63
+ )
64
+ self.unipc = UniPCMultistepScheduler.from_pretrained(
65
+ self.repo_id, subfolder="scheduler"
66
+ )
67
+ self.heun = HeunDiscreteScheduler.from_pretrained(
68
+ self.repo_id, subfolder="scheduler"
69
+ )
70
+ self.kdpm2_anc = KDPM2AncestralDiscreteScheduler.from_pretrained(
71
+ self.repo_id, subfolder="scheduler"
72
+ )
73
+ self.kdpm2 = KDPM2DiscreteScheduler.from_pretrained(
74
+ self.repo_id, subfolder="scheduler"
75
+ )
76
 
77
  self.tokenizer = CLIPTokenizer.from_pretrained(
78
  self.repo_id, subfolder="tokenizer"
 
88
  )
89
 
90
  checkpoint = torch.load(
91
+ "models/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt"
92
+ )
93
+ cfg = BEATsConfig(checkpoint["cfg"])
94
  self.aud_encoder = BEATs(cfg)
95
+ self.aud_encoder.load_state_dict(checkpoint["model"])
96
  self.aud_encoder.predictor = None
97
  input_size = 768 * 3
98
  self.embedder = FGAEmbedder(input_size=input_size, output_size=768)
 
106
  # Set correct lora layers
107
  lora_attn_procs = {}
108
  for name in self.unet.attn_processors.keys():
109
+ cross_attention_dim = (
110
+ None
111
+ if name.endswith("attn1.processor")
112
+ else self.unet.config.cross_attention_dim
113
+ )
114
  if name.startswith("mid_block"):
115
  hidden_size = self.unet.config.block_out_channels[-1]
116
  elif name.startswith("up_blocks"):
117
  block_id = int(name[len("up_blocks.")])
118
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[
119
+ block_id
120
+ ]
121
  elif name.startswith("down_blocks"):
122
  block_id = int(name[len("down_blocks.")])
123
  hidden_size = self.unet.config.block_out_channels[block_id]
124
 
125
+ lora_attn_procs[name] = LoRAAttnProcessor(
126
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
127
+ )
128
 
129
  self.unet.set_attn_processor(lora_attn_procs)
130
  self.lora_layers = AttnProcsLayers(self.unet.attn_processors)
131
  self.lora_layers.eval()
132
+ lora_layers_learned_embeds = "models/lora_layers_learned_embeds.bin"
133
+ self.lora_layers.load_state_dict(
134
+ torch.load(lora_layers_learned_embeds, map_location=device)
135
+ )
136
  self.unet.load_attn_procs(lora_layers_learned_embeds)
137
 
138
  self.embedder.eval()
139
+ embedder_learned_embeds = "models/embedder_learned_embeds.bin"
140
+ self.embedder.load_state_dict(
141
+ torch.load(embedder_learned_embeds, map_location=device)
142
+ )
143
 
144
+ self.placeholder_token = "<*>"
145
  num_added_tokens = self.tokenizer.add_tokens(self.placeholder_token)
146
  if num_added_tokens == 0:
147
  raise ValueError(
148
  f"The tokenizer already contains the token {self.placeholder_token}. Please pass a different"
149
  " `placeholder_token` that is not already in the tokenizer."
150
  )
151
+ self.placeholder_token_id = self.tokenizer.convert_tokens_to_ids(
152
+ self.placeholder_token
153
+ )
154
  # Resize the token embeddings as we are adding new special tokens to the tokenizer
155
  self.text_encoder.resize_token_embeddings(len(self.tokenizer))
156
 
157
 
158
  def greet(audio, steps=25, scheduler="ddpm"):
159
  sample_rate, audio = audio
160
+ audio = audio.astype(np.float32, order="C") / 32768.0
161
  desired_sample_rate = 16000
162
 
163
  match scheduler:
 
202
  audio = signal.resample(audio, new_length)
203
 
204
  weight_dtype = torch.float32
205
+ prompt = "a photo of <*>"
206
 
207
+ audio_values = (
208
+ torch.unsqueeze(torch.tensor(audio), dim=0).to(device).to(dtype=weight_dtype)
209
+ )
210
  if audio_values.ndim == 1:
211
  audio_values = torch.unsqueeze(audio_values, dim=0)
212
 
 
218
 
219
  token_embeds[model.placeholder_token_id] = audio_token.clone()
220
  generator = torch.Generator(device=device)
221
+ generator.manual_seed(23229249375547) # no reason this can't be input by the user!
222
  pipeline = StableDiffusionPipeline.from_pretrained(
223
  pretrained_model_name_or_path=model.repo_id,
224
  tokenizer=model.tokenizer,
225
  text_encoder=model.text_encoder,
226
  vae=model.vae,
227
  unet=model.unet,
228
+ scheduler=use_sched,
229
  safety_checker=None,
230
  ).to(device)
231
  pipeline.enable_xformers_memory_efficient_attention()
232
 
233
  # print(f"taking {steps} steps using the {scheduler} scheduler")
234
+ image = pipeline(
235
+ prompt, num_inference_steps=steps, guidance_scale=8.5, generator=generator
236
+ ).images[0]
237
  return image
238
 
239
+
240
  lora = False
241
  repo_id = "philz1337/reliberate"
242
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
259
  my_demo = gr.Interface(
260
  fn=greet,
261
  inputs=[
262
+ "audio",
263
+ gr.Slider(value=25, step=1, label="diffusion steps"),
264
+ gr.Dropdown(
265
+ choices=[
266
+ "ddim",
267
+ "ddpm",
268
+ "pndm",
269
+ "lms",
270
+ "euler_anc",
271
+ "euler",
272
+ "dpm",
273
+ "dpms",
274
+ "deis",
275
+ "unipc",
276
+ "heun",
277
+ "kdpm2_anc",
278
+ "kdpm2",
279
+ ],
280
+ value="unipc",
281
+ ),
282
  ],
283
  outputs="image",
284
+ title="AudioToken",
285
  description=description,
286
+ examples=examples,
287
  )
288
  my_demo.launch()