Artiprocher commited on
Commit
b0ab4d3
1 Parent(s): dddb151
Files changed (4) hide show
  1. LdmZhPipeline.py +1036 -0
  2. README.md +5 -5
  3. app.py +36 -0
  4. requirements.txt +6 -0
LdmZhPipeline.py ADDED
@@ -0,0 +1,1036 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ import importlib
4
+ import inspect
5
+ import os
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Optional, Union
8
+ from collections import OrderedDict
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch import nn
13
+ import functools
14
+
15
+ import diffusers
16
+ import PIL
17
+ from accelerate.utils.versions import is_torch_version
18
+ from huggingface_hub import snapshot_download
19
+ from packaging import version
20
+ from PIL import Image
21
+ from tqdm.auto import tqdm
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.dynamic_modules_utils import get_class_from_dynamic_module
25
+ from diffusers.modeling_utils import ModelMixin
26
+ from diffusers.hub_utils import http_user_agent
27
+ from diffusers.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
28
+ from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
29
+ from diffusers.utils import (
30
+ CONFIG_NAME,
31
+ DIFFUSERS_CACHE,
32
+ ONNX_WEIGHTS_NAME,
33
+ WEIGHTS_NAME,
34
+ BaseOutput,
35
+ deprecate,
36
+ is_transformers_available,
37
+ logging,
38
+ )
39
+
40
+
41
+ if is_transformers_available():
42
+ import transformers
43
+ from transformers import PreTrainedModel
44
+
45
+
46
+ INDEX_FILE = "diffusion_pytorch_model.bin"
47
+ CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
48
+ DUMMY_MODULES_FOLDER = "diffusers.utils"
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ LOADABLE_CLASSES = {
55
+ "diffusers": {
56
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
57
+ "SchedulerMixin": ["save_config", "from_config"],
58
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
59
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
60
+ },
61
+ "transformers": {
62
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
63
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
64
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
65
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
66
+ },
67
+ "LdmZhPipeline": {
68
+ "WukongClipTextEncoder": ["save_pretrained", "from_pretrained"],
69
+ "ESRGAN": ["save_pretrained", "from_pretrained"],
70
+ },
71
+ }
72
+
73
+ ALL_IMPORTABLE_CLASSES = {}
74
+ for library in LOADABLE_CLASSES:
75
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
76
+
77
+
78
+ @dataclass
79
+ class ImagePipelineOutput(BaseOutput):
80
+ """
81
+ Output class for image pipelines.
82
+
83
+ Args:
84
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
85
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
86
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
87
+ """
88
+
89
+ images: Union[List[PIL.Image.Image], np.ndarray]
90
+
91
+
92
+ @dataclass
93
+ class AudioPipelineOutput(BaseOutput):
94
+ """
95
+ Output class for audio pipelines.
96
+
97
+ Args:
98
+ audios (`np.ndarray`)
99
+ List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
100
+ denoised audio samples of the diffusion pipeline.
101
+ """
102
+
103
+ audios: np.ndarray
104
+
105
+
106
+ class DiffusionPipeline(ConfigMixin):
107
+ r"""
108
+ Base class for all models.
109
+
110
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
111
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
112
+
113
+ - move all PyTorch modules to the device of your choice
114
+ - enabling/disabling the progress bar for the denoising iteration
115
+
116
+ Class attributes:
117
+
118
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
119
+ components of the diffusion pipeline.
120
+ """
121
+ config_name = "model_index.json"
122
+
123
+ def register_modules(self, **kwargs):
124
+ # import it here to avoid circular import
125
+ from diffusers import pipelines
126
+
127
+ for name, module in kwargs.items():
128
+ # retrieve library
129
+ if module is None:
130
+ register_dict = {name: (None, None)}
131
+ else:
132
+ library = module.__module__.split(".")[0]
133
+
134
+ # check if the module is a pipeline module
135
+ pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
136
+ path = module.__module__.split(".")
137
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
138
+
139
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
140
+ # Or if it's a pipeline module, then the module is inside the pipeline
141
+ # folder so we set the library to module name.
142
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
143
+ library = pipeline_dir
144
+
145
+ # retrieve class_name
146
+ class_name = module.__class__.__name__
147
+
148
+ register_dict = {name: (library, class_name)}
149
+
150
+ # save model index config
151
+ self.register_to_config(**register_dict)
152
+
153
+ # set models
154
+ setattr(self, name, module)
155
+
156
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
157
+ """
158
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
159
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
160
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
161
+
162
+ Arguments:
163
+ save_directory (`str` or `os.PathLike`):
164
+ Directory to which to save. Will be created if it doesn't exist.
165
+ """
166
+ self.save_config(save_directory)
167
+
168
+ model_index_dict = dict(self.config)
169
+ model_index_dict.pop("_class_name")
170
+ model_index_dict.pop("_diffusers_version")
171
+ model_index_dict.pop("_module", None)
172
+
173
+ for pipeline_component_name in model_index_dict.keys():
174
+ sub_model = getattr(self, pipeline_component_name)
175
+ if sub_model is None:
176
+ # edge case for saving a pipeline with safety_checker=None
177
+ continue
178
+
179
+ model_cls = sub_model.__class__
180
+
181
+ save_method_name = None
182
+ # search for the model's base class in LOADABLE_CLASSES
183
+ for library_name, library_classes in LOADABLE_CLASSES.items():
184
+ library = importlib.import_module(library_name)
185
+ for base_class, save_load_methods in library_classes.items():
186
+ class_candidate = getattr(library, base_class)
187
+ if issubclass(model_cls, class_candidate):
188
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
189
+ save_method_name = save_load_methods[0]
190
+ break
191
+ if save_method_name is not None:
192
+ break
193
+
194
+ save_method = getattr(sub_model, save_method_name)
195
+ save_method(os.path.join(save_directory, pipeline_component_name))
196
+
197
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
198
+ if torch_device is None:
199
+ return self
200
+
201
+ module_names, _ = self.extract_init_dict(dict(self.config))
202
+ for name in module_names.keys():
203
+ module = getattr(self, name)
204
+ if isinstance(module, torch.nn.Module):
205
+ if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
206
+ logger.warning(
207
+ "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
208
+ " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
209
+ " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
210
+ " `float16` operations on those devices in PyTorch. Please remove the"
211
+ " `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
212
+ )
213
+ module.to(torch_device)
214
+ return self
215
+
216
+ @property
217
+ def device(self) -> torch.device:
218
+ r"""
219
+ Returns:
220
+ `torch.device`: The torch device on which the pipeline is located.
221
+ """
222
+ module_names, _ = self.extract_init_dict(dict(self.config))
223
+ for name in module_names.keys():
224
+ module = getattr(self, name)
225
+ if isinstance(module, torch.nn.Module):
226
+ # if module.device == torch.device("meta"):
227
+ # return torch.device("cpu")
228
+ return module.device
229
+ return torch.device("cpu")
230
+
231
+ @classmethod
232
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
233
+ r"""
234
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
235
+
236
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
237
+
238
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
239
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
240
+ task.
241
+
242
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
243
+ weights are discarded.
244
+
245
+ Parameters:
246
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
247
+ Can be either:
248
+
249
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
250
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
251
+ `CompVis/ldm-text2im-large-256`.
252
+ - A path to a *directory* containing pipeline weights saved using
253
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
254
+ torch_dtype (`str` or `torch.dtype`, *optional*):
255
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
256
+ will be automatically derived from the model's weights.
257
+ custom_pipeline (`str`, *optional*):
258
+
259
+ <Tip warning={true}>
260
+
261
+ This is an experimental feature and is likely to change in the future.
262
+
263
+ </Tip>
264
+
265
+ Can be either:
266
+
267
+ - A string, the *repo id* of a custom pipeline hosted inside a model repo on
268
+ https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
269
+ like `hf-internal-testing/diffusers-dummy-pipeline`.
270
+
271
+ <Tip>
272
+
273
+ It is required that the model repo has a file, called `pipeline.py` that defines the custom
274
+ pipeline.
275
+
276
+ </Tip>
277
+
278
+ - A string, the *file name* of a community pipeline hosted on GitHub under
279
+ https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
280
+ match exactly the file name without `.py` located under the above link, *e.g.*
281
+ `clip_guided_stable_diffusion`.
282
+
283
+ <Tip>
284
+
285
+ Community pipelines are always loaded from the current `main` branch of GitHub.
286
+
287
+ </Tip>
288
+
289
+ - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
290
+
291
+ <Tip>
292
+
293
+ It is required that the directory has a file, called `pipeline.py` that defines the custom
294
+ pipeline.
295
+
296
+ </Tip>
297
+
298
+ For more information on how to load and create custom pipelines, please have a look at [Loading and
299
+ Creating Custom
300
+ Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines)
301
+
302
+ torch_dtype (`str` or `torch.dtype`, *optional*):
303
+ force_download (`bool`, *optional*, defaults to `False`):
304
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
305
+ cached versions if they exist.
306
+ resume_download (`bool`, *optional*, defaults to `False`):
307
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
308
+ file exists.
309
+ proxies (`Dict[str, str]`, *optional*):
310
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
311
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
312
+ output_loading_info(`bool`, *optional*, defaults to `False`):
313
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
314
+ local_files_only(`bool`, *optional*, defaults to `False`):
315
+ Whether or not to only look at local files (i.e., do not try to download the model).
316
+ use_auth_token (`str` or *bool*, *optional*):
317
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
318
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
319
+ revision (`str`, *optional*, defaults to `"main"`):
320
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
321
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
322
+ identifier allowed by git.
323
+ mirror (`str`, *optional*):
324
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
325
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
326
+ Please refer to the mirror site for more information. specify the folder name here.
327
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
328
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
329
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
330
+ same device.
331
+
332
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
333
+ more information about each option see [designing a device
334
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
335
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
336
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
337
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
338
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
339
+ setting this argument to `True` will raise an error.
340
+
341
+ kwargs (remaining dictionary of keyword arguments, *optional*):
342
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
343
+ specific pipeline class. The overwritten components are then directly passed to the pipelines
344
+ `__init__` method. See example below for more information.
345
+
346
+ <Tip>
347
+
348
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
349
+ models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
350
+
351
+ </Tip>
352
+
353
+ <Tip>
354
+
355
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
356
+ this method in a firewalled environment.
357
+
358
+ </Tip>
359
+
360
+ Examples:
361
+
362
+ ```py
363
+ >>> from diffusers import DiffusionPipeline
364
+
365
+ >>> # Download pipeline from huggingface.co and cache.
366
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
367
+
368
+ >>> # Download pipeline that requires an authorization token
369
+ >>> # For more information on access tokens, please refer to this section
370
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
371
+ >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
372
+
373
+ >>> # Download pipeline, but overwrite scheduler
374
+ >>> from diffusers import LMSDiscreteScheduler
375
+
376
+ >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
377
+ >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
378
+ ```
379
+ """
380
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
381
+ resume_download = kwargs.pop("resume_download", False)
382
+ force_download = kwargs.pop("force_download", False)
383
+ proxies = kwargs.pop("proxies", None)
384
+ local_files_only = kwargs.pop("local_files_only", False)
385
+ use_auth_token = kwargs.pop("use_auth_token", None)
386
+ revision = kwargs.pop("revision", None)
387
+ torch_dtype = kwargs.pop("torch_dtype", None)
388
+ custom_pipeline = kwargs.pop("custom_pipeline", None)
389
+ provider = kwargs.pop("provider", None)
390
+ sess_options = kwargs.pop("sess_options", None)
391
+ device_map = kwargs.pop("device_map", None)
392
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
393
+
394
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
395
+ raise NotImplementedError(
396
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
397
+ " `device_map=None`."
398
+ )
399
+
400
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
401
+ raise NotImplementedError(
402
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
403
+ " `low_cpu_mem_usage=False`."
404
+ )
405
+
406
+ if low_cpu_mem_usage is False and device_map is not None:
407
+ raise ValueError(
408
+ f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
409
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
410
+ )
411
+
412
+ # 1. Download the checkpoints and configs
413
+ # use snapshot download here to get it working from from_pretrained
414
+ if not os.path.isdir(pretrained_model_name_or_path):
415
+ config_dict = cls.get_config_dict(
416
+ pretrained_model_name_or_path,
417
+ cache_dir=cache_dir,
418
+ resume_download=resume_download,
419
+ force_download=force_download,
420
+ proxies=proxies,
421
+ local_files_only=local_files_only,
422
+ use_auth_token=use_auth_token,
423
+ revision=revision,
424
+ )
425
+ # make sure we only download sub-folders and `diffusers` filenames
426
+ folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
427
+ allow_patterns = [os.path.join(k, "*") for k in folder_names]
428
+ allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
429
+
430
+ # make sure we don't download flax weights
431
+ ignore_patterns = "*.msgpack"
432
+
433
+ if custom_pipeline is not None:
434
+ allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
435
+
436
+ if cls != DiffusionPipeline:
437
+ requested_pipeline_class = cls.__name__
438
+ else:
439
+ requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
440
+ user_agent = {"pipeline_class": requested_pipeline_class}
441
+ if custom_pipeline is not None:
442
+ user_agent["custom_pipeline"] = custom_pipeline
443
+ user_agent = http_user_agent(user_agent)
444
+
445
+ # download all allow_patterns
446
+ cached_folder = snapshot_download(
447
+ pretrained_model_name_or_path,
448
+ cache_dir=cache_dir,
449
+ resume_download=resume_download,
450
+ proxies=proxies,
451
+ local_files_only=local_files_only,
452
+ use_auth_token=use_auth_token,
453
+ revision=revision,
454
+ allow_patterns=allow_patterns,
455
+ ignore_patterns=ignore_patterns,
456
+ user_agent=user_agent,
457
+ )
458
+ else:
459
+ cached_folder = pretrained_model_name_or_path
460
+
461
+ config_dict = cls.get_config_dict(cached_folder)
462
+
463
+ # 2. Load the pipeline class, if using custom module then load it from the hub
464
+ # if we load from explicit class, let's use it
465
+ if custom_pipeline is not None:
466
+ pipeline_class = get_class_from_dynamic_module(
467
+ custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
468
+ )
469
+ elif cls != DiffusionPipeline:
470
+ pipeline_class = cls
471
+ else:
472
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
473
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
474
+
475
+ # To be removed in 1.0.0
476
+ if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
477
+ version.parse(config_dict["_diffusers_version"]).base_version
478
+ ) <= version.parse("0.5.1"):
479
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
480
+
481
+ pipeline_class = StableDiffusionInpaintPipelineLegacy
482
+
483
+ deprecation_message = (
484
+ "You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
485
+ f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
486
+ " better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
487
+ " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
488
+ f" checkpoint {pretrained_model_name_or_path} to the format of"
489
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
490
+ " the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
491
+ )
492
+ deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
493
+
494
+ # some modules can be passed directly to the init
495
+ # in this case they are already instantiated in `kwargs`
496
+ # extract them here
497
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
498
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
499
+
500
+ init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
501
+
502
+ if len(unused_kwargs) > 0:
503
+ logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
504
+
505
+ init_kwargs = {}
506
+
507
+ # import it here to avoid circular import
508
+ from diffusers import pipelines
509
+
510
+ # 3. Load each module in the pipeline
511
+ for name, (library_name, class_name) in init_dict.items():
512
+ if class_name is None:
513
+ # edge case for when the pipeline was saved with safety_checker=None
514
+ init_kwargs[name] = None
515
+ continue
516
+
517
+ # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
518
+ if class_name.startswith("Flax"):
519
+ class_name = class_name[4:]
520
+
521
+ is_pipeline_module = hasattr(pipelines, library_name)
522
+ loaded_sub_model = None
523
+ sub_model_should_be_defined = True
524
+
525
+ # if the model is in a pipeline module, then we load it from the pipeline
526
+ if name in passed_class_obj:
527
+ # 1. check that passed_class_obj has correct parent class
528
+ if not is_pipeline_module:
529
+ library = importlib.import_module(library_name)
530
+ class_obj = getattr(library, class_name)
531
+ importable_classes = LOADABLE_CLASSES[library_name]
532
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
533
+
534
+ expected_class_obj = None
535
+ for class_name, class_candidate in class_candidates.items():
536
+ if issubclass(class_obj, class_candidate):
537
+ expected_class_obj = class_candidate
538
+
539
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
540
+ raise ValueError(
541
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
542
+ f" {expected_class_obj}"
543
+ )
544
+ elif passed_class_obj[name] is None:
545
+ logger.warn(
546
+ f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
547
+ f" that this might lead to problems when using {pipeline_class} and is not recommended."
548
+ )
549
+ sub_model_should_be_defined = False
550
+ else:
551
+ logger.warn(
552
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
553
+ " has the correct type"
554
+ )
555
+
556
+ # set passed class object
557
+ loaded_sub_model = passed_class_obj[name]
558
+ elif is_pipeline_module:
559
+ pipeline_module = getattr(pipelines, library_name)
560
+ class_obj = getattr(pipeline_module, class_name)
561
+ importable_classes = ALL_IMPORTABLE_CLASSES
562
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
563
+ else:
564
+ # else we just import it from the library.
565
+ library = importlib.import_module(library_name)
566
+ class_obj = getattr(library, class_name)
567
+ importable_classes = LOADABLE_CLASSES[library_name]
568
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
569
+
570
+ if loaded_sub_model is None and sub_model_should_be_defined:
571
+ load_method_name = None
572
+ for class_name, class_candidate in class_candidates.items():
573
+ if issubclass(class_obj, class_candidate):
574
+ load_method_name = importable_classes[class_name][1]
575
+
576
+ if load_method_name is None:
577
+ none_module = class_obj.__module__
578
+ if none_module.startswith(DUMMY_MODULES_FOLDER) and "dummy" in none_module:
579
+ # call class_obj for nice error message of missing requirements
580
+ class_obj()
581
+
582
+ raise ValueError(
583
+ f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
584
+ f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
585
+ )
586
+
587
+ load_method = getattr(class_obj, load_method_name)
588
+ loading_kwargs = {}
589
+
590
+ if issubclass(class_obj, torch.nn.Module):
591
+ loading_kwargs["torch_dtype"] = torch_dtype
592
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
593
+ loading_kwargs["provider"] = provider
594
+ loading_kwargs["sess_options"] = sess_options
595
+
596
+ is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
597
+ is_transformers_model = (
598
+ is_transformers_available()
599
+ and issubclass(class_obj, PreTrainedModel)
600
+ and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
601
+ )
602
+
603
+ # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
604
+ # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
605
+ # This makes sure that the weights won't be initialized which significantly speeds up loading.
606
+ if is_diffusers_model or is_transformers_model:
607
+ loading_kwargs["device_map"] = device_map
608
+ loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
609
+
610
+ # check if the module is in a subdirectory
611
+ if os.path.isdir(os.path.join(cached_folder, name)):
612
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
613
+ else:
614
+ # else load from the root directory
615
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
616
+
617
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
618
+
619
+ # 4. Potentially add passed objects if expected
620
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
621
+ if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
622
+ for module in missing_modules:
623
+ init_kwargs[module] = passed_class_obj[module]
624
+ elif len(missing_modules) > 0:
625
+ passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
626
+ raise ValueError(
627
+ f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
628
+ )
629
+
630
+ # 5. Instantiate the pipeline
631
+ model = pipeline_class(**init_kwargs)
632
+ return model
633
+
634
+ @property
635
+ def components(self) -> Dict[str, Any]:
636
+ r"""
637
+
638
+ The `self.components` property can be useful to run different pipelines with the same weights and
639
+ configurations to not have to re-allocate memory.
640
+
641
+ Examples:
642
+
643
+ ```py
644
+ >>> from diffusers import (
645
+ ... StableDiffusionPipeline,
646
+ ... StableDiffusionImg2ImgPipeline,
647
+ ... StableDiffusionInpaintPipeline,
648
+ ... )
649
+
650
+ >>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
651
+ >>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
652
+ >>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
653
+ ```
654
+
655
+ Returns:
656
+ A dictionaly containing all the modules needed to initialize the pipeline.
657
+ """
658
+ components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
659
+ expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
660
+
661
+ if set(components.keys()) != expected_modules:
662
+ raise ValueError(
663
+ f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
664
+ f" {expected_modules} to be defined, but {components} are defined."
665
+ )
666
+
667
+ return components
668
+
669
+ @staticmethod
670
+ def numpy_to_pil(images):
671
+ """
672
+ Convert a numpy image or a batch of images to a PIL image.
673
+ """
674
+ if images.ndim == 3:
675
+ images = images[None, ...]
676
+ images = (images * 255).round().astype("uint8")
677
+ if images.shape[-1] == 1:
678
+ # special case for grayscale (single channel) images
679
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
680
+ else:
681
+ pil_images = [Image.fromarray(image) for image in images]
682
+
683
+ return pil_images
684
+
685
+ def progress_bar(self, iterable):
686
+ if not hasattr(self, "_progress_bar_config"):
687
+ self._progress_bar_config = {}
688
+ elif not isinstance(self._progress_bar_config, dict):
689
+ raise ValueError(
690
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
691
+ )
692
+
693
+ return tqdm(iterable, **self._progress_bar_config)
694
+
695
+ def set_progress_bar_config(self, **kwargs):
696
+ self._progress_bar_config = kwargs
697
+
698
+
699
+ class LDMZhTextToImagePipeline(DiffusionPipeline):
700
+
701
+ def __init__(
702
+ self,
703
+ vqvae,
704
+ bert,
705
+ tokenizer,
706
+ unet,
707
+ scheduler,
708
+ sr,
709
+ ):
710
+ super().__init__()
711
+ self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler, sr=sr)
712
+
713
+ @torch.no_grad()
714
+ def __call__(
715
+ self,
716
+ prompt: Union[str, List[str]],
717
+ height: Optional[int] = 256,
718
+ width: Optional[int] = 256,
719
+ num_inference_steps: Optional[int] = 50,
720
+ guidance_scale: Optional[float] = 5.0,
721
+ eta: Optional[float] = 0.0,
722
+ generator: Optional[torch.Generator] = None,
723
+ output_type: Optional[str] = "pil",
724
+ return_dict: bool = True,
725
+ use_sr: bool = False,
726
+ **kwargs,
727
+ ):
728
+ r"""
729
+ Args:
730
+ prompt (`str` or `List[str]`):
731
+ The prompt or prompts to guide the image generation.
732
+ height (`int`, *optional*, defaults to 256):
733
+ The height in pixels of the generated image.
734
+ width (`int`, *optional*, defaults to 256):
735
+ The width in pixels of the generated image.
736
+ num_inference_steps (`int`, *optional*, defaults to 50):
737
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
738
+ expense of slower inference.
739
+ guidance_scale (`float`, *optional*, defaults to 1.0):
740
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
741
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
742
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
743
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
744
+ the, usually at the expense of lower image quality.
745
+ generator (`torch.Generator`, *optional*):
746
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
747
+ deterministic.
748
+ output_type (`str`, *optional*, defaults to `"pil"`):
749
+ The output format of the generate image. Choose between
750
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
751
+ return_dict (`bool`, *optional*):
752
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
753
+
754
+ Returns:
755
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
756
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
757
+ generated images.
758
+ """
759
+
760
+ if isinstance(prompt, str):
761
+ batch_size = 1
762
+ elif isinstance(prompt, list):
763
+ batch_size = len(prompt)
764
+ else:
765
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
766
+
767
+ if height % 8 != 0 or width % 8 != 0:
768
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
769
+
770
+ # get unconditional embeddings for classifier free guidance
771
+ if guidance_scale != 1.0:
772
+ uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=32, return_tensors="pt")
773
+ uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))
774
+
775
+ # get prompt text embeddings
776
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=32, return_tensors="pt")
777
+ text_embeddings = self.bert(text_input.input_ids.to(self.device))
778
+
779
+ latents = torch.randn(
780
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
781
+ generator=generator,
782
+ )
783
+ latents = latents.to(self.device)
784
+
785
+ self.scheduler.set_timesteps(num_inference_steps)
786
+
787
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
788
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
789
+
790
+ extra_kwargs = {}
791
+ if accepts_eta:
792
+ extra_kwargs["eta"] = eta
793
+
794
+ for t in self.progress_bar(self.scheduler.timesteps):
795
+ if guidance_scale == 1.0:
796
+ # guidance_scale of 1 means no guidance
797
+ latents_input = latents
798
+ context = text_embeddings
799
+ else:
800
+ # For classifier free guidance, we need to do two forward passes.
801
+ # Here we concatenate the unconditional and text embeddings into a single batch
802
+ # to avoid doing two forward passes
803
+ latents_input = torch.cat([latents] * 2)
804
+ context = torch.cat([uncond_embeddings, text_embeddings])
805
+
806
+ # predict the noise residual
807
+ noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
808
+ # perform guidance
809
+ if guidance_scale != 1.0:
810
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
811
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
812
+
813
+ # compute the previous noisy sample x_t -> x_t-1
814
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
815
+
816
+ # scale and decode the image latents with vae
817
+ latents = 1 / 0.18215 * latents
818
+ image = self.vqvae.decode(latents).sample
819
+
820
+ image = (image / 2 + 0.5).clamp(0, 1)
821
+ if use_sr:
822
+ image = self.sr(image)
823
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
824
+ if output_type == "pil":
825
+ image = self.numpy_to_pil(image)
826
+
827
+ if not return_dict:
828
+ return (image,)
829
+
830
+ return ImagePipelineOutput(images=image)
831
+
832
+
833
+ class QuickGELU(nn.Module):
834
+ def forward(self, x: torch.Tensor):
835
+ return x * torch.sigmoid(1.702 * x)
836
+
837
+
838
+ class ResidualAttentionBlock(nn.Module):
839
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
840
+ super().__init__()
841
+ self.attn = nn.MultiheadAttention(d_model, n_head)
842
+ self.ln_1 = nn.LayerNorm(d_model,eps=1e-07)
843
+ self.mlp = nn.Sequential(OrderedDict([
844
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
845
+ ("gelu", QuickGELU()),
846
+ ("c_proj", nn.Linear(d_model * 4, d_model))
847
+ ]))
848
+ self.ln_2 = nn.LayerNorm(d_model,eps=1e-07)
849
+ self.attn_mask = attn_mask
850
+ def attention(self, x: torch.Tensor):
851
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
852
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
853
+ def forward(self, x: torch.Tensor):
854
+ x = x + self.attention(self.ln_1(x))
855
+ x = x + self.mlp(self.ln_2(x))
856
+ return x
857
+
858
+
859
+ class Transformer(nn.Module):
860
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
861
+ super().__init__()
862
+ self.width = width
863
+ self.layers = layers
864
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
865
+
866
+ def forward(self, x: torch.Tensor):
867
+ return self.resblocks(x)
868
+
869
+
870
+ class TextTransformer(nn.Module):
871
+ def __init__(self,
872
+ context_length = 32,
873
+ vocab_size = 21128,
874
+ output_dim = 768,
875
+ width = 768,
876
+ layers = 12,
877
+ heads = 12,
878
+ return_full_embed = False):
879
+ super(TextTransformer, self).__init__()
880
+ self.width = width
881
+ self.layers = layers
882
+ self.vocab_size = vocab_size
883
+ self.return_full_embed = return_full_embed
884
+
885
+ self.transformer = Transformer(width, layers, heads, self.build_attntion_mask(context_length))
886
+ self.text_projection = torch.nn.Parameter(
887
+ torch.tensor(np.random.normal(0, self.width ** -0.5, size=(self.width, output_dim)).astype(np.float32)))
888
+ self.ln_final = nn.LayerNorm(width,eps=1e-07)
889
+
890
+ # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/27
891
+ # https://github.com/pytorch/pytorch/blob/a40812de534b42fcf0eb57a5cecbfdc7a70100cf/torch/nn/init.py#L22
892
+ self.embedding_table = nn.Parameter(nn.init.trunc_normal_(torch.empty(vocab_size, width),std=0.02))
893
+ # self.embedding_table = nn.Embedding.from_pretrained(nn.init.trunc_normal_(torch.empty(vocab_size, width),std=0.02))
894
+ self.positional_embedding = nn.Parameter(nn.init.trunc_normal_(torch.empty(context_length, width),std=0.01))
895
+ # self.positional_embedding = nn.Embedding.from_pretrained(nn.init.trunc_normal_(torch.empty(context_length, width),std=0.01))
896
+
897
+ self.index_select=torch.index_select
898
+ self.reshape=torch.reshape
899
+
900
+ @staticmethod
901
+ def build_attntion_mask(context_length):
902
+ mask = np.triu(np.full((context_length, context_length), -np.inf).astype(np.float32), 1)
903
+ mask = torch.tensor(mask)
904
+ return mask
905
+
906
+ def forward(self, x: torch.Tensor):
907
+
908
+ tail_token=(x==102).nonzero(as_tuple=True)
909
+ bsz, ctx_len = x.shape
910
+ flatten_id = x.flatten()
911
+ index_select_result = self.index_select(self.embedding_table,0, flatten_id)
912
+ x = self.reshape(index_select_result, (bsz, ctx_len, -1))
913
+ x = x + self.positional_embedding
914
+ x = x.permute(1, 0, 2) # NLD -> LND
915
+ x = self.transformer(x)
916
+ x = x.permute(1, 0, 2) # LND -> NLD
917
+ x = self.ln_final(x)
918
+ x=x[tail_token]
919
+ x = x @ self.text_projection
920
+ return x
921
+
922
+
923
+ class WukongClipTextEncoder(ModelMixin, ConfigMixin):
924
+
925
+ @register_to_config
926
+ def __init__(
927
+ self,
928
+ ):
929
+ super().__init__()
930
+ self.model = TextTransformer()
931
+
932
+ def forward(
933
+ self,
934
+ tokens
935
+ ):
936
+ z = self.model(tokens)
937
+ z = z / torch.linalg.norm(z, dim=-1, keepdim=True)
938
+ if z.ndim==2:
939
+ z = z.view((z.shape[0], 1, z.shape[1]))
940
+ return z
941
+
942
+
943
+ def make_layer(block, n_layers):
944
+ layers = []
945
+ for _ in range(n_layers):
946
+ layers.append(block())
947
+ return nn.Sequential(*layers)
948
+
949
+
950
+ class ResidualDenseBlock_5C(nn.Module):
951
+ def __init__(self, nf=64, gc=32, bias=True):
952
+ super(ResidualDenseBlock_5C, self).__init__()
953
+ # gc: growth channel, i.e. intermediate channels
954
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
955
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
956
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
957
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
958
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
959
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
960
+
961
+ # initialization
962
+ # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
963
+
964
+ def forward(self, x):
965
+ x1 = self.lrelu(self.conv1(x))
966
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
967
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
968
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
969
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
970
+ return x5 * 0.2 + x
971
+
972
+
973
+ class RRDB(nn.Module):
974
+ '''Residual in Residual Dense Block'''
975
+
976
+ def __init__(self, nf, gc=32):
977
+ super(RRDB, self).__init__()
978
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
979
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
980
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
981
+
982
+ def forward(self, x):
983
+ out = self.RDB1(x)
984
+ out = self.RDB2(out)
985
+ out = self.RDB3(out)
986
+ return out * 0.2 + x
987
+
988
+
989
+ class RRDBNet(nn.Module):
990
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
991
+ super(RRDBNet, self).__init__()
992
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
993
+
994
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
995
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
996
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
997
+ #### upsampling
998
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
999
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
1000
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
1001
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
1002
+
1003
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1004
+
1005
+ def forward(self, x):
1006
+ fea = self.conv_first(x)
1007
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
1008
+ fea = fea + trunk
1009
+
1010
+ fea = self.lrelu(self.upconv1(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
1011
+ fea = self.lrelu(self.upconv2(torch.nn.functional.interpolate(fea, scale_factor=2, mode='nearest')))
1012
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
1013
+
1014
+ return out
1015
+
1016
+
1017
+ class ESRGAN(ModelMixin, ConfigMixin):
1018
+
1019
+ @register_to_config
1020
+ def __init__(
1021
+ self,
1022
+ ):
1023
+ super().__init__()
1024
+ self.model = RRDBNet(3, 3, 64, 23, gc=32)
1025
+
1026
+ def forward(
1027
+ self,
1028
+ img_LR
1029
+ ):
1030
+ img_LR = img_LR[:,[2,1,0],:,:]
1031
+ img_LR = img_LR.to(self.device)
1032
+ with torch.no_grad():
1033
+ output = self.model(img_LR)
1034
+ output = output.data.float().clamp_(0, 1)
1035
+ output = output[:,[2,1,0],:,:]
1036
+ return output
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Pai Diffusion Food Large Zh
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: PAI Diffusion (Poem)
3
+ emoji: 🌖
4
+ colorFrom: gray
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.11.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from LdmZhPipeline import LDMZhTextToImagePipeline
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model_id = "alibaba-pai/pai-diffusion-food-large-zh"
9
+
10
+ pipe_text2img = LDMZhTextToImagePipeline.from_pretrained(model_id, use_auth_token="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")
11
+ pipe_text2img = pipe_text2img.to(device)
12
+
13
+ def infer_text2img(prompt, guide, steps):
14
+ output = pipe_text2img([prompt]*9, guidance_scale=guide, num_inference_steps=steps, use_sr=True)
15
+ images = output.images[0]
16
+ return images
17
+
18
+ with gr.Blocks() as demo:
19
+ examples = [
20
+ ["番茄炒蛋"],
21
+ ["草莓披萨"],
22
+ ["韩式炸鸡"],
23
+ ]
24
+ with gr.Row():
25
+ with gr.Column(scale=1, ):
26
+ image_out = gr.Image(label = '输出(output)')
27
+ with gr.Column(scale=1, ):
28
+ prompt = gr.Textbox(label = '提示词(prompt)')
29
+ submit_btn = gr.Button("生成图像(Generate)")
30
+ with gr.Row(scale=0.5 ):
31
+ guide = gr.Slider(2, 15, value = 7, label = '文本引导强度(guidance scale)')
32
+ steps = gr.Slider(10, 50, value = 20, step = 1, label = '迭代次数(inference steps)')
33
+ ex = gr.Examples(examples, fn=infer_text2img, inputs=[prompt, guide, steps], outputs=image_out)
34
+ submit_btn.click(fn = infer_text2img, inputs = [prompt, guide, steps], outputs = image_out)
35
+
36
+ demo.queue(concurrency_count=1, max_size=8).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch
3
+ torchvision
4
+ diffusers==0.7.2
5
+ transformers
6
+ accelerate