Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files- app.py +143 -0
- imagenet_en_cn.py +1002 -0
- pixelflow/data_in1k.py +158 -0
- pixelflow/model.py +449 -0
- pixelflow/pipeline_pixelflow.py +276 -0
- pixelflow/scheduling_pixelflow.py +106 -0
- pixelflow/solver_ode_wrapper.py +50 -0
- pixelflow/utils/config.py +27 -0
- pixelflow/utils/logger.py +27 -0
- pixelflow/utils/misc.py +50 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
import spaces
|
6 |
+
from imagenet_en_cn import IMAGENET_1K_CLASSES
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from huggingface_hub import snapshot_download
|
9 |
+
|
10 |
+
import torch
|
11 |
+
# from transformers import T5EncoderModel, AutoTokenizer
|
12 |
+
|
13 |
+
from pixelflow.scheduling_pixelflow import PixelFlowScheduler
|
14 |
+
from pixelflow.pipeline_pixelflow import PixelFlowPipeline
|
15 |
+
from pixelflow.utils import config as config_utils
|
16 |
+
from pixelflow.utils.misc import seed_everything
|
17 |
+
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False)
|
20 |
+
parser.add_argument('--checkpoint', type=str, help='checkpoint folder path')
|
21 |
+
parser.add_argument('--class_cond', action='store_true', help='use class conditional generation')
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# deploy
|
25 |
+
args.checkpoint = "pixelflow_c2i"
|
26 |
+
args.class_cond = True
|
27 |
+
|
28 |
+
|
29 |
+
if args.class_cond:
|
30 |
+
output_dir = args.checkpoint
|
31 |
+
if not os.path.exists(output_dir):
|
32 |
+
snapshot_download(repo_id="ShoufaChen/PixelFlow-Class2Image", local_dir=output_dir)
|
33 |
+
config = OmegaConf.load(f"{output_dir}/config.yaml")
|
34 |
+
model = config_utils.instantiate_from_config(config.model)
|
35 |
+
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
36 |
+
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
|
37 |
+
text_encoder = None
|
38 |
+
tokenizer = None
|
39 |
+
resolution = 256
|
40 |
+
NUM_EXAMPLES = 4
|
41 |
+
else:
|
42 |
+
raise NotImplementedError("Please run locally.")
|
43 |
+
config = OmegaConf.load(f"{output_dir}/config.yaml")
|
44 |
+
model = config_utils.instantiate_from_config(config.model).to(device)
|
45 |
+
print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
46 |
+
ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
|
47 |
+
text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl").to(device)
|
48 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
|
49 |
+
resolution = 1024
|
50 |
+
NUM_EXAMPLES = 1
|
51 |
+
model.load_state_dict(ckpt, strict=True)
|
52 |
+
model.eval()
|
53 |
+
|
54 |
+
print(f"outside space.GPU. {torch.cuda.is_available()=}")
|
55 |
+
if torch.cuda.is_available():
|
56 |
+
model = model.cuda()
|
57 |
+
device = torch.device("cuda")
|
58 |
+
else:
|
59 |
+
raise ValueError("No GPU")
|
60 |
+
|
61 |
+
scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3)
|
62 |
+
|
63 |
+
pipeline = PixelFlowPipeline(
|
64 |
+
scheduler,
|
65 |
+
model,
|
66 |
+
text_encoder=text_encoder,
|
67 |
+
tokenizer=tokenizer,
|
68 |
+
max_token_length=512,
|
69 |
+
)
|
70 |
+
|
71 |
+
@spaces.GPU
|
72 |
+
def infer(use_ode_dopri5, noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage):
|
73 |
+
print(f"inside space.GPU. {torch.cuda.is_available()=}")
|
74 |
+
seed_everything(seed)
|
75 |
+
with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
|
76 |
+
samples = pipeline(
|
77 |
+
prompt=[class_label] * NUM_EXAMPLES,
|
78 |
+
height=resolution,
|
79 |
+
width=resolution,
|
80 |
+
num_inference_steps=list(num_steps_per_stage),
|
81 |
+
guidance_scale=cfg_scale, # The guidance for the first frame, set it to 7 for 384p variant
|
82 |
+
device=device,
|
83 |
+
shift=noise_shift,
|
84 |
+
use_ode_dopri5=use_ode_dopri5,
|
85 |
+
)
|
86 |
+
samples = (samples * 255).round().astype("uint8")
|
87 |
+
samples = [Image.fromarray(sample) for sample in samples]
|
88 |
+
return samples
|
89 |
+
|
90 |
+
|
91 |
+
css = """
|
92 |
+
h1 {
|
93 |
+
text-align: center;
|
94 |
+
display: block;
|
95 |
+
}
|
96 |
+
|
97 |
+
.follow-link {
|
98 |
+
margin-top: 0.8em;
|
99 |
+
font-size: 1em;
|
100 |
+
text-align: center;
|
101 |
+
}
|
102 |
+
"""
|
103 |
+
|
104 |
+
|
105 |
+
with gr.Blocks(css=css) as demo:
|
106 |
+
gr.Markdown("# PixelFlow: Pixel-Space Generative Models with Flow")
|
107 |
+
gr.HTML("""
|
108 |
+
<div class="follow-link">
|
109 |
+
For text-to-image generation, please follow
|
110 |
+
<a href="https://github.com/ShoufaChen/PixelFlow/tree/main?tab=readme-ov-file#demo">text-to-image</a>.
|
111 |
+
For more details, refer to our
|
112 |
+
<a href="https://arxiv.org/abs/2504.07963">arXiv paper</a> and <a href="https://github.com/ShoufaChen/PixelFlow">GitHub repo</a>.
|
113 |
+
</div>
|
114 |
+
""")
|
115 |
+
|
116 |
+
with gr.Tabs():
|
117 |
+
with gr.TabItem('Generate'):
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Column():
|
120 |
+
with gr.Row():
|
121 |
+
if args.class_cond:
|
122 |
+
user_input = gr.Dropdown(
|
123 |
+
list(IMAGENET_1K_CLASSES.values()),
|
124 |
+
value='daisy [雏菊]',
|
125 |
+
type="index", label='ImageNet-1K Class'
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
# text input
|
129 |
+
user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",)
|
130 |
+
ode_dopri5 = gr.Checkbox(label="Dopri5 ODE", info="Use Dopri5 ODE solver")
|
131 |
+
noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift')
|
132 |
+
cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
|
133 |
+
num_steps_per_stage = []
|
134 |
+
for stage_idx in range(config.scheduler.num_stages):
|
135 |
+
num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=10, label=f'Num Inference Steps (Stage {stage_idx})')
|
136 |
+
num_steps_per_stage.append(num_steps)
|
137 |
+
seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
|
138 |
+
button = gr.Button("Generate", variant="primary")
|
139 |
+
with gr.Column():
|
140 |
+
output = gr.Gallery(label='Generated Images', height=700)
|
141 |
+
button.click(infer, inputs=[ode_dopri5, noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output])
|
142 |
+
demo.queue()
|
143 |
+
demo.launch(share=True, debug=True)
|
imagenet_en_cn.py
ADDED
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
IMAGENET_1K_CLASSES = {
|
2 |
+
0: 'tench, Tinca tinca [丁鲷]',
|
3 |
+
1: 'goldfish, Carassius auratus [金鱼]',
|
4 |
+
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias [大白鲨]',
|
5 |
+
3: 'tiger shark, Galeocerdo cuvieri [虎鲨]',
|
6 |
+
4: 'hammerhead, hammerhead shark [锤头鲨]',
|
7 |
+
5: 'electric ray, crampfish, numbfish, torpedo [电鳐]',
|
8 |
+
6: 'stingray [黄貂鱼]',
|
9 |
+
7: 'cock [公鸡]',
|
10 |
+
8: 'hen [母鸡]',
|
11 |
+
9: 'ostrich, Struthio camelus [鸵鸟]',
|
12 |
+
10: 'brambling, Fringilla montifringilla [燕雀]',
|
13 |
+
11: 'goldfinch, Carduelis carduelis [金翅雀]',
|
14 |
+
12: 'house finch, linnet, Carpodacus mexicanus [家朱雀]',
|
15 |
+
13: 'junco, snowbird [灯芯草雀]',
|
16 |
+
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea [靛蓝雀,靛蓝鸟]',
|
17 |
+
15: 'robin, American robin, Turdus migratorius [蓝鹀]',
|
18 |
+
16: 'bulbul [夜莺]',
|
19 |
+
17: 'jay [松鸦]',
|
20 |
+
18: 'magpie [喜鹊]',
|
21 |
+
19: 'chickadee [山雀]',
|
22 |
+
20: 'water ouzel, dipper [河鸟]',
|
23 |
+
21: 'kite [鸢(猛禽)]',
|
24 |
+
22: 'bald eagle, American eagle, Haliaeetus leucocephalus [秃头鹰]',
|
25 |
+
23: 'vulture [秃鹫]',
|
26 |
+
24: 'great grey owl, great gray owl, Strix nebulosa [大灰猫头鹰]',
|
27 |
+
25: 'European fire salamander, Salamandra salamandra [欧洲火蝾螈]',
|
28 |
+
26: 'common newt, Triturus vulgaris [普通蝾螈]',
|
29 |
+
27: 'eft [水蜥]',
|
30 |
+
28: 'spotted salamander, Ambystoma maculatum [斑点蝾螈]',
|
31 |
+
29: 'axolotl, mud puppy, Ambystoma mexicanum [蝾螈,泥狗]',
|
32 |
+
30: 'bullfrog, Rana catesbeiana [牛蛙]',
|
33 |
+
31: 'tree frog, tree-frog [树蛙]',
|
34 |
+
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui [尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍]',
|
35 |
+
33: 'loggerhead, loggerhead turtle, Caretta caretta [红海龟]',
|
36 |
+
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea [皮革龟]',
|
37 |
+
35: 'mud turtle [泥龟]',
|
38 |
+
36: 'terrapin [淡水龟]',
|
39 |
+
37: 'box turtle, box tortoise [箱龟]',
|
40 |
+
38: 'banded gecko [带状壁虎]',
|
41 |
+
39: 'common iguana, iguana, Iguana iguana [普通鬣蜥]',
|
42 |
+
40: 'American chameleon, anole, Anolis carolinensis [美国变色龙]',
|
43 |
+
41: 'whiptail, whiptail lizard [鞭尾蜥蜴]',
|
44 |
+
42: 'agama [飞龙科蜥蜴]',
|
45 |
+
43: 'frilled lizard, Chlamydosaurus kingi [褶边蜥蜴]',
|
46 |
+
44: 'alligator lizard [鳄鱼蜥蜴]',
|
47 |
+
45: 'Gila monster, Heloderma suspectum [毒蜥]',
|
48 |
+
46: 'green lizard, Lacerta viridis [绿蜥蜴]',
|
49 |
+
47: 'African chameleon, Chamaeleo chamaeleon [非洲变色龙]',
|
50 |
+
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis [科莫多蜥蜴]',
|
51 |
+
49: 'African crocodile, Nile crocodile, Crocodylus niloticus [非洲鳄,尼罗河鳄鱼]',
|
52 |
+
50: 'American alligator, Alligator mississipiensis [美国鳄鱼,鳄鱼]',
|
53 |
+
51: 'triceratops [三角龙]',
|
54 |
+
52: 'thunder snake, worm snake, Carphophis amoenus [雷蛇,蠕虫蛇]',
|
55 |
+
53: 'ringneck snake, ring-necked snake, ring snake [环蛇,环颈蛇]',
|
56 |
+
54: 'hognose snake, puff adder, sand viper [希腊蛇]',
|
57 |
+
55: 'green snake, grass snake [绿蛇,草蛇]',
|
58 |
+
56: 'king snake, kingsnake [国王蛇]',
|
59 |
+
57: 'garter snake, grass snake [袜带蛇,草蛇]',
|
60 |
+
58: 'water snake [水蛇]',
|
61 |
+
59: 'vine snake [藤蛇]',
|
62 |
+
60: 'night snake, Hypsiglena torquata [夜蛇]',
|
63 |
+
61: 'boa constrictor, Constrictor constrictor [大蟒蛇]',
|
64 |
+
62: 'rock python, rock snake, Python sebae [岩石蟒蛇,岩蛇,蟒蛇]',
|
65 |
+
63: 'Indian cobra, Naja naja [印度眼镜蛇]',
|
66 |
+
64: 'green mamba [绿曼巴]',
|
67 |
+
65: 'sea snake [海蛇]',
|
68 |
+
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus [角腹蛇]',
|
69 |
+
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus [菱纹响尾蛇]',
|
70 |
+
68: 'sidewinder, horned rattlesnake, Crotalus cerastes [角响尾蛇]',
|
71 |
+
69: 'trilobite [三叶虫]',
|
72 |
+
70: 'harvestman, daddy longlegs, Phalangium opilio [盲蜘蛛]',
|
73 |
+
71: 'scorpion [蝎子]',
|
74 |
+
72: 'black and gold garden spider, Argiope aurantia [黑金花园蜘蛛]',
|
75 |
+
73: 'barn spider, Araneus cavaticus [谷仓蜘蛛]',
|
76 |
+
74: 'garden spider, Aranea diademata [花园蜘蛛]',
|
77 |
+
75: 'black widow, Latrodectus mactans [黑寡妇蜘蛛]',
|
78 |
+
76: 'tarantula [狼蛛]',
|
79 |
+
77: 'wolf spider, hunting spider [狼蜘蛛,狩猎蜘蛛]',
|
80 |
+
78: 'tick [壁虱]',
|
81 |
+
79: 'centipede [蜈蚣]',
|
82 |
+
80: 'black grouse [黑松鸡]',
|
83 |
+
81: 'ptarmigan [松鸡,雷鸟]',
|
84 |
+
82: 'ruffed grouse, partridge, Bonasa umbellus [披肩鸡,披肩榛鸡]',
|
85 |
+
83: 'prairie chicken, prairie grouse, prairie fowl [草原鸡,草原松鸡]',
|
86 |
+
84: 'peacock [孔雀]',
|
87 |
+
85: 'quail [鹌鹑]',
|
88 |
+
86: 'partridge [鹧鸪]',
|
89 |
+
87: 'African grey, African gray, Psittacus erithacus [非洲灰鹦鹉]',
|
90 |
+
88: 'macaw [金刚鹦鹉]',
|
91 |
+
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita [硫冠鹦鹉]',
|
92 |
+
90: 'lorikeet [短尾鹦鹉]',
|
93 |
+
91: 'coucal [褐翅鸦鹃]',
|
94 |
+
92: 'bee eater [蜜蜂]',
|
95 |
+
93: 'hornbill [犀鸟]',
|
96 |
+
94: 'hummingbird [蜂鸟]',
|
97 |
+
95: 'jacamar [鹟䴕]',
|
98 |
+
96: 'toucan [犀鸟]',
|
99 |
+
97: 'drake [野鸭]',
|
100 |
+
98: 'red-breasted merganser, Mergus serrator [���胸秋沙鸭]',
|
101 |
+
99: 'goose [鹅]',
|
102 |
+
100: 'black swan, Cygnus atratus [黑天鹅]',
|
103 |
+
101: 'tusker [大象]',
|
104 |
+
102: 'echidna, spiny anteater, anteater [针鼹鼠]',
|
105 |
+
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus [鸭嘴兽]',
|
106 |
+
104: 'wallaby, brush kangaroo [沙袋鼠]',
|
107 |
+
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus [考拉,考拉熊]',
|
108 |
+
106: 'wombat [袋熊]',
|
109 |
+
107: 'jellyfish [水母]',
|
110 |
+
108: 'sea anemone, anemone [海葵]',
|
111 |
+
109: 'brain coral [脑珊瑚]',
|
112 |
+
110: 'flatworm, platyhelminth [扁形虫扁虫]',
|
113 |
+
111: 'nematode, nematode worm, roundworm [线虫,蛔虫]',
|
114 |
+
112: 'conch [海螺]',
|
115 |
+
113: 'snail [蜗牛]',
|
116 |
+
114: 'slug [鼻涕虫]',
|
117 |
+
115: 'sea slug, nudibranch [海参]',
|
118 |
+
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore [石鳖]',
|
119 |
+
117: 'chambered nautilus, pearly nautilus, nautilus [鹦鹉螺]',
|
120 |
+
118: 'Dungeness crab, Cancer magister [珍宝蟹]',
|
121 |
+
119: 'rock crab, Cancer irroratus [石蟹]',
|
122 |
+
120: 'fiddler crab [招潮蟹]',
|
123 |
+
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica [帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹]',
|
124 |
+
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus [美国龙虾,缅因州龙虾]',
|
125 |
+
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish [大螯虾]',
|
126 |
+
124: 'crayfish, crawfish, crawdad, crawdaddy [小龙虾]',
|
127 |
+
125: 'hermit crab [寄居蟹]',
|
128 |
+
126: 'isopod [等足目动物(明虾和螃蟹近亲)]',
|
129 |
+
127: 'white stork, Ciconia ciconia [白鹳]',
|
130 |
+
128: 'black stork, Ciconia nigra [黑鹳]',
|
131 |
+
129: 'spoonbill [鹭]',
|
132 |
+
130: 'flamingo [火烈鸟]',
|
133 |
+
131: 'little blue heron, Egretta caerulea [小蓝鹭]',
|
134 |
+
132: 'American egret, great white heron, Egretta albus [美国鹭,大白鹭]',
|
135 |
+
133: 'bittern [麻鸦]',
|
136 |
+
134: 'crane [鹤]',
|
137 |
+
135: 'limpkin, Aramus pictus [秧鹤]',
|
138 |
+
136: 'European gallinule, Porphyrio porphyrio [欧洲水鸡,紫水鸡]',
|
139 |
+
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana [沼泽泥母鸡,水母鸡]',
|
140 |
+
138: 'bustard [鸨]',
|
141 |
+
139: 'ruddy turnstone, Arenaria interpres [红翻石鹬]',
|
142 |
+
140: 'red-backed sandpiper, dunlin, Erolia alpina [红背鹬,黑腹滨鹬]',
|
143 |
+
141: 'redshank, Tringa totanus [红脚鹬]',
|
144 |
+
142: 'dowitcher [半蹼鹬]',
|
145 |
+
143: 'oystercatcher, oyster catcher [蛎鹬]',
|
146 |
+
144: 'pelican [鹈鹕]',
|
147 |
+
145: 'king penguin, Aptenodytes patagonica [国王企鹅]',
|
148 |
+
146: 'albatross, mollymawk [信天翁,大海鸟]',
|
149 |
+
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus [灰鲸]',
|
150 |
+
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca [杀人鲸,逆戟鲸,虎鲸]',
|
151 |
+
149: 'dugong, Dugong dugon [海牛]',
|
152 |
+
150: 'sea lion [海狮]',
|
153 |
+
151: 'Chihuahua [奇瓦瓦]',
|
154 |
+
152: 'Japanese spaniel [日本猎犬]',
|
155 |
+
153: 'Maltese dog, Maltese terrier, Maltese [马尔济斯犬]',
|
156 |
+
154: 'Pekinese, Pekingese, Peke [狮子狗]',
|
157 |
+
155: 'Shih-Tzu [西施犬]',
|
158 |
+
156: 'Blenheim spaniel [布莱尼姆猎犬]',
|
159 |
+
157: 'papillon [巴比狗]',
|
160 |
+
158: 'toy terrier [玩具犬]',
|
161 |
+
159: 'Rhodesian ridgeback [罗得西亚长背猎狗]',
|
162 |
+
160: 'Afghan hound, Afghan [阿富汗猎犬]',
|
163 |
+
161: 'basset, basset hound [猎犬]',
|
164 |
+
162: 'beagle [比格犬,猎兔犬]',
|
165 |
+
163: 'bloodhound, sleuthhound [侦探犬]',
|
166 |
+
164: 'bluetick [蓝色快狗]',
|
167 |
+
165: 'black-and-tan coonhound [黑褐猎浣熊犬]',
|
168 |
+
166: 'Walker hound, Walker foxhound [沃克猎犬]',
|
169 |
+
167: 'English foxhound [英国猎狐犬]',
|
170 |
+
168: 'redbone [美洲赤狗]',
|
171 |
+
169: 'borzoi, Russian wolfhound [俄罗斯猎狼犬]',
|
172 |
+
170: 'Irish wolfhound [爱尔兰猎狼犬]',
|
173 |
+
171: 'Italian greyhound [意大利灰狗]',
|
174 |
+
172: 'whippet [惠比特犬]',
|
175 |
+
173: 'Ibizan hound, Ibizan Podenco [依比沙猎犬]',
|
176 |
+
174: 'Norwegian elkhound, elkhound [挪威猎犬]',
|
177 |
+
175: 'otterhound, otter hound [奥达猎犬,水獭猎犬]',
|
178 |
+
176: 'Saluki, gazelle hound [沙克犬,瞪羚猎犬]',
|
179 |
+
177: 'Scottish deerhound, deerhound [苏格兰猎鹿犬,猎鹿犬]',
|
180 |
+
178: 'Weimaraner [威玛猎犬]',
|
181 |
+
179: 'Staffordshire bullterrier, Staffordshire bull terrier [斯塔福德郡牛头梗,斯塔福德郡斗牛梗]',
|
182 |
+
180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier [美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗]',
|
183 |
+
181: 'Bedlington terrier [贝德灵顿梗]',
|
184 |
+
182: 'Border terrier [边境梗]',
|
185 |
+
183: 'Kerry blue terrier [凯丽蓝梗]',
|
186 |
+
184: 'Irish terrier [爱尔兰梗]',
|
187 |
+
185: 'Norfolk terrier [诺福克梗]',
|
188 |
+
186: 'Norwich terrier [诺维奇梗]',
|
189 |
+
187: 'Yorkshire terrier [约克郡梗]',
|
190 |
+
188: 'wire-haired fox terrier [刚毛猎狐梗]',
|
191 |
+
189: 'Lakeland terrier [莱克兰梗]',
|
192 |
+
190: 'Sealyham terrier, Sealyham [锡利哈姆梗]',
|
193 |
+
191: 'Airedale, Airedale terrier [艾尔谷犬]',
|
194 |
+
192: 'cairn, cairn terrier [凯恩梗]',
|
195 |
+
193: 'Australian terrier [澳大利亚梗]',
|
196 |
+
194: 'Dandie Dinmont, Dandie Dinmont terrier [丹迪丁蒙梗]',
|
197 |
+
195: 'Boston bull, Boston terrier [波士顿梗]',
|
198 |
+
196: 'miniature schnauzer [迷你雪纳瑞犬]',
|
199 |
+
197: 'giant schnauzer [巨型雪纳瑞犬]',
|
200 |
+
198: 'standard schnauzer [标准雪纳瑞犬]',
|
201 |
+
199: 'Scotch terrier, Scottish terrier, Scottie [苏格兰梗]',
|
202 |
+
200: 'Tibetan terrier, chrysanthemum dog [西藏梗,菊花狗]',
|
203 |
+
201: 'silky terrier, Sydney silky [丝毛梗]',
|
204 |
+
202: 'soft-coated wheaten terrier [软毛麦色梗]',
|
205 |
+
203: 'West Highland white terrier [西高地白梗]',
|
206 |
+
204: 'Lhasa, Lhasa apso [拉萨阿普索犬]',
|
207 |
+
205: 'flat-coated retriever [平毛寻回犬]',
|
208 |
+
206: 'curly-coated retriever [卷毛寻回犬]',
|
209 |
+
207: 'golden retriever [金毛猎犬]',
|
210 |
+
208: 'Labrador retriever [拉布拉多猎犬]',
|
211 |
+
209: 'Chesapeake Bay retriever [乞沙比克猎犬]',
|
212 |
+
210: 'German short-haired pointer [德国短毛猎犬]',
|
213 |
+
211: 'vizsla, Hungarian pointer [维兹拉犬]',
|
214 |
+
212: 'English setter [英国谍犬]',
|
215 |
+
213: 'Irish setter, red setter [爱尔兰雪达犬,红色猎犬]',
|
216 |
+
214: 'Gordon setter [戈登雪达犬]',
|
217 |
+
215: 'Brittany spaniel [布列塔尼犬猎犬]',
|
218 |
+
216: 'clumber, clumber spaniel [黄毛,黄毛猎犬]',
|
219 |
+
217: 'English springer, English springer spaniel [英国史宾格犬]',
|
220 |
+
218: 'Welsh springer spaniel [威尔士史宾格犬]',
|
221 |
+
219: 'cocker spaniel, English cocker spaniel, cocker [可卡犬,英国可卡犬]',
|
222 |
+
220: 'Sussex spaniel [萨塞克斯猎犬]',
|
223 |
+
221: 'Irish water spaniel [爱尔兰水猎犬]',
|
224 |
+
222: 'kuvasz [哥威斯犬]',
|
225 |
+
223: 'schipperke [舒柏奇犬]',
|
226 |
+
224: 'groenendael [比利时牧羊犬]',
|
227 |
+
225: 'malinois [马里努阿犬]',
|
228 |
+
226: 'briard [伯瑞犬]',
|
229 |
+
227: 'kelpie [凯尔皮犬]',
|
230 |
+
228: 'komondor [匈牙利牧羊犬]',
|
231 |
+
229: 'Old English sheepdog, bobtail [老英国牧羊犬]',
|
232 |
+
230: 'Shetland sheepdog, Shetland sheep dog, Shetland [喜乐蒂牧羊犬]',
|
233 |
+
231: 'collie [牧羊犬]',
|
234 |
+
232: 'Border collie [边境牧羊犬]',
|
235 |
+
233: 'Bouvier des Flandres, Bouviers des Flandres [法兰德斯牧牛狗]',
|
236 |
+
234: 'Rottweiler [罗特韦尔犬]',
|
237 |
+
235: 'German shepherd, German shepherd dog, German police dog, alsatian [德国牧羊犬,德国警犬,阿尔萨斯]',
|
238 |
+
236: 'Doberman, Doberman pinscher [多伯曼犬,杜宾犬]',
|
239 |
+
237: 'miniature pinscher [迷你杜宾犬]',
|
240 |
+
238: 'Greater Swiss Mountain dog [大瑞士山地犬]',
|
241 |
+
239: 'Bernese mountain dog [伯恩山犬]',
|
242 |
+
240: 'Appenzeller [Appenzeller狗]',
|
243 |
+
241: 'EntleBucher [EntleBucher狗]',
|
244 |
+
242: 'boxer [拳师狗]',
|
245 |
+
243: 'bull mastiff [斗牛獒]',
|
246 |
+
244: 'Tibetan mastiff [藏獒]',
|
247 |
+
245: 'French bulldog [法国斗牛犬]',
|
248 |
+
246: 'Great Dane [大丹犬]',
|
249 |
+
247: 'Saint Bernard, St Bernard [圣伯纳德狗]',
|
250 |
+
248: 'Eskimo dog, husky [爱斯基摩犬,哈士奇]',
|
251 |
+
249: 'malamute, malemute, Alaskan malamute [雪橇犬,阿拉斯加爱斯基摩狗]',
|
252 |
+
250: 'Siberian husky [哈士奇]',
|
253 |
+
251: 'dalmatian, coach dog, carriage dog [达尔马提亚,教练车狗]',
|
254 |
+
252: 'affenpinscher, monkey pinscher, monkey dog [狮毛狗]',
|
255 |
+
253: 'basenji [巴辛吉狗]',
|
256 |
+
254: 'pug, pug-dog [哈巴狗,狮子狗]',
|
257 |
+
255: 'Leonberg [莱昂贝格狗]',
|
258 |
+
256: 'Newfoundland, Newfoundland dog [纽芬兰岛狗]',
|
259 |
+
257: 'Great Pyrenees [大白熊犬]',
|
260 |
+
258: 'Samoyed, Samoyede [萨摩耶犬]',
|
261 |
+
259: 'Pomeranian [博美犬]',
|
262 |
+
260: 'chow, chow chow [松狮,松狮]',
|
263 |
+
261: 'keeshond [荷兰卷尾狮毛狗]',
|
264 |
+
262: 'Brabancon griffon [布鲁塞尔格林芬犬]',
|
265 |
+
263: 'Pembroke, Pembroke Welsh corgi [彭布洛克威尔士科基犬]',
|
266 |
+
264: 'Cardigan, Cardigan Welsh corgi [威尔士柯基犬]',
|
267 |
+
265: 'toy poodle [玩具贵宾犬]',
|
268 |
+
266: 'miniature poodle [迷你贵宾犬]',
|
269 |
+
267: 'standard poodle [标准贵宾犬]',
|
270 |
+
268: 'Mexican hairless [墨西哥无毛犬]',
|
271 |
+
269: 'timber wolf, grey wolf, gray wolf, Canis lupus [灰狼]',
|
272 |
+
270: 'white wolf, Arctic wolf, Canis lupus tundrarum [白狼,北极狼]',
|
273 |
+
271: 'red wolf, maned wolf, Canis rufus, Canis niger [红太狼,鬃狼,犬犬鲁弗斯]',
|
274 |
+
272: 'coyote, prairie wolf, brush wolf, Canis latrans [狼,草原狼,刷狼,郊狼]',
|
275 |
+
273: 'dingo, warrigal, warragal, Canis dingo [澳洲野狗,澳大利亚野犬]',
|
276 |
+
274: 'dhole, Cuon alpinus [豺]',
|
277 |
+
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus [非洲猎犬,土狼犬]',
|
278 |
+
276: 'hyena, hyaena [鬣狗]',
|
279 |
+
277: 'red fox, Vulpes vulpes [红狐狸]',
|
280 |
+
278: 'kit fox, Vulpes macrotis [沙狐]',
|
281 |
+
279: 'Arctic fox, white fox, Alopex lagopus [北极狐狸,白狐狸]',
|
282 |
+
280: 'grey fox, gray fox, Urocyon cinereoargenteus [灰狐狸]',
|
283 |
+
281: 'tabby, tabby cat [虎斑猫]',
|
284 |
+
282: 'tiger cat [山猫,虎猫]',
|
285 |
+
283: 'Persian cat [波斯猫]',
|
286 |
+
284: 'Siamese cat, Siamese [暹罗暹罗猫,]',
|
287 |
+
285: 'Egyptian cat [埃及猫]',
|
288 |
+
286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor [美洲狮,美洲豹]',
|
289 |
+
287: 'lynx, catamount [猞猁,山猫]',
|
290 |
+
288: 'leopard, Panthera pardus [豹子]',
|
291 |
+
289: 'snow leopard, ounce, Panthera uncia [雪豹]',
|
292 |
+
290: 'jaguar, panther, Panthera onca, Felis onca [美洲虎]',
|
293 |
+
291: 'lion, king of beasts, Panthera leo [狮子]',
|
294 |
+
292: 'tiger, Panthera tigris [老虎]',
|
295 |
+
293: 'cheetah, chetah, Acinonyx jubatus [猎豹]',
|
296 |
+
294: 'brown bear, bruin, Ursus arctos [棕熊]',
|
297 |
+
295: 'American black bear, black bear, Ursus americanus, Euarctos americanus [美洲黑熊]',
|
298 |
+
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus [冰熊,北极熊]',
|
299 |
+
297: 'sloth bear, Melursus ursinus, Ursus ursinus [懒熊]',
|
300 |
+
298: 'mongoose [猫鼬]',
|
301 |
+
299: 'meerkat, mierkat [猫鼬,海猫]',
|
302 |
+
300: 'tiger beetle [虎甲虫]',
|
303 |
+
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle [瓢虫]',
|
304 |
+
302: 'ground beetle, carabid beetle [土鳖虫]',
|
305 |
+
303: 'long-horned beetle, longicorn, longicorn beetle [天牛]',
|
306 |
+
304: 'leaf beetle, chrysomelid [龟甲虫]',
|
307 |
+
305: 'dung beetle [粪甲虫]',
|
308 |
+
306: 'rhinoceros beetle [犀牛甲虫]',
|
309 |
+
307: 'weevil [象甲]',
|
310 |
+
308: 'fly [苍蝇]',
|
311 |
+
309: 'bee [蜜蜂]',
|
312 |
+
310: 'ant, emmet, pismire [蚂蚁]',
|
313 |
+
311: 'grasshopper, hopper [蚱蜢]',
|
314 |
+
312: 'cricket [蟋蟀]',
|
315 |
+
313: 'walking stick, walkingstick, stick insect [竹节虫]',
|
316 |
+
314: 'cockroach, roach [蟑螂]',
|
317 |
+
315: 'mantis, mantid [螳螂]',
|
318 |
+
316: 'cicada, cicala [蝉]',
|
319 |
+
317: 'leafhopper [叶蝉]',
|
320 |
+
318: 'lacewing, lacewing fly [草蜻蛉]',
|
321 |
+
319: 'dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk [蜻蜓]',
|
322 |
+
320: 'damselfly [豆娘,蜻蛉]',
|
323 |
+
321: 'admiral [优红蛱蝶]',
|
324 |
+
322: 'ringlet, ringlet butterfly [小环蝴蝶]',
|
325 |
+
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus [君主蝴蝶,大斑蝶]',
|
326 |
+
324: 'cabbage butterfly [菜粉蝶]',
|
327 |
+
325: 'sulphur butterfly, sulfur butterfly [白蝴蝶]',
|
328 |
+
326: 'lycaenid, lycaenid butterfly [灰蝶]',
|
329 |
+
327: 'starfish, sea star [海星]',
|
330 |
+
328: 'sea urchin [海胆]',
|
331 |
+
329: 'sea cucumber, holothurian [海参,海黄瓜]',
|
332 |
+
330: 'wood rabbit, cottontail, cottontail rabbit [野兔]',
|
333 |
+
331: 'hare [兔]',
|
334 |
+
332: 'Angora, Angora rabbit [安哥拉兔]',
|
335 |
+
333: 'hamster [仓鼠]',
|
336 |
+
334: 'porcupine, hedgehog [刺猬,豪猪,]',
|
337 |
+
335: 'fox squirrel, eastern fox squirrel, Sciurus niger [黑松鼠]',
|
338 |
+
336: 'marmot [土拨鼠]',
|
339 |
+
337: 'beaver [海狸]',
|
340 |
+
338: 'guinea pig, Cavia cobaya [豚鼠,豚鼠]',
|
341 |
+
339: 'sorrel [栗色马]',
|
342 |
+
340: 'zebra [斑马]',
|
343 |
+
341: 'hog, pig, grunter, squealer, Sus scrofa [猪]',
|
344 |
+
342: 'wild boar, boar, Sus scrofa [野猪]',
|
345 |
+
343: 'warthog [疣猪]',
|
346 |
+
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius [河马]',
|
347 |
+
345: 'ox [牛]',
|
348 |
+
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis [水牛,亚洲水牛]',
|
349 |
+
347: 'bison [野牛]',
|
350 |
+
348: 'ram, tup [公羊]',
|
351 |
+
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis [大角羊,洛矶山大角羊]',
|
352 |
+
350: 'ibex, Capra ibex [山羊]',
|
353 |
+
351: 'hartebeest [狷羚]',
|
354 |
+
352: 'impala, Aepyceros melampus [黑斑羚]',
|
355 |
+
353: 'gazelle [瞪羚]',
|
356 |
+
354: 'Arabian camel, dromedary, Camelus dromedarius [阿拉伯单峰骆驼,骆驼]',
|
357 |
+
355: 'llama [羊驼]',
|
358 |
+
356: 'weasel [黄鼠狼]',
|
359 |
+
357: 'mink [水貂]',
|
360 |
+
358: 'polecat, fitch, foulmart, foumart, Mustela putorius [臭猫]',
|
361 |
+
359: 'black-footed ferret, ferret, Mustela nigripes [黑足鼬]',
|
362 |
+
360: 'otter [水獭]',
|
363 |
+
361: 'skunk, polecat, wood pussy [臭鼬,木猫]',
|
364 |
+
362: 'badger [獾]',
|
365 |
+
363: 'armadillo [犰狳]',
|
366 |
+
364: 'three-toed sloth, ai, Bradypus tridactylus [树懒]',
|
367 |
+
365: 'orangutan, orang, orangutang, Pongo pygmaeus [猩猩,婆罗洲猩猩]',
|
368 |
+
366: 'gorilla, Gorilla gorilla [大猩猩]',
|
369 |
+
367: 'chimpanzee, chimp, Pan troglodytes [黑猩猩]',
|
370 |
+
368: 'gibbon, Hylobates lar [长臂猿]',
|
371 |
+
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus [合趾猿长臂猿,合趾猿]',
|
372 |
+
370: 'guenon, guenon monkey [长尾猴]',
|
373 |
+
371: 'patas, hussar monkey, Erythrocebus patas [赤猴]',
|
374 |
+
372: 'baboon [狒狒]',
|
375 |
+
373: 'macaque [恒河猴,猕猴]',
|
376 |
+
374: 'langur [白头叶猴]',
|
377 |
+
375: 'colobus, colobus monkey [疣猴]',
|
378 |
+
376: 'proboscis monkey, Nasalis larvatus [长鼻猴]',
|
379 |
+
377: 'marmoset [狨(美洲产小型长尾猴)]',
|
380 |
+
378: 'capuchin, ringtail, Cebus capucinus [卷尾猴]',
|
381 |
+
379: 'howler monkey, howler [吼猴]',
|
382 |
+
380: 'titi, titi monkey [伶猴]',
|
383 |
+
381: 'spider monkey, Ateles geoffroyi [蜘蛛猴]',
|
384 |
+
382: 'squirrel monkey, Saimiri sciureus [松鼠猴]',
|
385 |
+
383: 'Madagascar cat, ring-tailed lemur, Lemur catta [马达加斯加环尾狐猴,鼠狐猴]',
|
386 |
+
384: 'indri, indris, Indri indri, Indri brevicaudatus [大狐猴,马达加斯加大狐猴]',
|
387 |
+
385: 'Indian elephant, Elephas maximus [印度大象,亚洲象]',
|
388 |
+
386: 'African elephant, Loxodonta africana [非洲象,非洲象]',
|
389 |
+
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens [小熊猫]',
|
390 |
+
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca [大熊猫]',
|
391 |
+
389: 'barracouta, snoek [杖鱼]',
|
392 |
+
390: 'eel [鳗鱼]',
|
393 |
+
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch [银鲑,银鲑鱼]',
|
394 |
+
392: 'rock beauty, Holocanthus tricolor [三色刺蝶鱼]',
|
395 |
+
393: 'anemone fish [海葵鱼]',
|
396 |
+
394: 'sturgeon [鲟鱼]',
|
397 |
+
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus [雀鳝]',
|
398 |
+
396: 'lionfish [狮子鱼]',
|
399 |
+
397: 'puffer, pufferfish, blowfish, globefish [河豚]',
|
400 |
+
398: 'abacus [算盘]',
|
401 |
+
399: 'abaya [长袍]',
|
402 |
+
400: 'academic gown, academic robe, judge robe [学位袍]',
|
403 |
+
401: 'accordion, piano accordion, squeeze box [手风琴]',
|
404 |
+
402: 'acoustic guitar [原声吉他]',
|
405 |
+
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier [航空母舰]',
|
406 |
+
404: 'airliner [客机]',
|
407 |
+
405: 'airship, dirigible [飞艇]',
|
408 |
+
406: 'altar [祭坛]',
|
409 |
+
407: 'ambulance [救护车]',
|
410 |
+
408: 'amphibian, amphibious vehicle [水陆两用车]',
|
411 |
+
409: 'analog clock [模拟时钟]',
|
412 |
+
410: 'apiary, bee house [蜂房]',
|
413 |
+
411: 'apron [围裙]',
|
414 |
+
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin [垃圾桶]',
|
415 |
+
413: 'assault rifle, assault gun [攻击步枪,枪]',
|
416 |
+
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack [背包]',
|
417 |
+
415: 'bakery, bakeshop, bakehouse [面包店,面包铺,]',
|
418 |
+
416: 'balance beam, beam [平衡木]',
|
419 |
+
417: 'balloon [热气球]',
|
420 |
+
418: 'ballpoint, ballpoint pen, ballpen, Biro [圆珠笔]',
|
421 |
+
419: 'Band Aid [创可贴]',
|
422 |
+
420: 'banjo [班卓琴]',
|
423 |
+
421: 'bannister, banister, balustrade, balusters, handrail [栏杆,楼梯扶手]',
|
424 |
+
422: 'barbell [杠铃]',
|
425 |
+
423: 'barber chair [理发师的椅子]',
|
426 |
+
424: 'barbershop [理发店]',
|
427 |
+
425: 'barn [牲口棚]',
|
428 |
+
426: 'barometer [晴雨表]',
|
429 |
+
427: 'barrel, cask [圆筒]',
|
430 |
+
428: 'barrow, garden cart, lawn cart, wheelbarrow [园地小车,手推车]',
|
431 |
+
429: 'baseball [棒球]',
|
432 |
+
430: 'basketball [篮球]',
|
433 |
+
431: 'bassinet [婴儿床]',
|
434 |
+
432: 'bassoon [巴松管,低音管]',
|
435 |
+
433: 'bathing cap, swimming cap [游泳帽]',
|
436 |
+
434: 'bath towel [沐浴毛巾]',
|
437 |
+
435: 'bathtub, bathing tub, bath, tub [浴缸,澡盆]',
|
438 |
+
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon [沙滩车,旅行车]',
|
439 |
+
437: 'beacon, lighthouse, beacon light, pharos [灯塔]',
|
440 |
+
438: 'beaker [高脚杯]',
|
441 |
+
439: 'bearskin, busby, shako [熊皮高帽]',
|
442 |
+
440: 'beer bottle [啤酒瓶]',
|
443 |
+
441: 'beer glass [啤酒杯]',
|
444 |
+
442: 'bell cote, bell cot [钟塔]',
|
445 |
+
443: 'bib [(小儿用的)围嘴]',
|
446 |
+
444: 'bicycle-built-for-two, tandem bicycle, tandem [串联自行车,]',
|
447 |
+
445: 'bikini, two-piece [比基尼]',
|
448 |
+
446: 'binder, ring-binder [装订册]',
|
449 |
+
447: 'binoculars, field glasses, opera glasses [双筒望远镜]',
|
450 |
+
448: 'birdhouse [鸟舍]',
|
451 |
+
449: 'boathouse [船库]',
|
452 |
+
450: 'bobsled, bobsleigh, bob [雪橇]',
|
453 |
+
451: 'bolo tie, bolo, bola tie, bola [饰扣式领带]',
|
454 |
+
452: 'bonnet, poke bonnet [阔边女帽]',
|
455 |
+
453: 'bookcase [书橱]',
|
456 |
+
454: 'bookshop, bookstore, bookstall [书店,书摊]',
|
457 |
+
455: 'bottlecap [瓶盖]',
|
458 |
+
456: 'bow [弓箭]',
|
459 |
+
457: 'bow tie, bow-tie, bowtie [蝴蝶结领结]',
|
460 |
+
458: 'brass, memorial tablet, plaque [铜制牌位]',
|
461 |
+
459: 'brassiere, bra, bandeau [奶罩]',
|
462 |
+
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty [防波堤,海堤]',
|
463 |
+
461: 'breastplate, aegis, egis [铠甲]',
|
464 |
+
462: 'broom [扫帚]',
|
465 |
+
463: 'bucket, pail [桶]',
|
466 |
+
464: 'buckle [扣环]',
|
467 |
+
465: 'bulletproof vest [防弹背心]',
|
468 |
+
466: 'bullet train, bullet [动车,子弹头列车]',
|
469 |
+
467: 'butcher shop, meat market [肉铺,肉菜市场]',
|
470 |
+
468: 'cab, hack, taxi, taxicab [出租车]',
|
471 |
+
469: 'caldron, cauldron [大锅]',
|
472 |
+
470: 'candle, taper, wax light [蜡烛]',
|
473 |
+
471: 'cannon [大炮]',
|
474 |
+
472: 'canoe [独木舟]',
|
475 |
+
473: 'can opener, tin opener [开瓶器,开罐器]',
|
476 |
+
474: 'cardigan [开衫]',
|
477 |
+
475: 'car mirror [车镜]',
|
478 |
+
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig [旋转木马]',
|
479 |
+
477: 'carpenters kit, tool kit [木匠的工具包,工具包]',
|
480 |
+
478: 'carton [纸箱]',
|
481 |
+
479: 'car wheel [车轮]',
|
482 |
+
480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM [取款机,自动取款机]',
|
483 |
+
481: 'cassette [盒式录音带]',
|
484 |
+
482: 'cassette player [卡带播放器]',
|
485 |
+
483: 'castle [城堡]',
|
486 |
+
484: 'catamaran [双体船]',
|
487 |
+
485: 'CD player [CD播放器]',
|
488 |
+
486: 'cello, violoncello [大提琴]',
|
489 |
+
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone [移动电话,手机]',
|
490 |
+
488: 'chain [铁链]',
|
491 |
+
489: 'chainlink fence [围栏]',
|
492 |
+
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour [链甲]',
|
493 |
+
491: 'chain saw, chainsaw [电锯,油锯]',
|
494 |
+
492: 'chest [箱子]',
|
495 |
+
493: 'chiffonier, commode [衣柜,洗脸台]',
|
496 |
+
494: 'chime, bell, gong [编钟,钟,锣]',
|
497 |
+
495: 'china cabinet, china closet [中国橱柜]',
|
498 |
+
496: 'Christmas stocking [圣诞袜]',
|
499 |
+
497: 'church, church building [教堂,教堂建筑]',
|
500 |
+
498: 'cinema, movie theater, movie theatre, movie house, picture palace [电影院,剧场]',
|
501 |
+
499: 'cleaver, meat cleaver, chopper [切肉刀,菜刀]',
|
502 |
+
500: 'cliff dwelling [悬崖屋]',
|
503 |
+
501: 'cloak [斗篷]',
|
504 |
+
502: 'clog, geta, patten, sabot [木屐,木鞋]',
|
505 |
+
503: 'cocktail shaker [鸡尾酒调酒器]',
|
506 |
+
504: 'coffee mug [咖啡杯]',
|
507 |
+
505: 'coffeepot [咖啡壶]',
|
508 |
+
506: 'coil, spiral, volute, whorl, helix [螺旋结构(楼梯)]',
|
509 |
+
507: 'combination lock [组合锁]',
|
510 |
+
508: 'computer keyboard, keypad [电脑键盘,键盘]',
|
511 |
+
509: 'confectionery, confectionary, candy store [糖果,糖果店]',
|
512 |
+
510: 'container ship, containership, container vessel [集装箱船]',
|
513 |
+
511: 'convertible [敞篷车]',
|
514 |
+
512: 'corkscrew, bottle screw [开瓶器,瓶螺杆]',
|
515 |
+
513: 'cornet, horn, trumpet, trump [短号,喇叭]',
|
516 |
+
514: 'cowboy boot [牛仔靴]',
|
517 |
+
515: 'cowboy hat, ten-gallon hat [牛仔帽]',
|
518 |
+
516: 'cradle [摇篮]',
|
519 |
+
517: 'crane [起重机]',
|
520 |
+
518: 'crash helmet [头盔]',
|
521 |
+
519: 'crate [板条箱]',
|
522 |
+
520: 'crib, cot [小儿床]',
|
523 |
+
521: 'Crock Pot [砂锅]',
|
524 |
+
522: 'croquet ball [槌球]',
|
525 |
+
523: 'crutch [拐杖]',
|
526 |
+
524: 'cuirass [胸甲]',
|
527 |
+
525: 'dam, dike, dyke [大坝,堤防]',
|
528 |
+
526: 'desk [书桌]',
|
529 |
+
527: 'desktop computer [台式电脑]',
|
530 |
+
528: 'dial telephone, dial phone [有线电话]',
|
531 |
+
529: 'diaper, nappy, napkin [尿布湿]',
|
532 |
+
530: 'digital clock [数字时钟]',
|
533 |
+
531: 'digital watch [数字手表]',
|
534 |
+
532: 'dining table, board [餐桌板]',
|
535 |
+
533: 'dishrag, dishcloth [抹布]',
|
536 |
+
534: 'dishwasher, dish washer, dishwashing machine [洗碗机,洗碟机]',
|
537 |
+
535: 'disk brake, disc brake [盘式制动器]',
|
538 |
+
536: 'dock, dockage, docking facility [码头,船坞,码头设施]',
|
539 |
+
537: 'dogsled, dog sled, dog sleigh [狗拉雪橇]',
|
540 |
+
538: 'dome [圆顶]',
|
541 |
+
539: 'doormat, welcome mat [门垫,垫子]',
|
542 |
+
540: 'drilling platform, offshore rig [钻井平台,海上钻井]',
|
543 |
+
541: 'drum, membranophone, tympan [鼓,乐器,鼓膜]',
|
544 |
+
542: 'drumstick [鼓槌]',
|
545 |
+
543: 'dumbbell [哑铃]',
|
546 |
+
544: 'Dutch oven [荷兰烤箱]',
|
547 |
+
545: 'electric fan, blower [电风扇,鼓风机]',
|
548 |
+
546: 'electric guitar [电吉他]',
|
549 |
+
547: 'electric locomotive [电力机车]',
|
550 |
+
548: 'entertainment center [电视,电视柜]',
|
551 |
+
549: 'envelope [信封]',
|
552 |
+
550: 'espresso maker [浓缩咖啡机]',
|
553 |
+
551: 'face powder [扑面粉]',
|
554 |
+
552: 'feather boa, boa [女用长围巾]',
|
555 |
+
553: 'file, file cabinet, filing cabinet [文件,文件柜,档案柜]',
|
556 |
+
554: 'fireboat [消防船]',
|
557 |
+
555: 'fire engine, fire truck [消防车]',
|
558 |
+
556: 'fire screen, fireguard [火炉栏]',
|
559 |
+
557: 'flagpole, flagstaff [旗杆]',
|
560 |
+
558: 'flute, transverse flute [长笛]',
|
561 |
+
559: 'folding chair [折叠椅]',
|
562 |
+
560: 'football helmet [橄榄球头盔]',
|
563 |
+
561: 'forklift [叉车]',
|
564 |
+
562: 'fountain [喷泉]',
|
565 |
+
563: 'fountain pen [钢笔]',
|
566 |
+
564: 'four-poster [有四根帷柱的床]',
|
567 |
+
565: 'freight car [运货车厢]',
|
568 |
+
566: 'French horn, horn [圆号,喇叭]',
|
569 |
+
567: 'frying pan, frypan, skillet [煎锅]',
|
570 |
+
568: 'fur coat [裘皮大衣]',
|
571 |
+
569: 'garbage truck, dustcart [垃圾车]',
|
572 |
+
570: 'gasmask, respirator, gas helmet [防毒面具,呼吸器]',
|
573 |
+
571: 'gas pump, gasoline pump, petrol pump, island dispenser [汽油泵]',
|
574 |
+
572: 'goblet [高脚杯]',
|
575 |
+
573: 'go-kart [卡丁车]',
|
576 |
+
574: 'golf ball [高尔夫球]',
|
577 |
+
575: 'golfcart, golf cart [高尔夫球车]',
|
578 |
+
576: 'gondola [狭长小船]',
|
579 |
+
577: 'gong, tam-tam [锣]',
|
580 |
+
578: 'gown [礼服]',
|
581 |
+
579: 'grand piano, grand [钢琴]',
|
582 |
+
580: 'greenhouse, nursery, glasshouse [温室,苗圃]',
|
583 |
+
581: 'grille, radiator grille [散热器格栅]',
|
584 |
+
582: 'grocery store, grocery, food market, market [杂货店,食品市场]',
|
585 |
+
583: 'guillotine [断头台]',
|
586 |
+
584: 'hair slide [小发夹]',
|
587 |
+
585: 'hair spray [头发喷雾]',
|
588 |
+
586: 'half track [半履带装甲车]',
|
589 |
+
587: 'hammer [锤子]',
|
590 |
+
588: 'hamper [大篮子]',
|
591 |
+
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier [手摇鼓风机,吹风机]',
|
592 |
+
590: 'hand-held computer, hand-held microcomputer [手提电脑]',
|
593 |
+
591: 'handkerchief, hankie, hanky, hankey [手帕]',
|
594 |
+
592: 'hard disc, hard disk, fixed disk [硬盘]',
|
595 |
+
593: 'harmonica, mouth organ, harp, mouth harp [口琴,口风琴]',
|
596 |
+
594: 'harp [竖琴]',
|
597 |
+
595: 'harvester, reaper [收割机]',
|
598 |
+
596: 'hatchet [斧头]',
|
599 |
+
597: 'holster [手枪皮套]',
|
600 |
+
598: 'home theater, home theatre [家庭影院]',
|
601 |
+
599: 'honeycomb [蜂窝]',
|
602 |
+
600: 'hook, claw [钩爪]',
|
603 |
+
601: 'hoopskirt, crinoline [衬裙]',
|
604 |
+
602: 'horizontal bar, high bar [单杠]',
|
605 |
+
603: 'horse cart, horse-cart [马车]',
|
606 |
+
604: 'hourglass [沙漏]',
|
607 |
+
605: 'iPod [手机,iPad]',
|
608 |
+
606: 'iron, smoothing iron [熨斗]',
|
609 |
+
607: 'jack-o-lantern [南瓜灯笼]',
|
610 |
+
608: 'jean, blue jean, denim [牛仔裤,蓝色牛仔裤]',
|
611 |
+
609: 'jeep, landrover [吉普车]',
|
612 |
+
610: 'jersey, T-shirt, tee shirt [运动衫,T恤]',
|
613 |
+
611: 'jigsaw puzzle [拼图]',
|
614 |
+
612: 'jinrikisha, ricksha, rickshaw [人力车]',
|
615 |
+
613: 'joystick [操纵杆]',
|
616 |
+
614: 'kimono [和服]',
|
617 |
+
615: 'knee pad [护膝]',
|
618 |
+
616: 'knot [蝴蝶结]',
|
619 |
+
617: 'lab coat, laboratory coat [大褂,实验室外套]',
|
620 |
+
618: 'ladle [长柄勺]',
|
621 |
+
619: 'lampshade, lamp shade [灯罩]',
|
622 |
+
620: 'laptop, laptop computer [笔记本电脑]',
|
623 |
+
621: 'lawn mower, mower [割草机]',
|
624 |
+
622: 'lens cap, lens cover [镜头盖]',
|
625 |
+
623: 'letter opener, paper knife, paperknife [开信刀,裁纸刀]',
|
626 |
+
624: 'library [图书馆]',
|
627 |
+
625: 'lifeboat [救生艇]',
|
628 |
+
626: 'lighter, light, igniter, ignitor [点火器,打火机]',
|
629 |
+
627: 'limousine, limo [豪华轿车]',
|
630 |
+
628: 'liner, ocean liner [远洋班轮]',
|
631 |
+
629: 'lipstick, lip rouge [唇膏,口红]',
|
632 |
+
630: 'Loafer [平底便鞋]',
|
633 |
+
631: 'lotion [洗剂]',
|
634 |
+
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system [扬声器]',
|
635 |
+
633: 'loupe, jewelers loupe [放大镜]',
|
636 |
+
634: 'lumbermill, sawmill [锯木厂]',
|
637 |
+
635: 'magnetic compass [磁罗盘]',
|
638 |
+
636: 'mailbag, postbag [邮袋]',
|
639 |
+
637: 'mailbox, letter box [信箱]',
|
640 |
+
638: 'maillot [女游泳衣]',
|
641 |
+
639: 'maillot, tank suit [有肩带浴衣]',
|
642 |
+
640: 'manhole cover [窨井盖]',
|
643 |
+
641: 'maraca [沙球(一种打击乐器)]',
|
644 |
+
642: 'marimba, xylophone [马林巴木琴]',
|
645 |
+
643: 'mask [面膜]',
|
646 |
+
644: 'matchstick [火柴]',
|
647 |
+
645: 'maypole [花柱]',
|
648 |
+
646: 'maze, labyrinth [迷宫]',
|
649 |
+
647: 'measuring cup [量杯]',
|
650 |
+
648: 'medicine chest, medicine cabinet [药箱]',
|
651 |
+
649: 'megalith, megalithic structure [巨石,巨石结构]',
|
652 |
+
650: 'microphone, mike [麦克风]',
|
653 |
+
651: 'microwave, microwave oven [微波炉]',
|
654 |
+
652: 'military uniform [军装]',
|
655 |
+
653: 'milk can [奶桶]',
|
656 |
+
654: 'minibus [迷你巴士]',
|
657 |
+
655: 'miniskirt, mini [迷你裙]',
|
658 |
+
656: 'minivan [面包车]',
|
659 |
+
657: 'missile [导弹]',
|
660 |
+
658: 'mitten [连指手套]',
|
661 |
+
659: 'mixing bowl [搅拌钵]',
|
662 |
+
660: 'mobile home, manufactured home [活动房屋(由汽车拖拉的)]',
|
663 |
+
661: 'Model T [T型发动机小汽车]',
|
664 |
+
662: 'modem [调制解调器]',
|
665 |
+
663: 'monastery [修道院]',
|
666 |
+
664: 'monitor [显示器]',
|
667 |
+
665: 'moped [电瓶车]',
|
668 |
+
666: 'mortar [砂浆]',
|
669 |
+
667: 'mortarboard [学士]',
|
670 |
+
668: 'mosque [清真寺]',
|
671 |
+
669: 'mosquito net [蚊帐]',
|
672 |
+
670: 'motor scooter, scooter [摩托车]',
|
673 |
+
671: 'mountain bike, all-terrain bike, off-roader [山地自行车]',
|
674 |
+
672: 'mountain tent [登山帐]',
|
675 |
+
673: 'mouse, computer mouse [鼠标,电脑鼠标]',
|
676 |
+
674: 'mousetrap [捕鼠器]',
|
677 |
+
675: 'moving van [搬家车]',
|
678 |
+
676: 'muzzle [口套]',
|
679 |
+
677: 'nail [钉子]',
|
680 |
+
678: 'neck brace [颈托]',
|
681 |
+
679: 'necklace [项链]',
|
682 |
+
680: 'nipple [乳头(瓶)]',
|
683 |
+
681: 'notebook, notebook computer [笔记本,笔记本电脑]',
|
684 |
+
682: 'obelisk [方尖碑]',
|
685 |
+
683: 'oboe, hautboy, hautbois [双簧管]',
|
686 |
+
684: 'ocarina, sweet potato [陶笛,卵形笛]',
|
687 |
+
685: 'odometer, hodometer, mileometer, milometer [里程表]',
|
688 |
+
686: 'oil filter [滤油器]',
|
689 |
+
687: 'organ, pipe organ [风琴,管风琴]',
|
690 |
+
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO [示波器]',
|
691 |
+
689: 'overskirt [罩裙]',
|
692 |
+
690: 'oxcart [牛车]',
|
693 |
+
691: 'oxygen mask [氧气面罩]',
|
694 |
+
692: 'packet [包装]',
|
695 |
+
693: 'paddle, boat paddle [船桨]',
|
696 |
+
694: 'paddlewheel, paddle wheel [明轮,桨轮]',
|
697 |
+
695: 'padlock [挂锁,扣锁]',
|
698 |
+
696: 'paintbrush [画笔]',
|
699 |
+
697: 'pajama, pyjama, pjs, jammies [睡衣]',
|
700 |
+
698: 'palace [宫殿]',
|
701 |
+
699: 'panpipe, pandean pipe, syrinx [排箫,鸣管]',
|
702 |
+
700: 'paper towel [纸巾]',
|
703 |
+
701: 'parachute, chute [降落伞]',
|
704 |
+
702: 'parallel bars, bars [双杠]',
|
705 |
+
703: 'park bench [公园长椅]',
|
706 |
+
704: 'parking meter [停车收费表,停车计时器]',
|
707 |
+
705: 'passenger car, coach, carriage [客车,教练车]',
|
708 |
+
706: 'patio, terrace [露台,阳台]',
|
709 |
+
707: 'pay-phone, pay-station [付费电话]',
|
710 |
+
708: 'pedestal, plinth, footstall [基座,基脚]',
|
711 |
+
709: 'pencil box, pencil case [铅笔盒]',
|
712 |
+
710: 'pencil sharpener [卷笔刀]',
|
713 |
+
711: 'perfume, essence [香水(瓶)]',
|
714 |
+
712: 'Petri dish [培养皿]',
|
715 |
+
713: 'photocopier [复印机]',
|
716 |
+
714: 'pick, plectrum, plectron [拨弦片,拨子]',
|
717 |
+
715: 'pickelhaube [尖顶头盔]',
|
718 |
+
716: 'picket fence, paling [栅栏,栅栏]',
|
719 |
+
717: 'pickup, pickup truck [皮卡,皮卡车]',
|
720 |
+
718: 'pier [桥墩]',
|
721 |
+
719: 'piggy bank, penny bank [存钱罐]',
|
722 |
+
720: 'pill bottle [药瓶]',
|
723 |
+
721: 'pillow [枕头]',
|
724 |
+
722: 'ping-pong ball [乒乓球]',
|
725 |
+
723: 'pinwheel [风车]',
|
726 |
+
724: 'pirate, pirate ship [海盗船]',
|
727 |
+
725: 'pitcher, ewer [水罐]',
|
728 |
+
726: 'plane, carpenters plane, woodworking plane [木工刨]',
|
729 |
+
727: 'planetarium [天文馆]',
|
730 |
+
728: 'plastic bag [塑料袋]',
|
731 |
+
729: 'plate rack [板架]',
|
732 |
+
730: 'plow, plough [犁型铲雪机]',
|
733 |
+
731: 'plunger, plumbers helper [手压皮碗泵]',
|
734 |
+
732: 'Polaroid camera, Polaroid Land camera [宝丽来相机]',
|
735 |
+
733: 'pole [电线杆]',
|
736 |
+
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria [警车,巡逻车]',
|
737 |
+
735: 'poncho [雨披]',
|
738 |
+
736: 'pool table, billiard table, snooker table [台球桌]',
|
739 |
+
737: 'pop bottle, soda bottle [充气饮料瓶]',
|
740 |
+
738: 'pot, flowerpot [花盆]',
|
741 |
+
739: 'potters wheel [陶工旋盘]',
|
742 |
+
740: 'power drill [电钻]',
|
743 |
+
741: 'prayer rug, prayer mat [祈祷垫,地毯]',
|
744 |
+
742: 'printer [打印机]',
|
745 |
+
743: 'prison, prison house [监狱]',
|
746 |
+
744: 'projectile, missile [炮弹,导弹]',
|
747 |
+
745: 'projector [投影仪]',
|
748 |
+
746: 'puck, hockey puck [冰球]',
|
749 |
+
747: 'punching bag, punch bag, punching ball, punchball [沙包,吊球]',
|
750 |
+
748: 'purse [钱包]',
|
751 |
+
749: 'quill, quill pen [羽管笔]',
|
752 |
+
750: 'quilt, comforter, comfort, puff [被子]',
|
753 |
+
751: 'racer, race car, racing car [赛车]',
|
754 |
+
752: 'racket, racquet [球拍]',
|
755 |
+
753: 'radiator [散热器]',
|
756 |
+
754: 'radio, wireless [收音机]',
|
757 |
+
755: 'radio telescope, radio reflector [射电望远镜,无线电反射器]',
|
758 |
+
756: 'rain barrel [雨桶]',
|
759 |
+
757: 'recreational vehicle, RV, R.V. [休闲车,房车]',
|
760 |
+
758: 'reel [卷轴,卷筒]',
|
761 |
+
759: 'reflex camera [反射式照相机]',
|
762 |
+
760: 'refrigerator, icebox [冰箱,冰柜]',
|
763 |
+
761: 'remote control, remote [遥控器]',
|
764 |
+
762: 'restaurant, eating house, eating place, eatery [餐厅,饮食店,食堂]',
|
765 |
+
763: 'revolver, six-gun, six-shooter [左轮手枪]',
|
766 |
+
764: 'rifle [步枪]',
|
767 |
+
765: 'rocking chair, rocker [摇椅]',
|
768 |
+
766: 'rotisserie [电转烤肉架]',
|
769 |
+
767: 'rubber eraser, rubber, pencil eraser [橡皮]',
|
770 |
+
768: 'rugby ball [橄榄球]',
|
771 |
+
769: 'rule, ruler [直尺]',
|
772 |
+
770: 'running shoe [跑步鞋]',
|
773 |
+
771: 'safe [保险柜]',
|
774 |
+
772: 'safety pin [安全别针]',
|
775 |
+
773: 'saltshaker, salt shaker [盐瓶(调味用)]',
|
776 |
+
774: 'sandal [凉鞋]',
|
777 |
+
775: 'sarong [纱笼,围裙]',
|
778 |
+
776: 'sax, saxophone [萨克斯管]',
|
779 |
+
777: 'scabbard [剑鞘]',
|
780 |
+
778: 'scale, weighing machine [秤,称重机]',
|
781 |
+
779: 'school bus [校车]',
|
782 |
+
780: 'schooner [帆船]',
|
783 |
+
781: 'scoreboard [记分牌]',
|
784 |
+
782: 'screen, CRT screen [屏幕]',
|
785 |
+
783: 'screw [螺丝]',
|
786 |
+
784: 'screwdriver [螺丝刀]',
|
787 |
+
785: 'seat belt, seatbelt [安全带]',
|
788 |
+
786: 'sewing machine [缝纫机]',
|
789 |
+
787: 'shield, buckler [盾牌,盾牌]',
|
790 |
+
788: 'shoe shop, shoe-shop, shoe store [皮鞋店,鞋店]',
|
791 |
+
789: 'shoji [障子]',
|
792 |
+
790: 'shopping basket [购物篮]',
|
793 |
+
791: 'shopping cart [购物车]',
|
794 |
+
792: 'shovel [铁锹]',
|
795 |
+
793: 'shower cap [浴帽]',
|
796 |
+
794: 'shower curtain [浴帘]',
|
797 |
+
795: 'ski [滑雪板]',
|
798 |
+
796: 'ski mask [滑雪面罩]',
|
799 |
+
797: 'sleeping bag [睡袋]',
|
800 |
+
798: 'slide rule, slipstick [滑尺]',
|
801 |
+
799: 'sliding door [滑动门]',
|
802 |
+
800: 'slot, one-armed bandit [角子老虎机]',
|
803 |
+
801: 'snorkel [潜水通气管]',
|
804 |
+
802: 'snowmobile [雪橇]',
|
805 |
+
803: 'snowplow, snowplough [扫雪机,扫雪机]',
|
806 |
+
804: 'soap dispenser [皂液器]',
|
807 |
+
805: 'soccer ball [足球]',
|
808 |
+
806: 'sock [袜子]',
|
809 |
+
807: 'solar dish, solar collector, solar furnace [碟式太阳能,太阳能集热器,太阳能炉]',
|
810 |
+
808: 'sombrero [宽边帽]',
|
811 |
+
809: 'soup bowl [汤碗]',
|
812 |
+
810: 'space bar [空格键]',
|
813 |
+
811: 'space heater [空间加热器]',
|
814 |
+
812: 'space shuttle [航天飞机]',
|
815 |
+
813: 'spatula [铲(搅拌或涂敷用的)]',
|
816 |
+
814: 'speedboat [快艇]',
|
817 |
+
815: 'spider web, spiders web [蜘蛛网]',
|
818 |
+
816: 'spindle [纺锤,纱锭]',
|
819 |
+
817: 'sports car, sport car [跑车]',
|
820 |
+
818: 'spotlight, spot [聚光灯]',
|
821 |
+
819: 'stage [舞台]',
|
822 |
+
820: 'steam locomotive [蒸汽机车]',
|
823 |
+
821: 'steel arch bridge [钢拱桥]',
|
824 |
+
822: 'steel drum [钢滚筒]',
|
825 |
+
823: 'stethoscope [听诊器]',
|
826 |
+
824: 'stole [女用披肩]',
|
827 |
+
825: 'stone wall [石头墙]',
|
828 |
+
826: 'stopwatch, stop watch [秒表]',
|
829 |
+
827: 'stove [火炉]',
|
830 |
+
828: 'strainer [过滤器]',
|
831 |
+
829: 'streetcar, tram, tramcar, trolley, trolley car [有轨电车,电车]',
|
832 |
+
830: 'stretcher [担架]',
|
833 |
+
831: 'studio couch, day bed [沙发床]',
|
834 |
+
832: 'stupa, tope [佛塔]',
|
835 |
+
833: 'submarine, pigboat, sub, U-boat [潜艇,潜水艇]',
|
836 |
+
834: 'suit, suit of clothes [套装,衣服]',
|
837 |
+
835: 'sundial [日晷]',
|
838 |
+
836: 'sunglass [太阳镜]',
|
839 |
+
837: 'sunglasses, dark glasses, shades [太阳镜,墨镜]',
|
840 |
+
838: 'sunscreen, sunblock, sun blocker [防晒霜,防晒剂]',
|
841 |
+
839: 'suspension bridge [悬索桥]',
|
842 |
+
840: 'swab, swob, mop [拖把]',
|
843 |
+
841: 'sweatshirt [运动衫]',
|
844 |
+
842: 'swimming trunks, bathing trunks [游泳裤]',
|
845 |
+
843: 'swing [秋千]',
|
846 |
+
844: 'switch, electric switch, electrical switch [开关,电器开关]',
|
847 |
+
845: 'syringe [注射器]',
|
848 |
+
846: 'table lamp [台灯]',
|
849 |
+
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle [坦克,装甲战车,装甲战斗车辆]',
|
850 |
+
848: 'tape player [磁带播放器]',
|
851 |
+
849: 'teapot [茶壶]',
|
852 |
+
850: 'teddy, teddy bear [泰迪,泰迪熊]',
|
853 |
+
851: 'television, television system [电视]',
|
854 |
+
852: 'tennis ball [网球]',
|
855 |
+
853: 'thatch, thatched roof [茅草,茅草屋顶]',
|
856 |
+
854: 'theater curtain, theatre curtain [幕布,剧院的帷幕]',
|
857 |
+
855: 'thimble [顶针]',
|
858 |
+
856: 'thresher, thrasher, threshing machine [脱粒机]',
|
859 |
+
857: 'throne [宝座]',
|
860 |
+
858: 'tile roof [瓦屋顶]',
|
861 |
+
859: 'toaster [烤面包机]',
|
862 |
+
860: 'tobacco shop, tobacconist shop, tobacconist [烟草店,烟草]',
|
863 |
+
861: 'toilet seat [马桶]',
|
864 |
+
862: 'torch [火炬]',
|
865 |
+
863: 'totem pole [图腾柱]',
|
866 |
+
864: 'tow truck, tow car, wrecker [拖车,牵引车,清障车]',
|
867 |
+
865: 'toyshop [玩具店]',
|
868 |
+
866: 'tractor [拖拉机]',
|
869 |
+
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi [拖车,铰接式卡车]',
|
870 |
+
868: 'tray [托盘]',
|
871 |
+
869: 'trench coat [风衣]',
|
872 |
+
870: 'tricycle, trike, velocipede [三轮车]',
|
873 |
+
871: 'trimaran [三体船]',
|
874 |
+
872: 'tripod [三脚架]',
|
875 |
+
873: 'triumphal arch [凯旋门]',
|
876 |
+
874: 'trolleybus, trolley coach, trackless trolley [无轨电车]',
|
877 |
+
875: 'trombone [长号]',
|
878 |
+
876: 'tub, vat [浴盆,浴缸]',
|
879 |
+
877: 'turnstile [旋转式栅门]',
|
880 |
+
878: 'typewriter keyboard [打字机键盘]',
|
881 |
+
879: 'umbrella [伞]',
|
882 |
+
880: 'unicycle, monocycle [独轮车]',
|
883 |
+
881: 'upright, upright piano [直立式钢琴]',
|
884 |
+
882: 'vacuum, vacuum cleaner [真空吸尘器]',
|
885 |
+
883: 'vase [花瓶]',
|
886 |
+
884: 'vault [拱顶]',
|
887 |
+
885: 'velvet [天鹅绒]',
|
888 |
+
886: 'vending machine [自动售货机]',
|
889 |
+
887: 'vestment [祭服]',
|
890 |
+
888: 'viaduct [高架桥]',
|
891 |
+
889: 'violin, fiddle [小提琴,小提琴]',
|
892 |
+
890: 'volleyball [排球]',
|
893 |
+
891: 'waffle iron [松饼机]',
|
894 |
+
892: 'wall clock [挂钟]',
|
895 |
+
893: 'wallet, billfold, notecase, pocketbook [钱包,皮夹]',
|
896 |
+
894: 'wardrobe, closet, press [衣柜,壁橱]',
|
897 |
+
895: 'warplane, military plane [军用飞机]',
|
898 |
+
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin [洗脸盆,洗手盆]',
|
899 |
+
897: 'washer, automatic washer, washing machine [洗衣机,自动洗衣机]',
|
900 |
+
898: 'water bottle [水瓶]',
|
901 |
+
899: 'water jug [水壶]',
|
902 |
+
900: 'water tower [水塔]',
|
903 |
+
901: 'whiskey jug [威士忌壶]',
|
904 |
+
902: 'whistle [哨子]',
|
905 |
+
903: 'wig [假发]',
|
906 |
+
904: 'window screen [纱窗]',
|
907 |
+
905: 'window shade [百叶窗]',
|
908 |
+
906: 'Windsor tie [温莎领带]',
|
909 |
+
907: 'wine bottle [葡萄酒瓶]',
|
910 |
+
908: 'wing [飞机翅膀,飞机]',
|
911 |
+
909: 'wok [炒菜锅]',
|
912 |
+
910: 'wooden spoon [木制的勺子]',
|
913 |
+
911: 'wool, woolen, woollen [毛织品,羊绒]',
|
914 |
+
912: 'worm fence, snake fence, snake-rail fence, Virginia fence [栅栏,围栏]',
|
915 |
+
913: 'wreck [沉船]',
|
916 |
+
914: 'yawl [双桅船]',
|
917 |
+
915: 'yurt [蒙古包]',
|
918 |
+
916: 'web site, website, internet site, site [网站,互联网网站]',
|
919 |
+
917: 'comic book [漫画]',
|
920 |
+
918: 'crossword puzzle, crossword [纵横字谜]',
|
921 |
+
919: 'street sign [路标]',
|
922 |
+
920: 'traffic light, traffic signal, stoplight [交通信号灯]',
|
923 |
+
921: 'book jacket, dust cover, dust jacket, dust wrapper [防尘罩,书皮]',
|
924 |
+
922: 'menu [菜单]',
|
925 |
+
923: 'plate [盘子]',
|
926 |
+
924: 'guacamole [鳄梨酱]',
|
927 |
+
925: 'consomme [清汤]',
|
928 |
+
926: 'hot pot, hotpot [罐焖土豆烧肉]',
|
929 |
+
927: 'trifle [蛋糕]',
|
930 |
+
928: 'ice cream, icecream [冰淇淋]',
|
931 |
+
929: 'ice lolly, lolly, lollipop, popsicle [雪糕,冰棍,冰棒]',
|
932 |
+
930: 'French loaf [法式面包]',
|
933 |
+
931: 'bagel, beigel [百吉饼]',
|
934 |
+
932: 'pretzel [椒盐脆饼]',
|
935 |
+
933: 'cheeseburger [芝士汉堡]',
|
936 |
+
934: 'hotdog, hot dog, red hot [热狗]',
|
937 |
+
935: 'mashed potato [土豆泥]',
|
938 |
+
936: 'head cabbage [结球甘蓝]',
|
939 |
+
937: 'broccoli [西兰花]',
|
940 |
+
938: 'cauliflower [菜花]',
|
941 |
+
939: 'zucchini, courgette [绿皮密生西葫芦]',
|
942 |
+
940: 'spaghetti squash [西葫芦]',
|
943 |
+
941: 'acorn squash [小青南瓜]',
|
944 |
+
942: 'butternut squash [南瓜]',
|
945 |
+
943: 'cucumber, cuke [黄瓜]',
|
946 |
+
944: 'artichoke, globe artichoke [朝鲜蓟]',
|
947 |
+
945: 'bell pepper [甜椒]',
|
948 |
+
946: 'cardoon [刺棘蓟]',
|
949 |
+
947: 'mushroom [蘑菇]',
|
950 |
+
948: 'Granny Smith [绿苹果]',
|
951 |
+
949: 'strawberry [草莓]',
|
952 |
+
950: 'orange [橘子]',
|
953 |
+
951: 'lemon [柠檬]',
|
954 |
+
952: 'fig [无花果]',
|
955 |
+
953: 'pineapple, ananas [菠萝]',
|
956 |
+
954: 'banana [香蕉]',
|
957 |
+
955: 'jackfruit, jak, jack [菠萝蜜]',
|
958 |
+
956: 'custard apple [蛋奶冻苹果]',
|
959 |
+
957: 'pomegranate [石榴]',
|
960 |
+
958: 'hay [干草]',
|
961 |
+
959: 'carbonara [烤面条加干酪沙司]',
|
962 |
+
960: 'chocolate sauce, chocolate syrup [巧克力酱,巧克力糖浆]',
|
963 |
+
961: 'dough [面团]',
|
964 |
+
962: 'meat loaf, meatloaf [瑞士肉包,肉饼]',
|
965 |
+
963: 'pizza, pizza pie [披萨,披萨饼]',
|
966 |
+
964: 'potpie [馅饼]',
|
967 |
+
965: 'burrito [卷饼]',
|
968 |
+
966: 'red wine [红葡萄酒]',
|
969 |
+
967: 'espresso [意大利浓咖啡]',
|
970 |
+
968: 'cup [杯子]',
|
971 |
+
969: 'eggnog [蛋酒]',
|
972 |
+
970: 'alp [高山]',
|
973 |
+
971: 'bubble [泡泡]',
|
974 |
+
972: 'cliff, drop, drop-off [悬崖]',
|
975 |
+
973: 'coral reef [珊瑚礁]',
|
976 |
+
974: 'geyser [间歇泉]',
|
977 |
+
975: 'lakeside, lakeshore [湖边,湖岸]',
|
978 |
+
976: 'promontory, headland, head, foreland [海角]',
|
979 |
+
977: 'sandbar, sand bar [沙洲,沙坝]',
|
980 |
+
978: 'seashore, coast, seacoast, sea-coast [海滨,海岸]',
|
981 |
+
979: 'valley, vale [峡谷]',
|
982 |
+
980: 'volcano [火山]',
|
983 |
+
981: 'ballplayer, baseball player [棒球,棒球运动员]',
|
984 |
+
982: 'groom, bridegroom [新郎]',
|
985 |
+
983: 'scuba diver [潜水员]',
|
986 |
+
984: 'rapeseed [油菜]',
|
987 |
+
985: 'daisy [雏菊]',
|
988 |
+
986: 'yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum [杓兰]',
|
989 |
+
987: 'corn [玉米]',
|
990 |
+
988: 'acorn [橡子]',
|
991 |
+
989: 'hip, rose hip, rosehip [玫瑰果]',
|
992 |
+
990: 'buckeye, horse chestnut, conker [七叶树果实]',
|
993 |
+
991: 'coral fungus [珊瑚菌]',
|
994 |
+
992: 'agaric [木耳]',
|
995 |
+
993: 'gyromitra [鹿花菌]',
|
996 |
+
994: 'stinkhorn, carrion fungus [鬼笔菌]',
|
997 |
+
995: 'earthstar [地星(菌类)]',
|
998 |
+
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa [多叶奇果菌]',
|
999 |
+
997: 'bolete [牛肝菌]',
|
1000 |
+
998: 'ear, spike, capitulum [玉米穗]',
|
1001 |
+
999: 'toilet tissue, toilet paper, bathroom tissue [卫生纸]',
|
1002 |
+
}
|
pixelflow/data_in1k.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ImageNet-1K Dataset and DataLoader
|
2 |
+
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.utils.data.distributed import DistributedSampler
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
from torchvision import transforms
|
9 |
+
from PIL import Image
|
10 |
+
import math
|
11 |
+
from functools import partial
|
12 |
+
import numpy as np
|
13 |
+
import random
|
14 |
+
|
15 |
+
from diffusers.models.embeddings import get_2d_rotary_pos_embed
|
16 |
+
|
17 |
+
# https://github.com/facebookresearch/DiT/blob/main/train.py#L85
|
18 |
+
def center_crop_arr(pil_image, image_size):
|
19 |
+
while min(*pil_image.size) >= 2 * image_size:
|
20 |
+
pil_image = pil_image.resize(
|
21 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
22 |
+
)
|
23 |
+
|
24 |
+
scale = image_size / min(*pil_image.size)
|
25 |
+
pil_image = pil_image.resize(
|
26 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
27 |
+
)
|
28 |
+
|
29 |
+
arr = np.array(pil_image)
|
30 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
31 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
32 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
33 |
+
|
34 |
+
|
35 |
+
def collate_fn(examples, config, noise_scheduler_copy):
|
36 |
+
patch_size = config.model.params.patch_size
|
37 |
+
pixel_values = torch.stack([eg[0] for eg in examples])
|
38 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
39 |
+
input_ids = [eg[1] for eg in examples]
|
40 |
+
|
41 |
+
batch_size = len(examples)
|
42 |
+
stage_indices = list(range(config.scheduler.num_stages)) * (batch_size // config.scheduler.num_stages + 1)
|
43 |
+
stage_indices = stage_indices[:batch_size]
|
44 |
+
|
45 |
+
random.shuffle(stage_indices)
|
46 |
+
stage_indices = torch.tensor(stage_indices, dtype=torch.int32)
|
47 |
+
orig_height, orig_width = pixel_values.shape[-2:]
|
48 |
+
timesteps = torch.randint(0, config.scheduler.num_train_timesteps, (batch_size,))
|
49 |
+
|
50 |
+
sample_list, input_ids_list, pos_embed_list, seq_len_list, target_list, timestep_list = [], [], [], [], [], []
|
51 |
+
for stage_idx in range(config.scheduler.num_stages):
|
52 |
+
corrected_stage_idx = config.scheduler.num_stages - stage_idx - 1
|
53 |
+
stage_select_indices = timesteps[stage_indices == corrected_stage_idx]
|
54 |
+
Timesteps = noise_scheduler_copy.Timesteps_per_stage[corrected_stage_idx][stage_select_indices].float()
|
55 |
+
batch_size_select = Timesteps.shape[0]
|
56 |
+
pixel_values_select = pixel_values[stage_indices == corrected_stage_idx]
|
57 |
+
input_ids_select = [input_ids[i] for i in range(batch_size) if stage_indices[i] == corrected_stage_idx]
|
58 |
+
|
59 |
+
end_height, end_width = orig_height // (2 ** stage_idx), orig_width // (2 ** stage_idx)
|
60 |
+
|
61 |
+
################ build model input ################
|
62 |
+
start_t, end_t = noise_scheduler_copy.start_t[corrected_stage_idx], noise_scheduler_copy.end_t[corrected_stage_idx]
|
63 |
+
|
64 |
+
pixel_values_end = pixel_values_select
|
65 |
+
pixel_values_start = pixel_values_select
|
66 |
+
if stage_idx > 0:
|
67 |
+
# pixel_values_end
|
68 |
+
for downsample_idx in range(1, stage_idx + 1):
|
69 |
+
pixel_values_end = F.interpolate(pixel_values_end, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
|
70 |
+
|
71 |
+
# pixel_values_start
|
72 |
+
for downsample_idx in range(1, stage_idx + 2):
|
73 |
+
pixel_values_start = F.interpolate(pixel_values_start, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
|
74 |
+
# upsample pixel_values_start
|
75 |
+
pixel_values_start = F.interpolate(pixel_values_start, (end_height, end_width), mode="nearest")
|
76 |
+
|
77 |
+
noise = torch.randn_like(pixel_values_end)
|
78 |
+
pixel_values_end = end_t * pixel_values_end + (1.0 - end_t) * noise
|
79 |
+
pixel_values_start = start_t * pixel_values_start + (1.0 - start_t) * noise
|
80 |
+
target = pixel_values_end - pixel_values_start
|
81 |
+
|
82 |
+
t_select = noise_scheduler_copy.t_window_per_stage[corrected_stage_idx][stage_select_indices].flatten()
|
83 |
+
while len(t_select.shape) < pixel_values_start.ndim:
|
84 |
+
t_select = t_select.unsqueeze(-1)
|
85 |
+
xt = t_select.float() * pixel_values_end + (1.0 - t_select.float()) * pixel_values_start
|
86 |
+
|
87 |
+
target = rearrange(target, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
|
88 |
+
xt = rearrange(xt, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
|
89 |
+
|
90 |
+
pos_embed = get_2d_rotary_pos_embed(
|
91 |
+
embed_dim=config.model.params.attention_head_dim,
|
92 |
+
crops_coords=((0, 0), (end_height // patch_size, end_width // patch_size)),
|
93 |
+
grid_size=(end_height // patch_size, end_width // patch_size),
|
94 |
+
)
|
95 |
+
seq_len = (end_height // patch_size) * (end_width // patch_size)
|
96 |
+
assert end_height == end_width, f"only support square image, got {seq_len}; TODO: latent_size_list"
|
97 |
+
sample_list.append(xt)
|
98 |
+
target_list.append(target)
|
99 |
+
pos_embed_list.extend([pos_embed] * batch_size_select)
|
100 |
+
seq_len_list.extend([seq_len] * batch_size_select)
|
101 |
+
timestep_list.append(Timesteps)
|
102 |
+
input_ids_list.extend(input_ids_select)
|
103 |
+
|
104 |
+
pixel_values = torch.cat(sample_list, dim=0).to(memory_format=torch.contiguous_format)
|
105 |
+
target_values = torch.cat(target_list, dim=0).to(memory_format=torch.contiguous_format)
|
106 |
+
pos_embed = torch.cat([torch.stack(one_pos_emb, -1) for one_pos_emb in pos_embed_list], dim=0).float()
|
107 |
+
cumsum_q_len = torch.cumsum(torch.tensor([0] + seq_len_list), 0).to(torch.int32)
|
108 |
+
latent_size_list = torch.tensor([int(math.sqrt(seq_len)) for seq_len in seq_len_list], dtype=torch.int32)
|
109 |
+
|
110 |
+
return {
|
111 |
+
"pixel_values": pixel_values,
|
112 |
+
"input_ids": input_ids_list,
|
113 |
+
"pos_embed": pos_embed,
|
114 |
+
"cumsum_q_len": cumsum_q_len,
|
115 |
+
"batch_latent_size": latent_size_list,
|
116 |
+
"seqlen_list_q": seq_len_list,
|
117 |
+
"cumsum_kv_len": None,
|
118 |
+
"batch_kv_len": None,
|
119 |
+
"timesteps": torch.cat(timestep_list, dim=0),
|
120 |
+
"target_values": target_values,
|
121 |
+
}
|
122 |
+
|
123 |
+
|
124 |
+
def build_imagenet_loader(config, noise_scheduler_copy):
|
125 |
+
if config.data.center_crop:
|
126 |
+
transform = transforms.Compose([
|
127 |
+
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config.data.resolution)),
|
128 |
+
transforms.RandomHorizontalFlip(),
|
129 |
+
transforms.ToTensor(),
|
130 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
131 |
+
])
|
132 |
+
else:
|
133 |
+
transform = transforms.Compose([
|
134 |
+
transforms.Resize(round(config.data.resolution * config.data.expand_ratio), interpolation=transforms.InterpolationMode.LANCZOS),
|
135 |
+
transforms.RandomCrop(config.data.resolution),
|
136 |
+
transforms.RandomHorizontalFlip(),
|
137 |
+
transforms.ToTensor(),
|
138 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
139 |
+
])
|
140 |
+
dataset = ImageFolder(config.data.root, transform=transform)
|
141 |
+
sampler = DistributedSampler(
|
142 |
+
dataset,
|
143 |
+
num_replicas=torch.distributed.get_world_size(),
|
144 |
+
rank=torch.distributed.get_rank(),
|
145 |
+
shuffle=True,
|
146 |
+
seed=config.seed,
|
147 |
+
)
|
148 |
+
|
149 |
+
loader = torch.utils.data.DataLoader(
|
150 |
+
dataset,
|
151 |
+
batch_size=config.data.batch_size,
|
152 |
+
collate_fn=partial(collate_fn, config=config, noise_scheduler_copy=noise_scheduler_copy),
|
153 |
+
shuffle=False,
|
154 |
+
sampler=sampler,
|
155 |
+
num_workers=config.data.num_workers,
|
156 |
+
drop_last=True,
|
157 |
+
)
|
158 |
+
return loader
|
pixelflow/model.py
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
try:
|
10 |
+
from flash_attn import flash_attn_varlen_func
|
11 |
+
except ImportError:
|
12 |
+
warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning)
|
13 |
+
flash_attn_varlen_func = None
|
14 |
+
|
15 |
+
|
16 |
+
def apply_rotary_emb(
|
17 |
+
x: torch.Tensor,
|
18 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
19 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
20 |
+
cos, sin = freqs_cis.unbind(-1)
|
21 |
+
cos = cos[None, None]
|
22 |
+
sin = sin[None, None]
|
23 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
24 |
+
|
25 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
26 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
27 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
28 |
+
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
class PatchEmbed(nn.Module):
|
33 |
+
def __init__(self, patch_size, in_channels, embed_dim, bias=True):
|
34 |
+
super().__init__()
|
35 |
+
self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
|
36 |
+
|
37 |
+
def forward_unfold(self, x):
|
38 |
+
out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t())
|
39 |
+
if self.proj.bias is not None:
|
40 |
+
out_unfold += self.proj.bias.to(out_unfold.dtype)
|
41 |
+
return out_unfold
|
42 |
+
|
43 |
+
# force fp32 for strict numerical reproducibility (debug only)
|
44 |
+
# @torch.autocast('cuda', enabled=False)
|
45 |
+
def forward(self, x):
|
46 |
+
if self.training:
|
47 |
+
return self.forward_unfold(x)
|
48 |
+
out = self.proj(x)
|
49 |
+
out = out.flatten(2).transpose(1, 2) # BCHW -> BNC
|
50 |
+
|
51 |
+
return out
|
52 |
+
|
53 |
+
class AdaLayerNorm(nn.Module):
|
54 |
+
def __init__(self, embedding_dim):
|
55 |
+
super().__init__()
|
56 |
+
self.embedding_dim = embedding_dim
|
57 |
+
self.silu = nn.SiLU()
|
58 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
59 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
60 |
+
|
61 |
+
def forward(self, x, timestep, seqlen_list=None):
|
62 |
+
input_dtype = x.dtype
|
63 |
+
emb = self.linear(self.silu(timestep))
|
64 |
+
|
65 |
+
if seqlen_list is not None:
|
66 |
+
# equivalent to `torch.repeat_interleave` but faster
|
67 |
+
emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)])
|
68 |
+
else:
|
69 |
+
emb = emb.unsqueeze(1)
|
70 |
+
|
71 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1)
|
72 |
+
x = self.norm(x).float() * (1 + scale_msa) + shift_msa
|
73 |
+
return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp
|
74 |
+
|
75 |
+
|
76 |
+
class FeedForward(nn.Module):
|
77 |
+
def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True):
|
78 |
+
super().__init__()
|
79 |
+
inner_dim = int(dim * mult) if inner_dim is None else inner_dim
|
80 |
+
dim_out = dim_out if dim_out is not None else dim
|
81 |
+
self.fc1 = nn.Linear(dim, inner_dim, bias=bias)
|
82 |
+
self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias)
|
83 |
+
|
84 |
+
def forward(self, hidden_states):
|
85 |
+
hidden_states = self.fc1(hidden_states)
|
86 |
+
hidden_states = F.gelu(hidden_states, approximate="tanh")
|
87 |
+
hidden_states = self.fc2(hidden_states)
|
88 |
+
return hidden_states
|
89 |
+
|
90 |
+
|
91 |
+
class RMSNorm(nn.Module):
|
92 |
+
def __init__(self, dim: int, eps=1e-6):
|
93 |
+
super().__init__()
|
94 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
95 |
+
self.eps = eps
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
|
99 |
+
return (self.weight * output).to(x.dtype)
|
100 |
+
|
101 |
+
|
102 |
+
class Attention(nn.Module):
|
103 |
+
def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False):
|
104 |
+
super().__init__()
|
105 |
+
self.q_dim = q_dim
|
106 |
+
self.kv_dim = kv_dim if kv_dim is not None else q_dim
|
107 |
+
self.inner_dim = head_dim * heads
|
108 |
+
self.dropout = dropout
|
109 |
+
self.head_dim = head_dim
|
110 |
+
self.num_heads = heads
|
111 |
+
|
112 |
+
self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias)
|
113 |
+
self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
|
114 |
+
self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
|
115 |
+
|
116 |
+
self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias)
|
117 |
+
|
118 |
+
self.q_norm = RMSNorm(self.inner_dim)
|
119 |
+
self.k_norm = RMSNorm(self.inner_dim)
|
120 |
+
|
121 |
+
def prepare_attention_mask(
|
122 |
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L694
|
123 |
+
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
124 |
+
):
|
125 |
+
head_size = self.num_heads
|
126 |
+
if attention_mask is None:
|
127 |
+
return attention_mask
|
128 |
+
|
129 |
+
current_length: int = attention_mask.shape[-1]
|
130 |
+
if current_length != target_length:
|
131 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
132 |
+
|
133 |
+
if out_dim == 3:
|
134 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
135 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
136 |
+
elif out_dim == 4:
|
137 |
+
attention_mask = attention_mask.unsqueeze(1)
|
138 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
139 |
+
|
140 |
+
return attention_mask
|
141 |
+
|
142 |
+
def forward(
|
143 |
+
self,
|
144 |
+
inputs_q,
|
145 |
+
inputs_kv,
|
146 |
+
attention_mask=None,
|
147 |
+
cross_attention=False,
|
148 |
+
rope_pos_embed=None,
|
149 |
+
cu_seqlens_q=None,
|
150 |
+
cu_seqlens_k=None,
|
151 |
+
max_seqlen_q=None,
|
152 |
+
max_seqlen_k=None,
|
153 |
+
):
|
154 |
+
|
155 |
+
inputs_kv = inputs_q if inputs_kv is None else inputs_kv
|
156 |
+
|
157 |
+
query_states = self.q_proj(inputs_q)
|
158 |
+
key_states = self.k_proj(inputs_kv)
|
159 |
+
value_states = self.v_proj(inputs_kv)
|
160 |
+
|
161 |
+
query_states = self.q_norm(query_states)
|
162 |
+
key_states = self.k_norm(key_states)
|
163 |
+
|
164 |
+
if max_seqlen_q is None:
|
165 |
+
assert not self.training, "PixelFlow needs sequence packing for training"
|
166 |
+
|
167 |
+
bsz, q_len, _ = inputs_q.shape
|
168 |
+
_, kv_len, _ = inputs_kv.shape
|
169 |
+
|
170 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
171 |
+
key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
172 |
+
value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
173 |
+
|
174 |
+
query_states = apply_rotary_emb(query_states, rope_pos_embed)
|
175 |
+
if not cross_attention:
|
176 |
+
key_states = apply_rotary_emb(key_states, rope_pos_embed)
|
177 |
+
|
178 |
+
if attention_mask is not None:
|
179 |
+
attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz)
|
180 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
181 |
+
# (batch, heads, source_length, target_length)
|
182 |
+
attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1])
|
183 |
+
|
184 |
+
# with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): # strict numerical reproducibility (debug only)
|
185 |
+
attn_output = F.scaled_dot_product_attention(
|
186 |
+
query_states,
|
187 |
+
key_states,
|
188 |
+
value_states,
|
189 |
+
attn_mask=attention_mask,
|
190 |
+
dropout_p=self.dropout if self.training else 0.0,
|
191 |
+
is_causal=False,
|
192 |
+
)
|
193 |
+
|
194 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
195 |
+
attn_output = attn_output.view(bsz, q_len, self.inner_dim)
|
196 |
+
attn_output = self.o_proj(attn_output)
|
197 |
+
return attn_output
|
198 |
+
|
199 |
+
else:
|
200 |
+
# sequence packing mode
|
201 |
+
query_states = query_states.view(-1, self.num_heads, self.head_dim)
|
202 |
+
key_states = key_states.view(-1, self.num_heads, self.head_dim)
|
203 |
+
value_states = value_states.view(-1, self.num_heads, self.head_dim)
|
204 |
+
|
205 |
+
query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
|
206 |
+
if not cross_attention:
|
207 |
+
key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
|
208 |
+
|
209 |
+
attn_output = flash_attn_varlen_func(
|
210 |
+
query_states,
|
211 |
+
key_states,
|
212 |
+
value_states,
|
213 |
+
cu_seqlens_q=cu_seqlens_q,
|
214 |
+
cu_seqlens_k=cu_seqlens_k,
|
215 |
+
max_seqlen_q=max_seqlen_q,
|
216 |
+
max_seqlen_k=max_seqlen_k,
|
217 |
+
)
|
218 |
+
|
219 |
+
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
|
220 |
+
attn_output = self.o_proj(attn_output)
|
221 |
+
return attn_output
|
222 |
+
|
223 |
+
|
224 |
+
class TransformerBlock(nn.Module):
|
225 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, dropout=0.0,
|
226 |
+
cross_attention_dim=None, attention_bias=False,
|
227 |
+
):
|
228 |
+
super().__init__()
|
229 |
+
self.norm1 = AdaLayerNorm(dim)
|
230 |
+
|
231 |
+
# Self Attention
|
232 |
+
self.attn1 = Attention(q_dim=dim, kv_dim=None, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
|
233 |
+
|
234 |
+
if cross_attention_dim is not None:
|
235 |
+
# Cross Attention
|
236 |
+
self.norm2 = RMSNorm(dim, eps=1e-6)
|
237 |
+
self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
|
238 |
+
else:
|
239 |
+
self.attn2 = None
|
240 |
+
|
241 |
+
self.norm3 = RMSNorm(dim, eps=1e-6)
|
242 |
+
self.mlp = FeedForward(dim)
|
243 |
+
|
244 |
+
def forward(
|
245 |
+
self,
|
246 |
+
hidden_states,
|
247 |
+
encoder_hidden_states=None,
|
248 |
+
encoder_attention_mask=None,
|
249 |
+
timestep=None,
|
250 |
+
rope_pos_embed=None,
|
251 |
+
cu_seqlens_q=None,
|
252 |
+
cu_seqlens_k=None,
|
253 |
+
seqlen_list_q=None,
|
254 |
+
seqlen_list_k=None,
|
255 |
+
):
|
256 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q)
|
257 |
+
|
258 |
+
attn_output = self.attn1(
|
259 |
+
inputs_q=norm_hidden_states,
|
260 |
+
inputs_kv=None,
|
261 |
+
attention_mask=None,
|
262 |
+
cross_attention=False,
|
263 |
+
rope_pos_embed=rope_pos_embed,
|
264 |
+
cu_seqlens_q=cu_seqlens_q,
|
265 |
+
cu_seqlens_k=cu_seqlens_q,
|
266 |
+
max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
|
267 |
+
max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None,
|
268 |
+
)
|
269 |
+
|
270 |
+
attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype)
|
271 |
+
hidden_states = attn_output + hidden_states
|
272 |
+
|
273 |
+
if self.attn2 is not None:
|
274 |
+
norm_hidden_states = self.norm2(hidden_states)
|
275 |
+
attn_output = self.attn2(
|
276 |
+
inputs_q=norm_hidden_states,
|
277 |
+
inputs_kv=encoder_hidden_states,
|
278 |
+
attention_mask=encoder_attention_mask,
|
279 |
+
cross_attention=True,
|
280 |
+
rope_pos_embed=rope_pos_embed,
|
281 |
+
cu_seqlens_q=cu_seqlens_q,
|
282 |
+
cu_seqlens_k=cu_seqlens_k,
|
283 |
+
max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
|
284 |
+
max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None,
|
285 |
+
)
|
286 |
+
hidden_states = hidden_states + attn_output
|
287 |
+
|
288 |
+
norm_hidden_states = self.norm3(hidden_states)
|
289 |
+
norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype)
|
290 |
+
ff_output = self.mlp(norm_hidden_states)
|
291 |
+
ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype)
|
292 |
+
hidden_states = ff_output + hidden_states
|
293 |
+
|
294 |
+
return hidden_states
|
295 |
+
|
296 |
+
|
297 |
+
class PixelFlowModel(torch.nn.Module):
|
298 |
+
def __init__(self, in_channels, out_channels, num_attention_heads, attention_head_dim,
|
299 |
+
depth, patch_size, dropout=0.0, cross_attention_dim=None, attention_bias=True, num_classes=0,
|
300 |
+
):
|
301 |
+
super().__init__()
|
302 |
+
self.patch_size = patch_size
|
303 |
+
self.attention_head_dim = attention_head_dim
|
304 |
+
self.num_classes = num_classes
|
305 |
+
self.out_channels = out_channels
|
306 |
+
|
307 |
+
embed_dim = num_attention_heads * attention_head_dim
|
308 |
+
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
|
309 |
+
|
310 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
311 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
|
312 |
+
|
313 |
+
# [stage] embedding
|
314 |
+
self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
|
315 |
+
if self.num_classes > 0:
|
316 |
+
# class conditional
|
317 |
+
self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1)
|
318 |
+
|
319 |
+
self.transformer_blocks = nn.ModuleList([
|
320 |
+
TransformerBlock(embed_dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, attention_bias) for _ in range(depth)
|
321 |
+
])
|
322 |
+
|
323 |
+
self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
324 |
+
self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim)
|
325 |
+
self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels)
|
326 |
+
|
327 |
+
self.initialize_from_scratch()
|
328 |
+
|
329 |
+
def initialize_from_scratch(self):
|
330 |
+
print("Starting Initialization...")
|
331 |
+
def _basic_init(module):
|
332 |
+
if isinstance(module, nn.Linear):
|
333 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
334 |
+
if module.bias is not None:
|
335 |
+
nn.init.constant_(module.bias, 0)
|
336 |
+
self.apply(_basic_init)
|
337 |
+
|
338 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
339 |
+
w = self.patch_embed.proj.weight.data
|
340 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
341 |
+
nn.init.constant_(self.patch_embed.proj.bias, 0)
|
342 |
+
|
343 |
+
nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02)
|
344 |
+
nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02)
|
345 |
+
|
346 |
+
nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02)
|
347 |
+
nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02)
|
348 |
+
|
349 |
+
if self.num_classes > 0:
|
350 |
+
nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02)
|
351 |
+
|
352 |
+
for block in self.transformer_blocks:
|
353 |
+
nn.init.constant_(block.norm1.linear.weight, 0)
|
354 |
+
nn.init.constant_(block.norm1.linear.bias, 0)
|
355 |
+
|
356 |
+
nn.init.constant_(self.proj_out_1.weight, 0)
|
357 |
+
nn.init.constant_(self.proj_out_1.bias, 0)
|
358 |
+
nn.init.constant_(self.proj_out_2.weight, 0)
|
359 |
+
nn.init.constant_(self.proj_out_2.bias, 0)
|
360 |
+
|
361 |
+
def forward(
|
362 |
+
self,
|
363 |
+
hidden_states,
|
364 |
+
encoder_hidden_states=None,
|
365 |
+
class_labels=None,
|
366 |
+
timestep=None,
|
367 |
+
latent_size=None,
|
368 |
+
encoder_attention_mask=None,
|
369 |
+
pos_embed=None,
|
370 |
+
cu_seqlens_q=None,
|
371 |
+
cu_seqlens_k=None,
|
372 |
+
seqlen_list_q=None,
|
373 |
+
seqlen_list_k=None,
|
374 |
+
):
|
375 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
376 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
377 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
378 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
379 |
+
|
380 |
+
orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1]
|
381 |
+
hidden_states = hidden_states.to(torch.float32)
|
382 |
+
hidden_states = self.patch_embed(hidden_states)
|
383 |
+
|
384 |
+
# timestep, class_embed, latent_size_embed
|
385 |
+
timesteps_proj = self.time_proj(timestep)
|
386 |
+
conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
387 |
+
|
388 |
+
if self.num_classes > 0:
|
389 |
+
class_embed = self.class_embedder(class_labels)
|
390 |
+
conditioning += class_embed
|
391 |
+
|
392 |
+
latent_size_proj = self.time_proj(latent_size)
|
393 |
+
latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype))
|
394 |
+
conditioning += latent_size_embed
|
395 |
+
|
396 |
+
for block in self.transformer_blocks:
|
397 |
+
hidden_states = block(
|
398 |
+
hidden_states,
|
399 |
+
encoder_hidden_states=encoder_hidden_states,
|
400 |
+
encoder_attention_mask=encoder_attention_mask,
|
401 |
+
timestep=conditioning,
|
402 |
+
rope_pos_embed=pos_embed,
|
403 |
+
cu_seqlens_q=cu_seqlens_q,
|
404 |
+
cu_seqlens_k=cu_seqlens_k,
|
405 |
+
seqlen_list_q=seqlen_list_q,
|
406 |
+
seqlen_list_k=seqlen_list_k,
|
407 |
+
)
|
408 |
+
|
409 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1)
|
410 |
+
if seqlen_list_q is None:
|
411 |
+
shift = shift.unsqueeze(1)
|
412 |
+
scale = scale.unsqueeze(1)
|
413 |
+
else:
|
414 |
+
shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)])
|
415 |
+
scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)])
|
416 |
+
|
417 |
+
hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype)
|
418 |
+
hidden_states = self.proj_out_2(hidden_states)
|
419 |
+
if self.training:
|
420 |
+
hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels)
|
421 |
+
hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1)
|
422 |
+
return hidden_states
|
423 |
+
|
424 |
+
height, width = orig_height // self.patch_size, orig_width // self.patch_size
|
425 |
+
hidden_states = hidden_states.reshape(
|
426 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
427 |
+
)
|
428 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
429 |
+
output = hidden_states.reshape(
|
430 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
431 |
+
)
|
432 |
+
|
433 |
+
return output
|
434 |
+
|
435 |
+
def c2i_forward_cfg_torchdiffq(self, hidden_states, timestep, class_labels, latent_size, pos_embed, cfg_scale):
|
436 |
+
# used for evaluation with ODE ('dopri5') solver from torchdiffeq
|
437 |
+
half = hidden_states[: len(hidden_states)//2]
|
438 |
+
combined = torch.cat([half, half], dim=0)
|
439 |
+
out = self.forward(
|
440 |
+
hidden_states=combined,
|
441 |
+
timestep=timestep,
|
442 |
+
class_labels=class_labels,
|
443 |
+
latent_size=latent_size,
|
444 |
+
pos_embed=pos_embed,
|
445 |
+
)
|
446 |
+
uncond_out, cond_out = torch.split(out, len(out)//2, dim=0)
|
447 |
+
half_output = uncond_out + cfg_scale * (cond_out - uncond_out)
|
448 |
+
output = torch.cat([half_output, half_output], dim=0)
|
449 |
+
return output
|
pixelflow/pipeline_pixelflow.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from einops import rearrange
|
2 |
+
import math
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.utils.torch_utils import randn_tensor
|
9 |
+
from diffusers.models.embeddings import get_2d_rotary_pos_embed
|
10 |
+
|
11 |
+
|
12 |
+
class PixelFlowPipeline:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
scheduler,
|
16 |
+
transformer,
|
17 |
+
text_encoder=None,
|
18 |
+
tokenizer=None,
|
19 |
+
max_token_length=512,
|
20 |
+
):
|
21 |
+
super().__init__()
|
22 |
+
self.class_cond = text_encoder is None or tokenizer is None
|
23 |
+
self.scheduler = scheduler
|
24 |
+
self.transformer = transformer
|
25 |
+
self.patch_size = transformer.patch_size
|
26 |
+
self.head_dim = transformer.attention_head_dim
|
27 |
+
self.num_stages = scheduler.num_stages
|
28 |
+
|
29 |
+
self.text_encoder = text_encoder
|
30 |
+
self.tokenizer = tokenizer
|
31 |
+
self.max_token_length = max_token_length
|
32 |
+
|
33 |
+
@torch.autocast("cuda", enabled=False)
|
34 |
+
def encode_prompt(
|
35 |
+
self,
|
36 |
+
prompt: Union[str, List[str]],
|
37 |
+
device: Optional[torch.device] = None,
|
38 |
+
num_images_per_prompt: int = 1,
|
39 |
+
do_classifier_free_guidance: bool = True,
|
40 |
+
negative_prompt: Union[str, List[str]] = "",
|
41 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
42 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
43 |
+
prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
44 |
+
negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
|
45 |
+
use_attention_mask: bool = False,
|
46 |
+
max_length: int = 512,
|
47 |
+
):
|
48 |
+
# Determine the batch size and normalize prompt input to a list
|
49 |
+
if prompt is not None:
|
50 |
+
if isinstance(prompt, str):
|
51 |
+
prompt = [prompt]
|
52 |
+
batch_size = len(prompt)
|
53 |
+
else:
|
54 |
+
batch_size = prompt_embeds.shape[0]
|
55 |
+
|
56 |
+
# Process prompt embeddings if not provided
|
57 |
+
if prompt_embeds is None:
|
58 |
+
text_inputs = self.tokenizer(
|
59 |
+
prompt,
|
60 |
+
padding="max_length",
|
61 |
+
max_length=max_length,
|
62 |
+
truncation=True,
|
63 |
+
add_special_tokens=True,
|
64 |
+
return_tensors="pt",
|
65 |
+
)
|
66 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
67 |
+
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
68 |
+
prompt_embeds = self.text_encoder(
|
69 |
+
text_input_ids,
|
70 |
+
attention_mask=prompt_attention_mask if use_attention_mask else None
|
71 |
+
)[0]
|
72 |
+
|
73 |
+
# Determine dtype from available encoder
|
74 |
+
if self.text_encoder is not None:
|
75 |
+
dtype = self.text_encoder.dtype
|
76 |
+
elif self.transformer is not None:
|
77 |
+
dtype = self.transformer.dtype
|
78 |
+
else:
|
79 |
+
dtype = None
|
80 |
+
|
81 |
+
# Move prompt embeddings to desired dtype and device
|
82 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
83 |
+
|
84 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
85 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
86 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
87 |
+
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
|
88 |
+
|
89 |
+
# Handle classifier-free guidance for negative prompts
|
90 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
91 |
+
# Normalize negative prompt to list and validate length
|
92 |
+
if isinstance(negative_prompt, str):
|
93 |
+
uncond_tokens = [negative_prompt] * batch_size
|
94 |
+
elif isinstance(negative_prompt, list):
|
95 |
+
if len(negative_prompt) != batch_size:
|
96 |
+
raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}")
|
97 |
+
uncond_tokens = negative_prompt
|
98 |
+
else:
|
99 |
+
raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}")
|
100 |
+
|
101 |
+
# Tokenize and encode negative prompts
|
102 |
+
uncond_inputs = self.tokenizer(
|
103 |
+
uncond_tokens,
|
104 |
+
padding="max_length",
|
105 |
+
max_length=prompt_embeds.shape[1],
|
106 |
+
truncation=True,
|
107 |
+
return_attention_mask=True,
|
108 |
+
add_special_tokens=True,
|
109 |
+
return_tensors="pt",
|
110 |
+
)
|
111 |
+
negative_input_ids = uncond_inputs.input_ids.to(device)
|
112 |
+
negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device)
|
113 |
+
negative_prompt_embeds = self.text_encoder(
|
114 |
+
negative_input_ids,
|
115 |
+
attention_mask=negative_prompt_attention_mask if use_attention_mask else None
|
116 |
+
)[0]
|
117 |
+
|
118 |
+
if do_classifier_free_guidance:
|
119 |
+
# Duplicate negative prompt embeddings and attention mask for each generation
|
120 |
+
seq_len_neg = negative_prompt_embeds.shape[1]
|
121 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
122 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
123 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
|
124 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
|
125 |
+
else:
|
126 |
+
negative_prompt_embeds = None
|
127 |
+
negative_prompt_attention_mask = None
|
128 |
+
|
129 |
+
# Concatenate negative and positive embeddings and their masks
|
130 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
131 |
+
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
132 |
+
|
133 |
+
return prompt_embeds, prompt_attention_mask
|
134 |
+
|
135 |
+
def sample_block_noise(self, bs, ch, height, width, eps=1e-6)):
|
136 |
+
gamma = self.scheduler.gamma
|
137 |
+
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4))
|
138 |
+
block_number = bs * ch * (height // 2) * (width // 2)
|
139 |
+
noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
|
140 |
+
noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2)
|
141 |
+
return noise
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def __call__(
|
145 |
+
self,
|
146 |
+
prompt,
|
147 |
+
height,
|
148 |
+
width,
|
149 |
+
num_inference_steps=30,
|
150 |
+
guidance_scale=4.0,
|
151 |
+
num_images_per_prompt=1,
|
152 |
+
device=None,
|
153 |
+
shift=1.0,
|
154 |
+
use_ode_dopri5=False,
|
155 |
+
):
|
156 |
+
if isinstance(num_inference_steps, int):
|
157 |
+
num_inference_steps = [num_inference_steps] * self.num_stages
|
158 |
+
|
159 |
+
if use_ode_dopri5:
|
160 |
+
assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now"
|
161 |
+
from pixelflow.solver_ode_wrapper import ODE
|
162 |
+
sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample
|
163 |
+
else:
|
164 |
+
# default Euler
|
165 |
+
sample_fn = None
|
166 |
+
|
167 |
+
self._guidance_scale = guidance_scale
|
168 |
+
batch_size = len(prompt)
|
169 |
+
if self.class_cond:
|
170 |
+
prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device)
|
171 |
+
negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds)
|
172 |
+
if self.do_classifier_free_guidance:
|
173 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
174 |
+
else:
|
175 |
+
prompt_embeds, prompt_attention_mask = self.encode_prompt(
|
176 |
+
prompt,
|
177 |
+
device,
|
178 |
+
num_images_per_prompt,
|
179 |
+
guidance_scale > 1,
|
180 |
+
"",
|
181 |
+
prompt_embeds=None,
|
182 |
+
negative_prompt_embeds=None,
|
183 |
+
use_attention_mask=True,
|
184 |
+
max_length=self.max_token_length,
|
185 |
+
)
|
186 |
+
|
187 |
+
init_factor = 2 ** (self.num_stages - 1)
|
188 |
+
height, width = height // init_factor, width // init_factor
|
189 |
+
shape = (batch_size * num_images_per_prompt, 3, height, width)
|
190 |
+
latents = randn_tensor(shape, device=device, dtype=torch.float32)
|
191 |
+
|
192 |
+
for stage_idx in range(self.num_stages):
|
193 |
+
stage_start = time.time()
|
194 |
+
# Set the number of inference steps for the current stage
|
195 |
+
self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
|
196 |
+
Timesteps = self.scheduler.Timesteps
|
197 |
+
|
198 |
+
if stage_idx > 0:
|
199 |
+
height, width = height * 2, width * 2
|
200 |
+
latents = F.interpolate(latents, size=(height, width), mode='nearest')
|
201 |
+
original_start_t = self.scheduler.original_start_t[stage_idx]
|
202 |
+
gamma = self.scheduler.gamma
|
203 |
+
alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
|
204 |
+
beta = alpha * (1 - original_start_t) / math.sqrt(- gamma)
|
205 |
+
|
206 |
+
# bs, ch, height, width = latents.shape
|
207 |
+
noise = self.sample_block_noise(*latents.shape)
|
208 |
+
noise = noise.to(device=device, dtype=latents.dtype)
|
209 |
+
latents = alpha * latents + beta * noise
|
210 |
+
|
211 |
+
size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device)
|
212 |
+
pos_embed = get_2d_rotary_pos_embed(
|
213 |
+
embed_dim=self.head_dim,
|
214 |
+
crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)),
|
215 |
+
grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size),
|
216 |
+
)
|
217 |
+
rope_pos = torch.stack(pos_embed, -1)
|
218 |
+
|
219 |
+
if sample_fn is not None:
|
220 |
+
# dopri5
|
221 |
+
model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos)
|
222 |
+
if stage_idx == 0:
|
223 |
+
latents = torch.cat([latents] * 2)
|
224 |
+
stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item()
|
225 |
+
stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item()
|
226 |
+
latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1]
|
227 |
+
if stage_idx == self.num_stages - 1:
|
228 |
+
latents = latents[:latents.shape[0] // 2]
|
229 |
+
else:
|
230 |
+
# euler
|
231 |
+
for T in Timesteps:
|
232 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
233 |
+
timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
|
234 |
+
if self.class_cond:
|
235 |
+
noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos)
|
236 |
+
else:
|
237 |
+
encoder_hidden_states = prompt_embeds
|
238 |
+
encoder_attention_mask = prompt_attention_mask
|
239 |
+
|
240 |
+
noise_pred = self.transformer(
|
241 |
+
latent_model_input,
|
242 |
+
encoder_hidden_states=encoder_hidden_states,
|
243 |
+
encoder_attention_mask=encoder_attention_mask,
|
244 |
+
timestep=timestep,
|
245 |
+
latent_size=size_tensor,
|
246 |
+
pos_embed=rope_pos,
|
247 |
+
)
|
248 |
+
|
249 |
+
if self.do_classifier_free_guidance:
|
250 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
251 |
+
noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond)
|
252 |
+
|
253 |
+
latents = self.scheduler.step(model_output=noise_pred, sample=latents)
|
254 |
+
stage_end = time.time()
|
255 |
+
|
256 |
+
samples = (latents / 2 + 0.5).clamp(0, 1)
|
257 |
+
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
|
258 |
+
return samples
|
259 |
+
|
260 |
+
@property
|
261 |
+
def device(self):
|
262 |
+
return next(self.transformer.parameters()).device
|
263 |
+
|
264 |
+
@property
|
265 |
+
def dtype(self):
|
266 |
+
return next(self.transformer.parameters()).dtype
|
267 |
+
|
268 |
+
def guidance_scale(self, step=None, stage_idx=None):
|
269 |
+
if not self.class_cond:
|
270 |
+
return self._guidance_scale
|
271 |
+
scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1}
|
272 |
+
return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1
|
273 |
+
|
274 |
+
@property
|
275 |
+
def do_classifier_free_guidance(self):
|
276 |
+
return self._guidance_scale > 0
|
pixelflow/scheduling_pixelflow.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def cal_rectify_ratio(start_t, gamma):
|
7 |
+
return 1 / (math.sqrt(1 - (1 / gamma)) * (1 - start_t) + start_t)
|
8 |
+
|
9 |
+
|
10 |
+
class PixelFlowScheduler:
|
11 |
+
def __init__(self, num_train_timesteps, num_stages, gamma=-1 / 3):
|
12 |
+
assert num_stages > 0, f"num_stages must be positive, got {num_stages}"
|
13 |
+
self.num_stages = num_stages
|
14 |
+
self.gamma = gamma
|
15 |
+
|
16 |
+
self.Timesteps = torch.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=torch.float32)
|
17 |
+
|
18 |
+
self.t = self.Timesteps / num_train_timesteps # normalized time in [0, 1]
|
19 |
+
|
20 |
+
self.stage_range = [x / num_stages for x in range(num_stages + 1)]
|
21 |
+
|
22 |
+
self.original_start_t = dict()
|
23 |
+
self.start_t, self.end_t = dict(), dict()
|
24 |
+
self.t_window_per_stage = dict()
|
25 |
+
self.Timesteps_per_stage = dict()
|
26 |
+
stage_distance = list()
|
27 |
+
|
28 |
+
# stage_idx = 0: min t, min resolution, most noisy
|
29 |
+
# stage_idx = num_stages - 1 : max t, max resolution, most clear
|
30 |
+
for stage_idx in range(num_stages):
|
31 |
+
start_idx = max(int(num_train_timesteps * self.stage_range[stage_idx]), 0)
|
32 |
+
end_idx = min(int(num_train_timesteps * self.stage_range[stage_idx + 1]), num_train_timesteps)
|
33 |
+
|
34 |
+
start_t = self.t[start_idx].item()
|
35 |
+
end_t = self.t[end_idx].item() if end_idx < num_train_timesteps else 1.0
|
36 |
+
|
37 |
+
self.original_start_t[stage_idx] = start_t
|
38 |
+
|
39 |
+
if stage_idx > 0:
|
40 |
+
start_t *= cal_rectify_ratio(start_t, gamma)
|
41 |
+
|
42 |
+
self.start_t[stage_idx] = start_t
|
43 |
+
self.end_t[stage_idx] = end_t
|
44 |
+
stage_distance.append(end_t - start_t)
|
45 |
+
|
46 |
+
total_stage_distance = sum(stage_distance)
|
47 |
+
t_within_stage = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float64)[:-1]
|
48 |
+
|
49 |
+
for stage_idx in range(num_stages):
|
50 |
+
start_ratio = 0.0 if stage_idx == 0 else sum(stage_distance[:stage_idx]) / total_stage_distance
|
51 |
+
end_ratio = 1.0 if stage_idx == num_stages - 1 else sum(stage_distance[:stage_idx + 1]) / total_stage_distance
|
52 |
+
|
53 |
+
Timestep_start = self.Timesteps[int(num_train_timesteps * start_ratio)]
|
54 |
+
Timestep_end = self.Timesteps[min(int(num_train_timesteps * end_ratio), num_train_timesteps - 1)]
|
55 |
+
|
56 |
+
self.t_window_per_stage[stage_idx] = t_within_stage
|
57 |
+
|
58 |
+
if stage_idx == num_stages - 1:
|
59 |
+
self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps, dtype=torch.float64)
|
60 |
+
else:
|
61 |
+
self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps + 1, dtype=torch.float64)[:-1]
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def time_linear_to_Timesteps(t, t_start, t_end, T_start, T_end):
|
65 |
+
"""
|
66 |
+
linearly map t to T: T = k * t + b
|
67 |
+
"""
|
68 |
+
k = (T_end - T_start) / (t_end - t_start)
|
69 |
+
b = T_start - t_start * k
|
70 |
+
return k * t + b
|
71 |
+
|
72 |
+
def set_timesteps(self, num_inference_steps, stage_index, device=None, shift=1.0):
|
73 |
+
self.num_inference_steps = num_inference_steps
|
74 |
+
|
75 |
+
stage_T_start = self.Timesteps_per_stage[stage_index][0].item()
|
76 |
+
stage_T_end = self.Timesteps_per_stage[stage_index][-1].item()
|
77 |
+
|
78 |
+
t_start = self.t_window_per_stage[stage_index][0].item()
|
79 |
+
t_end = self.t_window_per_stage[stage_index][-1].item()
|
80 |
+
|
81 |
+
t = np.linspace(t_start, t_end, num_inference_steps, dtype=np.float64)
|
82 |
+
t = t / (shift + (1 - shift) * t)
|
83 |
+
|
84 |
+
Timesteps = self.time_linear_to_Timesteps(t, t_start, t_end, stage_T_start, stage_T_end)
|
85 |
+
self.Timesteps = torch.from_numpy(Timesteps).to(device=device)
|
86 |
+
|
87 |
+
self.t = torch.from_numpy(np.append(t, 1.0)).to(device=device, dtype=torch.float64)
|
88 |
+
self._step_index = None
|
89 |
+
|
90 |
+
def step(self, model_output, sample):
|
91 |
+
if self.step_index is None:
|
92 |
+
self._step_index = 0
|
93 |
+
|
94 |
+
sample = sample.to(torch.float32)
|
95 |
+
t = self.t[self.step_index].float()
|
96 |
+
t_next = self.t[self.step_index + 1].float()
|
97 |
+
|
98 |
+
prev_sample = sample + (t_next - t) * model_output
|
99 |
+
self._step_index += 1
|
100 |
+
|
101 |
+
return prev_sample.to(model_output.dtype)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def step_index(self):
|
105 |
+
"""Current step index for the scheduler."""
|
106 |
+
return self._step_index
|
pixelflow/solver_ode_wrapper.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchdiffeq import odeint
|
3 |
+
|
4 |
+
|
5 |
+
# https://github.com/willisma/SiT/blob/main/transport/integrators.py#L77
|
6 |
+
class ODE:
|
7 |
+
"""ODE solver class"""
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
*,
|
11 |
+
t0,
|
12 |
+
t1,
|
13 |
+
sampler_type,
|
14 |
+
num_steps,
|
15 |
+
atol,
|
16 |
+
rtol,
|
17 |
+
):
|
18 |
+
assert t0 < t1, "ODE sampler has to be in forward time"
|
19 |
+
|
20 |
+
self.t = torch.linspace(t0, t1, num_steps)
|
21 |
+
self.atol = atol
|
22 |
+
self.rtol = rtol
|
23 |
+
self.sampler_type = sampler_type
|
24 |
+
|
25 |
+
def time_linear_to_Timesteps(self, t, t_start, t_end, T_start, T_end):
|
26 |
+
# T = k * t + b
|
27 |
+
k = (T_end - T_start) / (t_end - t_start)
|
28 |
+
b = T_start - t_start * k
|
29 |
+
return k * t + b
|
30 |
+
|
31 |
+
def sample(self, x, model, T_start, T_end, **model_kwargs):
|
32 |
+
device = x[0].device if isinstance(x, tuple) else x.device
|
33 |
+
def _fn(t, x):
|
34 |
+
t = torch.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else torch.ones(x.size(0)).to(device) * t
|
35 |
+
model_output = model(x, self.time_linear_to_Timesteps(t, 0, 1, T_start, T_end), **model_kwargs)
|
36 |
+
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
|
37 |
+
return model_output
|
38 |
+
|
39 |
+
t = self.t.to(device)
|
40 |
+
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
|
41 |
+
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
|
42 |
+
samples = odeint(
|
43 |
+
_fn,
|
44 |
+
x,
|
45 |
+
t,
|
46 |
+
method=self.sampler_type,
|
47 |
+
atol=atol,
|
48 |
+
rtol=rtol
|
49 |
+
)
|
50 |
+
return samples
|
pixelflow/utils/config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
|
4 |
+
def get_obj_from_str(string, reload=False):
|
5 |
+
module, cls = string.rsplit(".", 1)
|
6 |
+
if reload:
|
7 |
+
module_imp = importlib.import_module(module)
|
8 |
+
importlib.reload(module_imp)
|
9 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
10 |
+
|
11 |
+
|
12 |
+
def instantiate_from_config(config):
|
13 |
+
if not "target" in config:
|
14 |
+
raise KeyError("Expected key `target` to instantiate.")
|
15 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
16 |
+
|
17 |
+
|
18 |
+
def instantiate_optimizer_from_config(config, params):
|
19 |
+
if not "target" in config:
|
20 |
+
raise KeyError("Expected key `target` to instantiate.")
|
21 |
+
return get_obj_from_str(config["target"])(params, **config.get("params", dict()))
|
22 |
+
|
23 |
+
|
24 |
+
def instantiate_dataset_from_config(config, transform):
|
25 |
+
if not "target" in config:
|
26 |
+
raise KeyError("Expected key `target` to instantiate.")
|
27 |
+
return get_obj_from_str(config["target"])(transform=transform, **config.get("params", dict()))
|
pixelflow/utils/logger.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
|
4 |
+
|
5 |
+
class PathSimplifierFormatter(logging.Formatter):
|
6 |
+
def format(self, record):
|
7 |
+
record.short_path = os.path.relpath(record.pathname)
|
8 |
+
return super().format(record)
|
9 |
+
|
10 |
+
|
11 |
+
def setup_logger(log_directory, experiment_name, process_rank, source_module=__name__):
|
12 |
+
handlers = [logging.StreamHandler()]
|
13 |
+
|
14 |
+
if process_rank == 0:
|
15 |
+
log_file_path = os.path.join(log_directory, f"{experiment_name}.log")
|
16 |
+
handlers.append(logging.FileHandler(log_file_path))
|
17 |
+
|
18 |
+
log_formatter = PathSimplifierFormatter(
|
19 |
+
fmt='[%(asctime)s %(short_path)s:%(lineno)d] %(message)s',
|
20 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
21 |
+
)
|
22 |
+
|
23 |
+
for handler in handlers:
|
24 |
+
handler.setFormatter(log_formatter)
|
25 |
+
|
26 |
+
logging.basicConfig(level=logging.INFO, handlers=handlers)
|
27 |
+
return logging.getLogger(source_module)
|
pixelflow/utils/misc.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
def seed_everything(seed=0, deterministic_ops=True, allow_tf32=False):
|
7 |
+
"""
|
8 |
+
Sets the seed for reproducibility across various libraries and frameworks, and configures PyTorch backend settings.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
seed (int): The seed value for random number generation. Default is 0.
|
12 |
+
deterministic_ops (bool): Whether to enable deterministic operations in PyTorch.
|
13 |
+
Enabling this can make results reproducible at the cost of potential performance degradation. Default is True.
|
14 |
+
allow_tf32 (bool): Whether to allow TensorFloat-32 (TF32) precision in PyTorch operations. TF32 can improve performance but may affect reproducibility. Default is False.
|
15 |
+
|
16 |
+
Effects:
|
17 |
+
- Seeds Python's random module, NumPy, and PyTorch (CPU and GPU).
|
18 |
+
- Sets the environment variable `PYTHONHASHSEED` to the specified seed.
|
19 |
+
- Configures PyTorch to use deterministic algorithms if `deterministic_ops` is True.
|
20 |
+
- Configures TensorFloat-32 precision based on `allow_tf32`.
|
21 |
+
- Issues warnings if configurations may impact reproducibility.
|
22 |
+
|
23 |
+
Notes:
|
24 |
+
- Setting `torch.backends.cudnn.deterministic` to False allows nondeterministic operations, which may introduce variability.
|
25 |
+
- Allowing TF32 (`allow_tf32=True`) may lead to non-reproducible results, especially in matrix operations.
|
26 |
+
"""
|
27 |
+
# Seed standard random number generators
|
28 |
+
random.seed(seed)
|
29 |
+
np.random.seed(seed)
|
30 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
31 |
+
|
32 |
+
# Seed PyTorch random number generators
|
33 |
+
torch.manual_seed(seed)
|
34 |
+
torch.cuda.manual_seed_all(seed)
|
35 |
+
|
36 |
+
# Configure deterministic operations
|
37 |
+
if deterministic_ops:
|
38 |
+
torch.backends.cudnn.deterministic = True
|
39 |
+
torch.use_deterministic_algorithms(True)
|
40 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
41 |
+
else:
|
42 |
+
torch.backends.cudnn.deterministic = False
|
43 |
+
print("WARNING: torch.backends.cudnn.deterministic is set to False, reproducibility is not guaranteed.")
|
44 |
+
|
45 |
+
# Configure TensorFloat-32 precision
|
46 |
+
if allow_tf32:
|
47 |
+
print("WARNING: TensorFloat-32 (TF32) is enabled; reproducibility is not guaranteed.")
|
48 |
+
|
49 |
+
torch.backends.cudnn.allow_tf32 = allow_tf32 # Default True in PyTorch 2.6.0
|
50 |
+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Default False in PyTorch 2.6.0
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops
|
2 |
+
pandas
|
3 |
+
pyarrow
|
4 |
+
omegaconf
|
5 |
+
diffusers==0.32.2
|
6 |
+
transformers==4.48.0
|
7 |
+
torchdiffeq==0.2.4
|