ShoufaChen commited on
Commit
137645c
·
verified ·
1 Parent(s): be16882
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import spaces
6
+ from imagenet_en_cn import IMAGENET_1K_CLASSES
7
+ from omegaconf import OmegaConf
8
+ from huggingface_hub import snapshot_download
9
+
10
+ import torch
11
+ # from transformers import T5EncoderModel, AutoTokenizer
12
+
13
+ from pixelflow.scheduling_pixelflow import PixelFlowScheduler
14
+ from pixelflow.pipeline_pixelflow import PixelFlowPipeline
15
+ from pixelflow.utils import config as config_utils
16
+ from pixelflow.utils.misc import seed_everything
17
+
18
+
19
+ parser = argparse.ArgumentParser(description='Gradio Demo', add_help=False)
20
+ parser.add_argument('--checkpoint', type=str, help='checkpoint folder path')
21
+ parser.add_argument('--class_cond', action='store_true', help='use class conditional generation')
22
+ args = parser.parse_args()
23
+
24
+ # deploy
25
+ args.checkpoint = "pixelflow_c2i"
26
+ args.class_cond = True
27
+
28
+
29
+ if args.class_cond:
30
+ output_dir = args.checkpoint
31
+ if not os.path.exists(output_dir):
32
+ snapshot_download(repo_id="ShoufaChen/PixelFlow-Class2Image", local_dir=output_dir)
33
+ config = OmegaConf.load(f"{output_dir}/config.yaml")
34
+ model = config_utils.instantiate_from_config(config.model)
35
+ print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
36
+ ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
37
+ text_encoder = None
38
+ tokenizer = None
39
+ resolution = 256
40
+ NUM_EXAMPLES = 4
41
+ else:
42
+ raise NotImplementedError("Please run locally.")
43
+ config = OmegaConf.load(f"{output_dir}/config.yaml")
44
+ model = config_utils.instantiate_from_config(config.model).to(device)
45
+ print(f"Num of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
46
+ ckpt = torch.load(f"{output_dir}/model.pt", map_location="cpu", weights_only=True)
47
+ text_encoder = T5EncoderModel.from_pretrained("google/flan-t5-xl").to(device)
48
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl")
49
+ resolution = 1024
50
+ NUM_EXAMPLES = 1
51
+ model.load_state_dict(ckpt, strict=True)
52
+ model.eval()
53
+
54
+ print(f"outside space.GPU. {torch.cuda.is_available()=}")
55
+ if torch.cuda.is_available():
56
+ model = model.cuda()
57
+ device = torch.device("cuda")
58
+ else:
59
+ raise ValueError("No GPU")
60
+
61
+ scheduler = PixelFlowScheduler(config.scheduler.num_train_timesteps, num_stages=config.scheduler.num_stages, gamma=-1/3)
62
+
63
+ pipeline = PixelFlowPipeline(
64
+ scheduler,
65
+ model,
66
+ text_encoder=text_encoder,
67
+ tokenizer=tokenizer,
68
+ max_token_length=512,
69
+ )
70
+
71
+ @spaces.GPU
72
+ def infer(use_ode_dopri5, noise_shift, cfg_scale, class_label, seed, *num_steps_per_stage):
73
+ print(f"inside space.GPU. {torch.cuda.is_available()=}")
74
+ seed_everything(seed)
75
+ with torch.autocast("cuda", dtype=torch.bfloat16), torch.no_grad():
76
+ samples = pipeline(
77
+ prompt=[class_label] * NUM_EXAMPLES,
78
+ height=resolution,
79
+ width=resolution,
80
+ num_inference_steps=list(num_steps_per_stage),
81
+ guidance_scale=cfg_scale, # The guidance for the first frame, set it to 7 for 384p variant
82
+ device=device,
83
+ shift=noise_shift,
84
+ use_ode_dopri5=use_ode_dopri5,
85
+ )
86
+ samples = (samples * 255).round().astype("uint8")
87
+ samples = [Image.fromarray(sample) for sample in samples]
88
+ return samples
89
+
90
+
91
+ css = """
92
+ h1 {
93
+ text-align: center;
94
+ display: block;
95
+ }
96
+
97
+ .follow-link {
98
+ margin-top: 0.8em;
99
+ font-size: 1em;
100
+ text-align: center;
101
+ }
102
+ """
103
+
104
+
105
+ with gr.Blocks(css=css) as demo:
106
+ gr.Markdown("# PixelFlow: Pixel-Space Generative Models with Flow")
107
+ gr.HTML("""
108
+ <div class="follow-link">
109
+ For text-to-image generation, please follow
110
+ <a href="https://github.com/ShoufaChen/PixelFlow/tree/main?tab=readme-ov-file#demo">text-to-image</a>.
111
+ For more details, refer to our
112
+ <a href="https://arxiv.org/abs/2504.07963">arXiv paper</a> and <a href="https://github.com/ShoufaChen/PixelFlow">GitHub repo</a>.
113
+ </div>
114
+ """)
115
+
116
+ with gr.Tabs():
117
+ with gr.TabItem('Generate'):
118
+ with gr.Row():
119
+ with gr.Column():
120
+ with gr.Row():
121
+ if args.class_cond:
122
+ user_input = gr.Dropdown(
123
+ list(IMAGENET_1K_CLASSES.values()),
124
+ value='daisy [雏菊]',
125
+ type="index", label='ImageNet-1K Class'
126
+ )
127
+ else:
128
+ # text input
129
+ user_input = gr.Textbox(label='Enter your prompt', show_label=False, max_lines=1, placeholder="Enter your prompt",)
130
+ ode_dopri5 = gr.Checkbox(label="Dopri5 ODE", info="Use Dopri5 ODE solver")
131
+ noise_shift = gr.Slider(minimum=1.0, maximum=100.0, step=1, value=1.0, label='Noise Shift')
132
+ cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
133
+ num_steps_per_stage = []
134
+ for stage_idx in range(config.scheduler.num_stages):
135
+ num_steps = gr.Slider(minimum=1, maximum=100, step=1, value=10, label=f'Num Inference Steps (Stage {stage_idx})')
136
+ num_steps_per_stage.append(num_steps)
137
+ seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
138
+ button = gr.Button("Generate", variant="primary")
139
+ with gr.Column():
140
+ output = gr.Gallery(label='Generated Images', height=700)
141
+ button.click(infer, inputs=[ode_dopri5, noise_shift, cfg_scale, user_input, seed, *num_steps_per_stage], outputs=[output])
142
+ demo.queue()
143
+ demo.launch(share=True, debug=True)
imagenet_en_cn.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ IMAGENET_1K_CLASSES = {
2
+ 0: 'tench, Tinca tinca [丁鲷]',
3
+ 1: 'goldfish, Carassius auratus [金鱼]',
4
+ 2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias [大白鲨]',
5
+ 3: 'tiger shark, Galeocerdo cuvieri [虎鲨]',
6
+ 4: 'hammerhead, hammerhead shark [锤头鲨]',
7
+ 5: 'electric ray, crampfish, numbfish, torpedo [电鳐]',
8
+ 6: 'stingray [黄貂鱼]',
9
+ 7: 'cock [公鸡]',
10
+ 8: 'hen [母鸡]',
11
+ 9: 'ostrich, Struthio camelus [鸵鸟]',
12
+ 10: 'brambling, Fringilla montifringilla [燕雀]',
13
+ 11: 'goldfinch, Carduelis carduelis [金翅雀]',
14
+ 12: 'house finch, linnet, Carpodacus mexicanus [家朱雀]',
15
+ 13: 'junco, snowbird [灯芯草雀]',
16
+ 14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea [靛蓝雀,靛蓝鸟]',
17
+ 15: 'robin, American robin, Turdus migratorius [蓝鹀]',
18
+ 16: 'bulbul [夜莺]',
19
+ 17: 'jay [松鸦]',
20
+ 18: 'magpie [喜鹊]',
21
+ 19: 'chickadee [山雀]',
22
+ 20: 'water ouzel, dipper [河鸟]',
23
+ 21: 'kite [鸢(猛禽)]',
24
+ 22: 'bald eagle, American eagle, Haliaeetus leucocephalus [秃头鹰]',
25
+ 23: 'vulture [秃鹫]',
26
+ 24: 'great grey owl, great gray owl, Strix nebulosa [大灰猫头鹰]',
27
+ 25: 'European fire salamander, Salamandra salamandra [欧洲火蝾螈]',
28
+ 26: 'common newt, Triturus vulgaris [普通蝾螈]',
29
+ 27: 'eft [水蜥]',
30
+ 28: 'spotted salamander, Ambystoma maculatum [斑点蝾螈]',
31
+ 29: 'axolotl, mud puppy, Ambystoma mexicanum [蝾螈,泥狗]',
32
+ 30: 'bullfrog, Rana catesbeiana [牛蛙]',
33
+ 31: 'tree frog, tree-frog [树蛙]',
34
+ 32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui [尾蛙,铃蟾蜍,肋蟾蜍,尾蟾蜍]',
35
+ 33: 'loggerhead, loggerhead turtle, Caretta caretta [红海龟]',
36
+ 34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea [皮革龟]',
37
+ 35: 'mud turtle [泥龟]',
38
+ 36: 'terrapin [淡水龟]',
39
+ 37: 'box turtle, box tortoise [箱龟]',
40
+ 38: 'banded gecko [带状壁虎]',
41
+ 39: 'common iguana, iguana, Iguana iguana [普通鬣蜥]',
42
+ 40: 'American chameleon, anole, Anolis carolinensis [美国变色龙]',
43
+ 41: 'whiptail, whiptail lizard [鞭尾蜥蜴]',
44
+ 42: 'agama [飞龙科蜥蜴]',
45
+ 43: 'frilled lizard, Chlamydosaurus kingi [褶边蜥蜴]',
46
+ 44: 'alligator lizard [鳄鱼蜥蜴]',
47
+ 45: 'Gila monster, Heloderma suspectum [毒蜥]',
48
+ 46: 'green lizard, Lacerta viridis [绿蜥蜴]',
49
+ 47: 'African chameleon, Chamaeleo chamaeleon [非洲变色龙]',
50
+ 48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis [科莫多蜥蜴]',
51
+ 49: 'African crocodile, Nile crocodile, Crocodylus niloticus [非洲鳄,尼罗河鳄鱼]',
52
+ 50: 'American alligator, Alligator mississipiensis [美国鳄鱼,鳄鱼]',
53
+ 51: 'triceratops [三角龙]',
54
+ 52: 'thunder snake, worm snake, Carphophis amoenus [雷蛇,蠕虫蛇]',
55
+ 53: 'ringneck snake, ring-necked snake, ring snake [环蛇,环颈蛇]',
56
+ 54: 'hognose snake, puff adder, sand viper [希腊蛇]',
57
+ 55: 'green snake, grass snake [绿蛇,草蛇]',
58
+ 56: 'king snake, kingsnake [国王蛇]',
59
+ 57: 'garter snake, grass snake [袜带蛇,草蛇]',
60
+ 58: 'water snake [水蛇]',
61
+ 59: 'vine snake [藤蛇]',
62
+ 60: 'night snake, Hypsiglena torquata [夜蛇]',
63
+ 61: 'boa constrictor, Constrictor constrictor [大蟒蛇]',
64
+ 62: 'rock python, rock snake, Python sebae [岩石蟒蛇,岩蛇,蟒蛇]',
65
+ 63: 'Indian cobra, Naja naja [印度眼镜蛇]',
66
+ 64: 'green mamba [绿曼巴]',
67
+ 65: 'sea snake [海蛇]',
68
+ 66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus [角腹蛇]',
69
+ 67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus [菱纹响尾蛇]',
70
+ 68: 'sidewinder, horned rattlesnake, Crotalus cerastes [角响尾蛇]',
71
+ 69: 'trilobite [三叶虫]',
72
+ 70: 'harvestman, daddy longlegs, Phalangium opilio [盲蜘蛛]',
73
+ 71: 'scorpion [蝎子]',
74
+ 72: 'black and gold garden spider, Argiope aurantia [黑金花园蜘蛛]',
75
+ 73: 'barn spider, Araneus cavaticus [谷仓蜘蛛]',
76
+ 74: 'garden spider, Aranea diademata [花园蜘蛛]',
77
+ 75: 'black widow, Latrodectus mactans [黑寡妇蜘蛛]',
78
+ 76: 'tarantula [狼蛛]',
79
+ 77: 'wolf spider, hunting spider [狼蜘蛛,狩猎蜘蛛]',
80
+ 78: 'tick [壁虱]',
81
+ 79: 'centipede [蜈蚣]',
82
+ 80: 'black grouse [黑松鸡]',
83
+ 81: 'ptarmigan [松鸡,雷鸟]',
84
+ 82: 'ruffed grouse, partridge, Bonasa umbellus [披肩鸡,披肩榛鸡]',
85
+ 83: 'prairie chicken, prairie grouse, prairie fowl [草原鸡,草原松鸡]',
86
+ 84: 'peacock [孔雀]',
87
+ 85: 'quail [鹌鹑]',
88
+ 86: 'partridge [鹧鸪]',
89
+ 87: 'African grey, African gray, Psittacus erithacus [非洲灰鹦鹉]',
90
+ 88: 'macaw [金刚鹦鹉]',
91
+ 89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita [硫冠鹦鹉]',
92
+ 90: 'lorikeet [短尾鹦鹉]',
93
+ 91: 'coucal [褐翅鸦鹃]',
94
+ 92: 'bee eater [蜜蜂]',
95
+ 93: 'hornbill [犀鸟]',
96
+ 94: 'hummingbird [蜂鸟]',
97
+ 95: 'jacamar [鹟䴕]',
98
+ 96: 'toucan [犀鸟]',
99
+ 97: 'drake [野鸭]',
100
+ 98: 'red-breasted merganser, Mergus serrator [���胸秋沙鸭]',
101
+ 99: 'goose [鹅]',
102
+ 100: 'black swan, Cygnus atratus [黑天鹅]',
103
+ 101: 'tusker [大象]',
104
+ 102: 'echidna, spiny anteater, anteater [针鼹鼠]',
105
+ 103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus [鸭嘴兽]',
106
+ 104: 'wallaby, brush kangaroo [沙袋鼠]',
107
+ 105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus [考拉,考拉熊]',
108
+ 106: 'wombat [袋熊]',
109
+ 107: 'jellyfish [水母]',
110
+ 108: 'sea anemone, anemone [海葵]',
111
+ 109: 'brain coral [脑珊瑚]',
112
+ 110: 'flatworm, platyhelminth [扁形虫扁虫]',
113
+ 111: 'nematode, nematode worm, roundworm [线虫,蛔虫]',
114
+ 112: 'conch [海螺]',
115
+ 113: 'snail [蜗牛]',
116
+ 114: 'slug [鼻涕虫]',
117
+ 115: 'sea slug, nudibranch [海参]',
118
+ 116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore [石鳖]',
119
+ 117: 'chambered nautilus, pearly nautilus, nautilus [鹦鹉螺]',
120
+ 118: 'Dungeness crab, Cancer magister [珍宝蟹]',
121
+ 119: 'rock crab, Cancer irroratus [石蟹]',
122
+ 120: 'fiddler crab [招潮蟹]',
123
+ 121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica [帝王蟹,阿拉斯加蟹,阿拉斯加帝王蟹]',
124
+ 122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus [美国龙虾,缅因州龙虾]',
125
+ 123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish [大螯虾]',
126
+ 124: 'crayfish, crawfish, crawdad, crawdaddy [小龙虾]',
127
+ 125: 'hermit crab [寄居蟹]',
128
+ 126: 'isopod [等足目动物(明虾和螃蟹近亲)]',
129
+ 127: 'white stork, Ciconia ciconia [白鹳]',
130
+ 128: 'black stork, Ciconia nigra [黑鹳]',
131
+ 129: 'spoonbill [鹭]',
132
+ 130: 'flamingo [火烈鸟]',
133
+ 131: 'little blue heron, Egretta caerulea [小蓝鹭]',
134
+ 132: 'American egret, great white heron, Egretta albus [美国鹭,大白鹭]',
135
+ 133: 'bittern [麻鸦]',
136
+ 134: 'crane [鹤]',
137
+ 135: 'limpkin, Aramus pictus [秧鹤]',
138
+ 136: 'European gallinule, Porphyrio porphyrio [欧洲水鸡,紫水鸡]',
139
+ 137: 'American coot, marsh hen, mud hen, water hen, Fulica americana [沼泽泥母鸡,水母鸡]',
140
+ 138: 'bustard [鸨]',
141
+ 139: 'ruddy turnstone, Arenaria interpres [红翻石鹬]',
142
+ 140: 'red-backed sandpiper, dunlin, Erolia alpina [红背鹬,黑腹滨鹬]',
143
+ 141: 'redshank, Tringa totanus [红脚鹬]',
144
+ 142: 'dowitcher [半蹼鹬]',
145
+ 143: 'oystercatcher, oyster catcher [蛎鹬]',
146
+ 144: 'pelican [鹈鹕]',
147
+ 145: 'king penguin, Aptenodytes patagonica [国王企鹅]',
148
+ 146: 'albatross, mollymawk [信天翁,大海鸟]',
149
+ 147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus [灰鲸]',
150
+ 148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca [杀人鲸,逆戟鲸,虎鲸]',
151
+ 149: 'dugong, Dugong dugon [海牛]',
152
+ 150: 'sea lion [海狮]',
153
+ 151: 'Chihuahua [奇瓦瓦]',
154
+ 152: 'Japanese spaniel [日本猎犬]',
155
+ 153: 'Maltese dog, Maltese terrier, Maltese [马尔济斯犬]',
156
+ 154: 'Pekinese, Pekingese, Peke [狮子狗]',
157
+ 155: 'Shih-Tzu [西施犬]',
158
+ 156: 'Blenheim spaniel [布莱尼姆猎犬]',
159
+ 157: 'papillon [巴比狗]',
160
+ 158: 'toy terrier [玩具犬]',
161
+ 159: 'Rhodesian ridgeback [罗得西亚长背猎狗]',
162
+ 160: 'Afghan hound, Afghan [阿富汗猎犬]',
163
+ 161: 'basset, basset hound [猎犬]',
164
+ 162: 'beagle [比格犬,猎兔犬]',
165
+ 163: 'bloodhound, sleuthhound [侦探犬]',
166
+ 164: 'bluetick [蓝色快狗]',
167
+ 165: 'black-and-tan coonhound [黑褐猎浣熊犬]',
168
+ 166: 'Walker hound, Walker foxhound [沃克猎犬]',
169
+ 167: 'English foxhound [英国猎狐犬]',
170
+ 168: 'redbone [美洲赤狗]',
171
+ 169: 'borzoi, Russian wolfhound [俄罗斯猎狼犬]',
172
+ 170: 'Irish wolfhound [爱尔兰猎狼犬]',
173
+ 171: 'Italian greyhound [意大利灰狗]',
174
+ 172: 'whippet [惠比特犬]',
175
+ 173: 'Ibizan hound, Ibizan Podenco [依比沙猎犬]',
176
+ 174: 'Norwegian elkhound, elkhound [挪威猎犬]',
177
+ 175: 'otterhound, otter hound [奥达猎犬,水獭猎犬]',
178
+ 176: 'Saluki, gazelle hound [沙克犬,瞪羚猎犬]',
179
+ 177: 'Scottish deerhound, deerhound [苏格兰猎鹿犬,猎鹿犬]',
180
+ 178: 'Weimaraner [威玛猎犬]',
181
+ 179: 'Staffordshire bullterrier, Staffordshire bull terrier [斯塔福德郡牛头梗,斯塔福德郡斗牛梗]',
182
+ 180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier [美国斯塔福德郡梗,美国比特斗牛梗,斗牛梗]',
183
+ 181: 'Bedlington terrier [贝德灵顿梗]',
184
+ 182: 'Border terrier [边境梗]',
185
+ 183: 'Kerry blue terrier [凯丽蓝梗]',
186
+ 184: 'Irish terrier [爱尔兰梗]',
187
+ 185: 'Norfolk terrier [诺福克梗]',
188
+ 186: 'Norwich terrier [诺维奇梗]',
189
+ 187: 'Yorkshire terrier [约克郡梗]',
190
+ 188: 'wire-haired fox terrier [刚毛猎狐梗]',
191
+ 189: 'Lakeland terrier [莱克兰梗]',
192
+ 190: 'Sealyham terrier, Sealyham [锡利哈姆梗]',
193
+ 191: 'Airedale, Airedale terrier [艾尔谷犬]',
194
+ 192: 'cairn, cairn terrier [凯恩梗]',
195
+ 193: 'Australian terrier [澳大利亚梗]',
196
+ 194: 'Dandie Dinmont, Dandie Dinmont terrier [丹迪丁蒙梗]',
197
+ 195: 'Boston bull, Boston terrier [波士顿梗]',
198
+ 196: 'miniature schnauzer [迷你雪纳瑞犬]',
199
+ 197: 'giant schnauzer [巨型雪纳瑞犬]',
200
+ 198: 'standard schnauzer [标准雪纳瑞犬]',
201
+ 199: 'Scotch terrier, Scottish terrier, Scottie [苏格兰梗]',
202
+ 200: 'Tibetan terrier, chrysanthemum dog [西藏梗,菊花狗]',
203
+ 201: 'silky terrier, Sydney silky [丝毛梗]',
204
+ 202: 'soft-coated wheaten terrier [软毛麦色梗]',
205
+ 203: 'West Highland white terrier [西高地白梗]',
206
+ 204: 'Lhasa, Lhasa apso [拉萨阿普索犬]',
207
+ 205: 'flat-coated retriever [平毛寻回犬]',
208
+ 206: 'curly-coated retriever [卷毛寻回犬]',
209
+ 207: 'golden retriever [金毛猎犬]',
210
+ 208: 'Labrador retriever [拉布拉多猎犬]',
211
+ 209: 'Chesapeake Bay retriever [乞沙比克猎犬]',
212
+ 210: 'German short-haired pointer [德国短毛猎犬]',
213
+ 211: 'vizsla, Hungarian pointer [维兹拉犬]',
214
+ 212: 'English setter [英国谍犬]',
215
+ 213: 'Irish setter, red setter [爱尔兰雪达犬,红色猎犬]',
216
+ 214: 'Gordon setter [戈登雪达犬]',
217
+ 215: 'Brittany spaniel [布列塔尼犬猎犬]',
218
+ 216: 'clumber, clumber spaniel [黄毛,黄毛猎犬]',
219
+ 217: 'English springer, English springer spaniel [英国史宾格犬]',
220
+ 218: 'Welsh springer spaniel [威尔士史宾格犬]',
221
+ 219: 'cocker spaniel, English cocker spaniel, cocker [可卡犬,英国可卡犬]',
222
+ 220: 'Sussex spaniel [萨塞克斯猎犬]',
223
+ 221: 'Irish water spaniel [爱尔兰水猎犬]',
224
+ 222: 'kuvasz [哥威斯犬]',
225
+ 223: 'schipperke [舒柏奇犬]',
226
+ 224: 'groenendael [比利时牧羊犬]',
227
+ 225: 'malinois [马里努阿犬]',
228
+ 226: 'briard [伯瑞犬]',
229
+ 227: 'kelpie [凯尔皮犬]',
230
+ 228: 'komondor [匈牙利牧羊犬]',
231
+ 229: 'Old English sheepdog, bobtail [老英国牧羊犬]',
232
+ 230: 'Shetland sheepdog, Shetland sheep dog, Shetland [喜乐蒂牧羊犬]',
233
+ 231: 'collie [牧羊犬]',
234
+ 232: 'Border collie [边境牧羊犬]',
235
+ 233: 'Bouvier des Flandres, Bouviers des Flandres [法兰德斯牧牛狗]',
236
+ 234: 'Rottweiler [罗特韦尔犬]',
237
+ 235: 'German shepherd, German shepherd dog, German police dog, alsatian [德国牧羊犬,德国警犬,阿尔萨斯]',
238
+ 236: 'Doberman, Doberman pinscher [多伯曼犬,杜宾犬]',
239
+ 237: 'miniature pinscher [迷你杜宾犬]',
240
+ 238: 'Greater Swiss Mountain dog [大瑞士山地犬]',
241
+ 239: 'Bernese mountain dog [伯恩山犬]',
242
+ 240: 'Appenzeller [Appenzeller狗]',
243
+ 241: 'EntleBucher [EntleBucher狗]',
244
+ 242: 'boxer [拳师狗]',
245
+ 243: 'bull mastiff [斗牛獒]',
246
+ 244: 'Tibetan mastiff [藏獒]',
247
+ 245: 'French bulldog [法国斗牛犬]',
248
+ 246: 'Great Dane [大丹犬]',
249
+ 247: 'Saint Bernard, St Bernard [圣伯纳德狗]',
250
+ 248: 'Eskimo dog, husky [爱斯基摩犬,哈士奇]',
251
+ 249: 'malamute, malemute, Alaskan malamute [雪橇犬,阿拉斯加爱斯基摩狗]',
252
+ 250: 'Siberian husky [哈士奇]',
253
+ 251: 'dalmatian, coach dog, carriage dog [达尔马提亚,教练车狗]',
254
+ 252: 'affenpinscher, monkey pinscher, monkey dog [狮毛狗]',
255
+ 253: 'basenji [巴辛吉狗]',
256
+ 254: 'pug, pug-dog [哈巴狗,狮子狗]',
257
+ 255: 'Leonberg [莱昂贝格狗]',
258
+ 256: 'Newfoundland, Newfoundland dog [纽芬兰岛狗]',
259
+ 257: 'Great Pyrenees [大白熊犬]',
260
+ 258: 'Samoyed, Samoyede [萨摩耶犬]',
261
+ 259: 'Pomeranian [博美犬]',
262
+ 260: 'chow, chow chow [松狮,松狮]',
263
+ 261: 'keeshond [荷兰卷尾狮毛狗]',
264
+ 262: 'Brabancon griffon [布鲁塞尔格林芬犬]',
265
+ 263: 'Pembroke, Pembroke Welsh corgi [彭布洛克威尔士科基犬]',
266
+ 264: 'Cardigan, Cardigan Welsh corgi [威尔士柯基犬]',
267
+ 265: 'toy poodle [玩具贵宾犬]',
268
+ 266: 'miniature poodle [迷你贵宾犬]',
269
+ 267: 'standard poodle [标准贵宾犬]',
270
+ 268: 'Mexican hairless [墨西哥无毛犬]',
271
+ 269: 'timber wolf, grey wolf, gray wolf, Canis lupus [灰狼]',
272
+ 270: 'white wolf, Arctic wolf, Canis lupus tundrarum [白狼,北极狼]',
273
+ 271: 'red wolf, maned wolf, Canis rufus, Canis niger [红太狼,鬃狼,犬犬鲁弗斯]',
274
+ 272: 'coyote, prairie wolf, brush wolf, Canis latrans [狼,草原狼,刷狼,郊狼]',
275
+ 273: 'dingo, warrigal, warragal, Canis dingo [澳洲野狗,澳大利亚野犬]',
276
+ 274: 'dhole, Cuon alpinus [豺]',
277
+ 275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus [非洲猎犬,土狼犬]',
278
+ 276: 'hyena, hyaena [鬣狗]',
279
+ 277: 'red fox, Vulpes vulpes [红狐狸]',
280
+ 278: 'kit fox, Vulpes macrotis [沙狐]',
281
+ 279: 'Arctic fox, white fox, Alopex lagopus [北极狐狸,白狐狸]',
282
+ 280: 'grey fox, gray fox, Urocyon cinereoargenteus [灰狐狸]',
283
+ 281: 'tabby, tabby cat [虎斑猫]',
284
+ 282: 'tiger cat [山猫,虎猫]',
285
+ 283: 'Persian cat [波斯猫]',
286
+ 284: 'Siamese cat, Siamese [暹罗暹罗猫,]',
287
+ 285: 'Egyptian cat [埃及猫]',
288
+ 286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor [美洲狮,美洲豹]',
289
+ 287: 'lynx, catamount [猞猁,山猫]',
290
+ 288: 'leopard, Panthera pardus [豹子]',
291
+ 289: 'snow leopard, ounce, Panthera uncia [雪豹]',
292
+ 290: 'jaguar, panther, Panthera onca, Felis onca [美洲虎]',
293
+ 291: 'lion, king of beasts, Panthera leo [狮子]',
294
+ 292: 'tiger, Panthera tigris [老虎]',
295
+ 293: 'cheetah, chetah, Acinonyx jubatus [猎豹]',
296
+ 294: 'brown bear, bruin, Ursus arctos [棕熊]',
297
+ 295: 'American black bear, black bear, Ursus americanus, Euarctos americanus [美洲黑熊]',
298
+ 296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus [冰熊,北极熊]',
299
+ 297: 'sloth bear, Melursus ursinus, Ursus ursinus [懒熊]',
300
+ 298: 'mongoose [猫鼬]',
301
+ 299: 'meerkat, mierkat [猫鼬,海猫]',
302
+ 300: 'tiger beetle [虎甲虫]',
303
+ 301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle [瓢虫]',
304
+ 302: 'ground beetle, carabid beetle [土鳖虫]',
305
+ 303: 'long-horned beetle, longicorn, longicorn beetle [天牛]',
306
+ 304: 'leaf beetle, chrysomelid [龟甲虫]',
307
+ 305: 'dung beetle [粪甲虫]',
308
+ 306: 'rhinoceros beetle [犀牛甲虫]',
309
+ 307: 'weevil [象甲]',
310
+ 308: 'fly [苍蝇]',
311
+ 309: 'bee [蜜蜂]',
312
+ 310: 'ant, emmet, pismire [蚂蚁]',
313
+ 311: 'grasshopper, hopper [蚱蜢]',
314
+ 312: 'cricket [蟋蟀]',
315
+ 313: 'walking stick, walkingstick, stick insect [竹节虫]',
316
+ 314: 'cockroach, roach [蟑螂]',
317
+ 315: 'mantis, mantid [螳螂]',
318
+ 316: 'cicada, cicala [蝉]',
319
+ 317: 'leafhopper [叶蝉]',
320
+ 318: 'lacewing, lacewing fly [草蜻蛉]',
321
+ 319: 'dragonfly, darning needle, devils darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk [蜻蜓]',
322
+ 320: 'damselfly [豆娘,蜻蛉]',
323
+ 321: 'admiral [优红蛱蝶]',
324
+ 322: 'ringlet, ringlet butterfly [小环蝴蝶]',
325
+ 323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus [君主蝴蝶,大斑蝶]',
326
+ 324: 'cabbage butterfly [菜粉蝶]',
327
+ 325: 'sulphur butterfly, sulfur butterfly [白蝴蝶]',
328
+ 326: 'lycaenid, lycaenid butterfly [灰蝶]',
329
+ 327: 'starfish, sea star [海星]',
330
+ 328: 'sea urchin [海胆]',
331
+ 329: 'sea cucumber, holothurian [海参,海黄瓜]',
332
+ 330: 'wood rabbit, cottontail, cottontail rabbit [野兔]',
333
+ 331: 'hare [兔]',
334
+ 332: 'Angora, Angora rabbit [安哥拉兔]',
335
+ 333: 'hamster [仓鼠]',
336
+ 334: 'porcupine, hedgehog [刺猬,豪猪,]',
337
+ 335: 'fox squirrel, eastern fox squirrel, Sciurus niger [黑松鼠]',
338
+ 336: 'marmot [土拨鼠]',
339
+ 337: 'beaver [海狸]',
340
+ 338: 'guinea pig, Cavia cobaya [豚鼠,豚鼠]',
341
+ 339: 'sorrel [栗色马]',
342
+ 340: 'zebra [斑马]',
343
+ 341: 'hog, pig, grunter, squealer, Sus scrofa [猪]',
344
+ 342: 'wild boar, boar, Sus scrofa [野猪]',
345
+ 343: 'warthog [疣猪]',
346
+ 344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius [河马]',
347
+ 345: 'ox [牛]',
348
+ 346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis [水牛,亚洲水牛]',
349
+ 347: 'bison [野牛]',
350
+ 348: 'ram, tup [公羊]',
351
+ 349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis [大角羊,洛矶山大角羊]',
352
+ 350: 'ibex, Capra ibex [山羊]',
353
+ 351: 'hartebeest [狷羚]',
354
+ 352: 'impala, Aepyceros melampus [黑斑羚]',
355
+ 353: 'gazelle [瞪羚]',
356
+ 354: 'Arabian camel, dromedary, Camelus dromedarius [阿拉伯单峰骆驼,骆驼]',
357
+ 355: 'llama [羊驼]',
358
+ 356: 'weasel [黄鼠狼]',
359
+ 357: 'mink [水貂]',
360
+ 358: 'polecat, fitch, foulmart, foumart, Mustela putorius [臭猫]',
361
+ 359: 'black-footed ferret, ferret, Mustela nigripes [黑足鼬]',
362
+ 360: 'otter [水獭]',
363
+ 361: 'skunk, polecat, wood pussy [臭鼬,木猫]',
364
+ 362: 'badger [獾]',
365
+ 363: 'armadillo [犰狳]',
366
+ 364: 'three-toed sloth, ai, Bradypus tridactylus [树懒]',
367
+ 365: 'orangutan, orang, orangutang, Pongo pygmaeus [猩猩,婆罗洲猩猩]',
368
+ 366: 'gorilla, Gorilla gorilla [大猩猩]',
369
+ 367: 'chimpanzee, chimp, Pan troglodytes [黑猩猩]',
370
+ 368: 'gibbon, Hylobates lar [长臂猿]',
371
+ 369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus [合趾猿长臂猿,合趾猿]',
372
+ 370: 'guenon, guenon monkey [长尾猴]',
373
+ 371: 'patas, hussar monkey, Erythrocebus patas [赤猴]',
374
+ 372: 'baboon [狒狒]',
375
+ 373: 'macaque [恒河猴,猕猴]',
376
+ 374: 'langur [白头叶猴]',
377
+ 375: 'colobus, colobus monkey [疣猴]',
378
+ 376: 'proboscis monkey, Nasalis larvatus [长鼻猴]',
379
+ 377: 'marmoset [狨(美洲产小型长尾猴)]',
380
+ 378: 'capuchin, ringtail, Cebus capucinus [卷尾猴]',
381
+ 379: 'howler monkey, howler [吼猴]',
382
+ 380: 'titi, titi monkey [伶猴]',
383
+ 381: 'spider monkey, Ateles geoffroyi [蜘蛛猴]',
384
+ 382: 'squirrel monkey, Saimiri sciureus [松鼠猴]',
385
+ 383: 'Madagascar cat, ring-tailed lemur, Lemur catta [马达加斯加环尾狐猴,鼠狐猴]',
386
+ 384: 'indri, indris, Indri indri, Indri brevicaudatus [大狐猴,马达加斯加大狐猴]',
387
+ 385: 'Indian elephant, Elephas maximus [印度大象,亚洲象]',
388
+ 386: 'African elephant, Loxodonta africana [非洲象,非洲象]',
389
+ 387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens [小熊猫]',
390
+ 388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca [大熊猫]',
391
+ 389: 'barracouta, snoek [杖鱼]',
392
+ 390: 'eel [鳗鱼]',
393
+ 391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch [银鲑,银鲑鱼]',
394
+ 392: 'rock beauty, Holocanthus tricolor [三色刺蝶鱼]',
395
+ 393: 'anemone fish [海葵鱼]',
396
+ 394: 'sturgeon [鲟鱼]',
397
+ 395: 'gar, garfish, garpike, billfish, Lepisosteus osseus [雀鳝]',
398
+ 396: 'lionfish [狮子鱼]',
399
+ 397: 'puffer, pufferfish, blowfish, globefish [河豚]',
400
+ 398: 'abacus [算盘]',
401
+ 399: 'abaya [长袍]',
402
+ 400: 'academic gown, academic robe, judge robe [学位袍]',
403
+ 401: 'accordion, piano accordion, squeeze box [手风琴]',
404
+ 402: 'acoustic guitar [原声吉他]',
405
+ 403: 'aircraft carrier, carrier, flattop, attack aircraft carrier [航空母舰]',
406
+ 404: 'airliner [客机]',
407
+ 405: 'airship, dirigible [飞艇]',
408
+ 406: 'altar [祭坛]',
409
+ 407: 'ambulance [救护车]',
410
+ 408: 'amphibian, amphibious vehicle [水陆两用车]',
411
+ 409: 'analog clock [模拟时钟]',
412
+ 410: 'apiary, bee house [蜂房]',
413
+ 411: 'apron [围裙]',
414
+ 412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin [垃圾桶]',
415
+ 413: 'assault rifle, assault gun [攻击步枪,枪]',
416
+ 414: 'backpack, back pack, knapsack, packsack, rucksack, haversack [背包]',
417
+ 415: 'bakery, bakeshop, bakehouse [面包店,面包铺,]',
418
+ 416: 'balance beam, beam [平衡木]',
419
+ 417: 'balloon [热气球]',
420
+ 418: 'ballpoint, ballpoint pen, ballpen, Biro [圆珠笔]',
421
+ 419: 'Band Aid [创可贴]',
422
+ 420: 'banjo [班卓琴]',
423
+ 421: 'bannister, banister, balustrade, balusters, handrail [栏杆,楼梯扶手]',
424
+ 422: 'barbell [杠铃]',
425
+ 423: 'barber chair [理发师的椅子]',
426
+ 424: 'barbershop [理发店]',
427
+ 425: 'barn [牲口棚]',
428
+ 426: 'barometer [晴雨表]',
429
+ 427: 'barrel, cask [圆筒]',
430
+ 428: 'barrow, garden cart, lawn cart, wheelbarrow [园地小车,手推车]',
431
+ 429: 'baseball [棒球]',
432
+ 430: 'basketball [篮球]',
433
+ 431: 'bassinet [婴儿床]',
434
+ 432: 'bassoon [巴松管,低音管]',
435
+ 433: 'bathing cap, swimming cap [游泳帽]',
436
+ 434: 'bath towel [沐浴毛巾]',
437
+ 435: 'bathtub, bathing tub, bath, tub [浴缸,澡盆]',
438
+ 436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon [沙滩车,旅行车]',
439
+ 437: 'beacon, lighthouse, beacon light, pharos [灯塔]',
440
+ 438: 'beaker [高脚杯]',
441
+ 439: 'bearskin, busby, shako [熊皮高帽]',
442
+ 440: 'beer bottle [啤酒瓶]',
443
+ 441: 'beer glass [啤酒杯]',
444
+ 442: 'bell cote, bell cot [钟塔]',
445
+ 443: 'bib [(小儿用的)围嘴]',
446
+ 444: 'bicycle-built-for-two, tandem bicycle, tandem [串联自行车,]',
447
+ 445: 'bikini, two-piece [比基尼]',
448
+ 446: 'binder, ring-binder [装订册]',
449
+ 447: 'binoculars, field glasses, opera glasses [双筒望远镜]',
450
+ 448: 'birdhouse [鸟舍]',
451
+ 449: 'boathouse [船库]',
452
+ 450: 'bobsled, bobsleigh, bob [雪橇]',
453
+ 451: 'bolo tie, bolo, bola tie, bola [饰扣式领带]',
454
+ 452: 'bonnet, poke bonnet [阔边女帽]',
455
+ 453: 'bookcase [书橱]',
456
+ 454: 'bookshop, bookstore, bookstall [书店,书摊]',
457
+ 455: 'bottlecap [瓶盖]',
458
+ 456: 'bow [弓箭]',
459
+ 457: 'bow tie, bow-tie, bowtie [蝴蝶结领结]',
460
+ 458: 'brass, memorial tablet, plaque [铜制牌位]',
461
+ 459: 'brassiere, bra, bandeau [奶罩]',
462
+ 460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty [防波堤,海堤]',
463
+ 461: 'breastplate, aegis, egis [铠甲]',
464
+ 462: 'broom [扫帚]',
465
+ 463: 'bucket, pail [桶]',
466
+ 464: 'buckle [扣环]',
467
+ 465: 'bulletproof vest [防弹背心]',
468
+ 466: 'bullet train, bullet [动车,子弹头列车]',
469
+ 467: 'butcher shop, meat market [肉铺,肉菜市场]',
470
+ 468: 'cab, hack, taxi, taxicab [出租车]',
471
+ 469: 'caldron, cauldron [大锅]',
472
+ 470: 'candle, taper, wax light [蜡烛]',
473
+ 471: 'cannon [大炮]',
474
+ 472: 'canoe [独木舟]',
475
+ 473: 'can opener, tin opener [开瓶器,开罐器]',
476
+ 474: 'cardigan [开衫]',
477
+ 475: 'car mirror [车镜]',
478
+ 476: 'carousel, carrousel, merry-go-round, roundabout, whirligig [旋转木马]',
479
+ 477: 'carpenters kit, tool kit [木匠的工具包,工具包]',
480
+ 478: 'carton [纸箱]',
481
+ 479: 'car wheel [车轮]',
482
+ 480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM [取款机,自动取款机]',
483
+ 481: 'cassette [盒式录音带]',
484
+ 482: 'cassette player [卡带播放器]',
485
+ 483: 'castle [城堡]',
486
+ 484: 'catamaran [双体船]',
487
+ 485: 'CD player [CD播放器]',
488
+ 486: 'cello, violoncello [大提琴]',
489
+ 487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone [移动电话,手机]',
490
+ 488: 'chain [铁链]',
491
+ 489: 'chainlink fence [围栏]',
492
+ 490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour [链甲]',
493
+ 491: 'chain saw, chainsaw [电锯,油锯]',
494
+ 492: 'chest [箱子]',
495
+ 493: 'chiffonier, commode [衣柜,洗脸台]',
496
+ 494: 'chime, bell, gong [编钟,钟,锣]',
497
+ 495: 'china cabinet, china closet [中国橱柜]',
498
+ 496: 'Christmas stocking [圣诞袜]',
499
+ 497: 'church, church building [教堂,教堂建筑]',
500
+ 498: 'cinema, movie theater, movie theatre, movie house, picture palace [电影院,剧场]',
501
+ 499: 'cleaver, meat cleaver, chopper [切肉刀,菜刀]',
502
+ 500: 'cliff dwelling [悬崖屋]',
503
+ 501: 'cloak [斗篷]',
504
+ 502: 'clog, geta, patten, sabot [木屐,木鞋]',
505
+ 503: 'cocktail shaker [鸡尾酒调酒器]',
506
+ 504: 'coffee mug [咖啡杯]',
507
+ 505: 'coffeepot [咖啡壶]',
508
+ 506: 'coil, spiral, volute, whorl, helix [螺旋结构(楼梯)]',
509
+ 507: 'combination lock [组合锁]',
510
+ 508: 'computer keyboard, keypad [电脑键盘,键盘]',
511
+ 509: 'confectionery, confectionary, candy store [糖果,糖果店]',
512
+ 510: 'container ship, containership, container vessel [集装箱船]',
513
+ 511: 'convertible [敞篷车]',
514
+ 512: 'corkscrew, bottle screw [开瓶器,瓶螺杆]',
515
+ 513: 'cornet, horn, trumpet, trump [短号,喇叭]',
516
+ 514: 'cowboy boot [牛仔靴]',
517
+ 515: 'cowboy hat, ten-gallon hat [牛仔帽]',
518
+ 516: 'cradle [摇篮]',
519
+ 517: 'crane [起重机]',
520
+ 518: 'crash helmet [头盔]',
521
+ 519: 'crate [板条箱]',
522
+ 520: 'crib, cot [小儿床]',
523
+ 521: 'Crock Pot [砂锅]',
524
+ 522: 'croquet ball [槌球]',
525
+ 523: 'crutch [拐杖]',
526
+ 524: 'cuirass [胸甲]',
527
+ 525: 'dam, dike, dyke [大坝,堤防]',
528
+ 526: 'desk [书桌]',
529
+ 527: 'desktop computer [台式电脑]',
530
+ 528: 'dial telephone, dial phone [有线电话]',
531
+ 529: 'diaper, nappy, napkin [尿布湿]',
532
+ 530: 'digital clock [数字时钟]',
533
+ 531: 'digital watch [数字手表]',
534
+ 532: 'dining table, board [餐桌板]',
535
+ 533: 'dishrag, dishcloth [抹布]',
536
+ 534: 'dishwasher, dish washer, dishwashing machine [洗碗机,洗碟机]',
537
+ 535: 'disk brake, disc brake [盘式制动器]',
538
+ 536: 'dock, dockage, docking facility [码头,船坞,码头设施]',
539
+ 537: 'dogsled, dog sled, dog sleigh [狗拉雪橇]',
540
+ 538: 'dome [圆顶]',
541
+ 539: 'doormat, welcome mat [门垫,垫子]',
542
+ 540: 'drilling platform, offshore rig [钻井平台,海上钻井]',
543
+ 541: 'drum, membranophone, tympan [鼓,乐器,鼓膜]',
544
+ 542: 'drumstick [鼓槌]',
545
+ 543: 'dumbbell [哑铃]',
546
+ 544: 'Dutch oven [荷兰烤箱]',
547
+ 545: 'electric fan, blower [电风扇,鼓风机]',
548
+ 546: 'electric guitar [电吉他]',
549
+ 547: 'electric locomotive [电力机车]',
550
+ 548: 'entertainment center [电视,电视柜]',
551
+ 549: 'envelope [信封]',
552
+ 550: 'espresso maker [浓缩咖啡机]',
553
+ 551: 'face powder [扑面粉]',
554
+ 552: 'feather boa, boa [女用长围巾]',
555
+ 553: 'file, file cabinet, filing cabinet [文件,文件柜,档案柜]',
556
+ 554: 'fireboat [消防船]',
557
+ 555: 'fire engine, fire truck [消防车]',
558
+ 556: 'fire screen, fireguard [火炉栏]',
559
+ 557: 'flagpole, flagstaff [旗杆]',
560
+ 558: 'flute, transverse flute [长笛]',
561
+ 559: 'folding chair [折叠椅]',
562
+ 560: 'football helmet [橄榄球头盔]',
563
+ 561: 'forklift [叉车]',
564
+ 562: 'fountain [喷泉]',
565
+ 563: 'fountain pen [钢笔]',
566
+ 564: 'four-poster [有四根帷柱的床]',
567
+ 565: 'freight car [运货车厢]',
568
+ 566: 'French horn, horn [圆号,喇叭]',
569
+ 567: 'frying pan, frypan, skillet [煎锅]',
570
+ 568: 'fur coat [裘皮大衣]',
571
+ 569: 'garbage truck, dustcart [垃圾车]',
572
+ 570: 'gasmask, respirator, gas helmet [防毒面具,呼吸器]',
573
+ 571: 'gas pump, gasoline pump, petrol pump, island dispenser [汽油泵]',
574
+ 572: 'goblet [高脚杯]',
575
+ 573: 'go-kart [卡丁车]',
576
+ 574: 'golf ball [高尔夫球]',
577
+ 575: 'golfcart, golf cart [高尔夫球车]',
578
+ 576: 'gondola [狭长小船]',
579
+ 577: 'gong, tam-tam [锣]',
580
+ 578: 'gown [礼服]',
581
+ 579: 'grand piano, grand [钢琴]',
582
+ 580: 'greenhouse, nursery, glasshouse [温室,苗圃]',
583
+ 581: 'grille, radiator grille [散热器格栅]',
584
+ 582: 'grocery store, grocery, food market, market [杂货店,食品市场]',
585
+ 583: 'guillotine [断头台]',
586
+ 584: 'hair slide [小发夹]',
587
+ 585: 'hair spray [头发喷雾]',
588
+ 586: 'half track [半履带装甲车]',
589
+ 587: 'hammer [锤子]',
590
+ 588: 'hamper [大篮子]',
591
+ 589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier [手摇鼓风机,吹风机]',
592
+ 590: 'hand-held computer, hand-held microcomputer [手提电脑]',
593
+ 591: 'handkerchief, hankie, hanky, hankey [手帕]',
594
+ 592: 'hard disc, hard disk, fixed disk [硬盘]',
595
+ 593: 'harmonica, mouth organ, harp, mouth harp [口琴,口风琴]',
596
+ 594: 'harp [竖琴]',
597
+ 595: 'harvester, reaper [收割机]',
598
+ 596: 'hatchet [斧头]',
599
+ 597: 'holster [手枪皮套]',
600
+ 598: 'home theater, home theatre [家庭影院]',
601
+ 599: 'honeycomb [蜂窝]',
602
+ 600: 'hook, claw [钩爪]',
603
+ 601: 'hoopskirt, crinoline [衬裙]',
604
+ 602: 'horizontal bar, high bar [单杠]',
605
+ 603: 'horse cart, horse-cart [马车]',
606
+ 604: 'hourglass [沙漏]',
607
+ 605: 'iPod [手机,iPad]',
608
+ 606: 'iron, smoothing iron [熨斗]',
609
+ 607: 'jack-o-lantern [南瓜灯笼]',
610
+ 608: 'jean, blue jean, denim [牛仔裤,蓝色牛仔裤]',
611
+ 609: 'jeep, landrover [吉普车]',
612
+ 610: 'jersey, T-shirt, tee shirt [运动衫,T恤]',
613
+ 611: 'jigsaw puzzle [拼图]',
614
+ 612: 'jinrikisha, ricksha, rickshaw [人力车]',
615
+ 613: 'joystick [操纵杆]',
616
+ 614: 'kimono [和服]',
617
+ 615: 'knee pad [护膝]',
618
+ 616: 'knot [蝴蝶结]',
619
+ 617: 'lab coat, laboratory coat [大褂,实验室外套]',
620
+ 618: 'ladle [长柄勺]',
621
+ 619: 'lampshade, lamp shade [灯罩]',
622
+ 620: 'laptop, laptop computer [笔记本电脑]',
623
+ 621: 'lawn mower, mower [割草机]',
624
+ 622: 'lens cap, lens cover [镜头盖]',
625
+ 623: 'letter opener, paper knife, paperknife [开信刀,裁纸刀]',
626
+ 624: 'library [图书馆]',
627
+ 625: 'lifeboat [救生艇]',
628
+ 626: 'lighter, light, igniter, ignitor [点火器,打火机]',
629
+ 627: 'limousine, limo [豪华轿车]',
630
+ 628: 'liner, ocean liner [远洋班轮]',
631
+ 629: 'lipstick, lip rouge [唇膏,口红]',
632
+ 630: 'Loafer [平底便鞋]',
633
+ 631: 'lotion [洗剂]',
634
+ 632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system [扬声器]',
635
+ 633: 'loupe, jewelers loupe [放大镜]',
636
+ 634: 'lumbermill, sawmill [锯木厂]',
637
+ 635: 'magnetic compass [磁罗盘]',
638
+ 636: 'mailbag, postbag [邮袋]',
639
+ 637: 'mailbox, letter box [信箱]',
640
+ 638: 'maillot [女游泳衣]',
641
+ 639: 'maillot, tank suit [有肩带浴衣]',
642
+ 640: 'manhole cover [窨井盖]',
643
+ 641: 'maraca [沙球(一种打击乐器)]',
644
+ 642: 'marimba, xylophone [马林巴木琴]',
645
+ 643: 'mask [面膜]',
646
+ 644: 'matchstick [火柴]',
647
+ 645: 'maypole [花柱]',
648
+ 646: 'maze, labyrinth [迷宫]',
649
+ 647: 'measuring cup [量杯]',
650
+ 648: 'medicine chest, medicine cabinet [药箱]',
651
+ 649: 'megalith, megalithic structure [巨石,巨石结构]',
652
+ 650: 'microphone, mike [麦克风]',
653
+ 651: 'microwave, microwave oven [微波炉]',
654
+ 652: 'military uniform [军装]',
655
+ 653: 'milk can [奶桶]',
656
+ 654: 'minibus [迷你巴士]',
657
+ 655: 'miniskirt, mini [迷你裙]',
658
+ 656: 'minivan [面包车]',
659
+ 657: 'missile [导弹]',
660
+ 658: 'mitten [连指手套]',
661
+ 659: 'mixing bowl [搅拌钵]',
662
+ 660: 'mobile home, manufactured home [活动房屋(由汽车拖拉的)]',
663
+ 661: 'Model T [T型发动机小汽车]',
664
+ 662: 'modem [调制解调器]',
665
+ 663: 'monastery [修道院]',
666
+ 664: 'monitor [显示器]',
667
+ 665: 'moped [电瓶车]',
668
+ 666: 'mortar [砂浆]',
669
+ 667: 'mortarboard [学士]',
670
+ 668: 'mosque [清真寺]',
671
+ 669: 'mosquito net [蚊帐]',
672
+ 670: 'motor scooter, scooter [摩托车]',
673
+ 671: 'mountain bike, all-terrain bike, off-roader [山地自行车]',
674
+ 672: 'mountain tent [登山帐]',
675
+ 673: 'mouse, computer mouse [鼠标,电脑鼠标]',
676
+ 674: 'mousetrap [捕鼠器]',
677
+ 675: 'moving van [搬家车]',
678
+ 676: 'muzzle [口套]',
679
+ 677: 'nail [钉子]',
680
+ 678: 'neck brace [颈托]',
681
+ 679: 'necklace [项链]',
682
+ 680: 'nipple [乳头(瓶)]',
683
+ 681: 'notebook, notebook computer [笔记本,笔记本电脑]',
684
+ 682: 'obelisk [方尖碑]',
685
+ 683: 'oboe, hautboy, hautbois [双簧管]',
686
+ 684: 'ocarina, sweet potato [陶笛,卵形笛]',
687
+ 685: 'odometer, hodometer, mileometer, milometer [里程表]',
688
+ 686: 'oil filter [滤油器]',
689
+ 687: 'organ, pipe organ [风琴,管风琴]',
690
+ 688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO [示波器]',
691
+ 689: 'overskirt [罩裙]',
692
+ 690: 'oxcart [牛车]',
693
+ 691: 'oxygen mask [氧气面罩]',
694
+ 692: 'packet [包装]',
695
+ 693: 'paddle, boat paddle [船桨]',
696
+ 694: 'paddlewheel, paddle wheel [明轮,桨轮]',
697
+ 695: 'padlock [挂锁,扣锁]',
698
+ 696: 'paintbrush [画笔]',
699
+ 697: 'pajama, pyjama, pjs, jammies [睡衣]',
700
+ 698: 'palace [宫殿]',
701
+ 699: 'panpipe, pandean pipe, syrinx [排箫,鸣管]',
702
+ 700: 'paper towel [纸巾]',
703
+ 701: 'parachute, chute [降落伞]',
704
+ 702: 'parallel bars, bars [双杠]',
705
+ 703: 'park bench [公园长椅]',
706
+ 704: 'parking meter [停车收费表,停车计时器]',
707
+ 705: 'passenger car, coach, carriage [客车,教练车]',
708
+ 706: 'patio, terrace [露台,阳台]',
709
+ 707: 'pay-phone, pay-station [付费电话]',
710
+ 708: 'pedestal, plinth, footstall [基座,基脚]',
711
+ 709: 'pencil box, pencil case [铅笔盒]',
712
+ 710: 'pencil sharpener [卷笔刀]',
713
+ 711: 'perfume, essence [香水(瓶)]',
714
+ 712: 'Petri dish [培养皿]',
715
+ 713: 'photocopier [复印机]',
716
+ 714: 'pick, plectrum, plectron [拨弦片,拨子]',
717
+ 715: 'pickelhaube [尖顶头盔]',
718
+ 716: 'picket fence, paling [栅栏,栅栏]',
719
+ 717: 'pickup, pickup truck [皮卡,皮卡车]',
720
+ 718: 'pier [桥墩]',
721
+ 719: 'piggy bank, penny bank [存钱罐]',
722
+ 720: 'pill bottle [药瓶]',
723
+ 721: 'pillow [枕头]',
724
+ 722: 'ping-pong ball [乒乓球]',
725
+ 723: 'pinwheel [风车]',
726
+ 724: 'pirate, pirate ship [海盗船]',
727
+ 725: 'pitcher, ewer [水罐]',
728
+ 726: 'plane, carpenters plane, woodworking plane [木工刨]',
729
+ 727: 'planetarium [天文馆]',
730
+ 728: 'plastic bag [塑料袋]',
731
+ 729: 'plate rack [板架]',
732
+ 730: 'plow, plough [犁型铲雪机]',
733
+ 731: 'plunger, plumbers helper [手压皮碗泵]',
734
+ 732: 'Polaroid camera, Polaroid Land camera [宝丽来相机]',
735
+ 733: 'pole [电线杆]',
736
+ 734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria [警车,巡逻车]',
737
+ 735: 'poncho [雨披]',
738
+ 736: 'pool table, billiard table, snooker table [台球桌]',
739
+ 737: 'pop bottle, soda bottle [充气饮料瓶]',
740
+ 738: 'pot, flowerpot [花盆]',
741
+ 739: 'potters wheel [陶工旋盘]',
742
+ 740: 'power drill [电钻]',
743
+ 741: 'prayer rug, prayer mat [祈祷垫,地毯]',
744
+ 742: 'printer [打印机]',
745
+ 743: 'prison, prison house [监狱]',
746
+ 744: 'projectile, missile [炮弹,导弹]',
747
+ 745: 'projector [投影仪]',
748
+ 746: 'puck, hockey puck [冰球]',
749
+ 747: 'punching bag, punch bag, punching ball, punchball [沙包,吊球]',
750
+ 748: 'purse [钱包]',
751
+ 749: 'quill, quill pen [羽管笔]',
752
+ 750: 'quilt, comforter, comfort, puff [被子]',
753
+ 751: 'racer, race car, racing car [赛车]',
754
+ 752: 'racket, racquet [球拍]',
755
+ 753: 'radiator [散热器]',
756
+ 754: 'radio, wireless [收音机]',
757
+ 755: 'radio telescope, radio reflector [射电望远镜,无线电反射器]',
758
+ 756: 'rain barrel [雨桶]',
759
+ 757: 'recreational vehicle, RV, R.V. [休闲车,房车]',
760
+ 758: 'reel [卷轴,卷筒]',
761
+ 759: 'reflex camera [反射式照相机]',
762
+ 760: 'refrigerator, icebox [冰箱,冰柜]',
763
+ 761: 'remote control, remote [遥控器]',
764
+ 762: 'restaurant, eating house, eating place, eatery [餐厅,饮食店,食堂]',
765
+ 763: 'revolver, six-gun, six-shooter [左轮手枪]',
766
+ 764: 'rifle [步枪]',
767
+ 765: 'rocking chair, rocker [摇椅]',
768
+ 766: 'rotisserie [电转烤肉架]',
769
+ 767: 'rubber eraser, rubber, pencil eraser [橡皮]',
770
+ 768: 'rugby ball [橄榄球]',
771
+ 769: 'rule, ruler [直尺]',
772
+ 770: 'running shoe [跑步鞋]',
773
+ 771: 'safe [保险柜]',
774
+ 772: 'safety pin [安全别针]',
775
+ 773: 'saltshaker, salt shaker [盐瓶(调味用)]',
776
+ 774: 'sandal [凉鞋]',
777
+ 775: 'sarong [纱笼,围裙]',
778
+ 776: 'sax, saxophone [萨克斯管]',
779
+ 777: 'scabbard [剑鞘]',
780
+ 778: 'scale, weighing machine [秤,称重机]',
781
+ 779: 'school bus [校车]',
782
+ 780: 'schooner [帆船]',
783
+ 781: 'scoreboard [记分牌]',
784
+ 782: 'screen, CRT screen [屏幕]',
785
+ 783: 'screw [螺丝]',
786
+ 784: 'screwdriver [螺丝刀]',
787
+ 785: 'seat belt, seatbelt [安全带]',
788
+ 786: 'sewing machine [缝纫机]',
789
+ 787: 'shield, buckler [盾牌,盾牌]',
790
+ 788: 'shoe shop, shoe-shop, shoe store [皮鞋店,鞋店]',
791
+ 789: 'shoji [障子]',
792
+ 790: 'shopping basket [购物篮]',
793
+ 791: 'shopping cart [购物车]',
794
+ 792: 'shovel [铁锹]',
795
+ 793: 'shower cap [浴帽]',
796
+ 794: 'shower curtain [浴帘]',
797
+ 795: 'ski [滑雪板]',
798
+ 796: 'ski mask [滑雪面罩]',
799
+ 797: 'sleeping bag [睡袋]',
800
+ 798: 'slide rule, slipstick [滑尺]',
801
+ 799: 'sliding door [滑动门]',
802
+ 800: 'slot, one-armed bandit [角子老虎机]',
803
+ 801: 'snorkel [潜水通气管]',
804
+ 802: 'snowmobile [雪橇]',
805
+ 803: 'snowplow, snowplough [扫雪机,扫雪机]',
806
+ 804: 'soap dispenser [皂液器]',
807
+ 805: 'soccer ball [足球]',
808
+ 806: 'sock [袜子]',
809
+ 807: 'solar dish, solar collector, solar furnace [碟式太阳能,太阳能集热器,太阳能炉]',
810
+ 808: 'sombrero [宽边帽]',
811
+ 809: 'soup bowl [汤碗]',
812
+ 810: 'space bar [空格键]',
813
+ 811: 'space heater [空间加热器]',
814
+ 812: 'space shuttle [航天飞机]',
815
+ 813: 'spatula [铲(搅拌或涂敷用的)]',
816
+ 814: 'speedboat [快艇]',
817
+ 815: 'spider web, spiders web [蜘蛛网]',
818
+ 816: 'spindle [纺锤,纱锭]',
819
+ 817: 'sports car, sport car [跑车]',
820
+ 818: 'spotlight, spot [聚光灯]',
821
+ 819: 'stage [舞台]',
822
+ 820: 'steam locomotive [蒸汽机车]',
823
+ 821: 'steel arch bridge [钢拱桥]',
824
+ 822: 'steel drum [钢滚筒]',
825
+ 823: 'stethoscope [听诊器]',
826
+ 824: 'stole [女用披肩]',
827
+ 825: 'stone wall [石头墙]',
828
+ 826: 'stopwatch, stop watch [秒表]',
829
+ 827: 'stove [火炉]',
830
+ 828: 'strainer [过滤器]',
831
+ 829: 'streetcar, tram, tramcar, trolley, trolley car [有轨电车,电车]',
832
+ 830: 'stretcher [担架]',
833
+ 831: 'studio couch, day bed [沙发床]',
834
+ 832: 'stupa, tope [佛塔]',
835
+ 833: 'submarine, pigboat, sub, U-boat [潜艇,潜水艇]',
836
+ 834: 'suit, suit of clothes [套装,衣服]',
837
+ 835: 'sundial [日晷]',
838
+ 836: 'sunglass [太阳镜]',
839
+ 837: 'sunglasses, dark glasses, shades [太阳镜,墨镜]',
840
+ 838: 'sunscreen, sunblock, sun blocker [防晒霜,防晒剂]',
841
+ 839: 'suspension bridge [悬索桥]',
842
+ 840: 'swab, swob, mop [拖把]',
843
+ 841: 'sweatshirt [运动衫]',
844
+ 842: 'swimming trunks, bathing trunks [游泳裤]',
845
+ 843: 'swing [秋千]',
846
+ 844: 'switch, electric switch, electrical switch [开关,电器开关]',
847
+ 845: 'syringe [注射器]',
848
+ 846: 'table lamp [台灯]',
849
+ 847: 'tank, army tank, armored combat vehicle, armoured combat vehicle [坦克,装甲战车,装甲战斗车辆]',
850
+ 848: 'tape player [磁带播放器]',
851
+ 849: 'teapot [茶壶]',
852
+ 850: 'teddy, teddy bear [泰迪,泰迪熊]',
853
+ 851: 'television, television system [电视]',
854
+ 852: 'tennis ball [网球]',
855
+ 853: 'thatch, thatched roof [茅草,茅草屋顶]',
856
+ 854: 'theater curtain, theatre curtain [幕布,剧院的帷幕]',
857
+ 855: 'thimble [顶针]',
858
+ 856: 'thresher, thrasher, threshing machine [脱粒机]',
859
+ 857: 'throne [宝座]',
860
+ 858: 'tile roof [瓦屋顶]',
861
+ 859: 'toaster [烤面包机]',
862
+ 860: 'tobacco shop, tobacconist shop, tobacconist [烟草店,烟草]',
863
+ 861: 'toilet seat [马桶]',
864
+ 862: 'torch [火炬]',
865
+ 863: 'totem pole [图腾柱]',
866
+ 864: 'tow truck, tow car, wrecker [拖车,牵引车,清障车]',
867
+ 865: 'toyshop [玩具店]',
868
+ 866: 'tractor [拖拉机]',
869
+ 867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi [拖车,铰接式卡车]',
870
+ 868: 'tray [托盘]',
871
+ 869: 'trench coat [风衣]',
872
+ 870: 'tricycle, trike, velocipede [三轮车]',
873
+ 871: 'trimaran [三体船]',
874
+ 872: 'tripod [三脚架]',
875
+ 873: 'triumphal arch [凯旋门]',
876
+ 874: 'trolleybus, trolley coach, trackless trolley [无轨电车]',
877
+ 875: 'trombone [长号]',
878
+ 876: 'tub, vat [浴盆,浴缸]',
879
+ 877: 'turnstile [旋转式栅门]',
880
+ 878: 'typewriter keyboard [打字机键盘]',
881
+ 879: 'umbrella [伞]',
882
+ 880: 'unicycle, monocycle [独轮车]',
883
+ 881: 'upright, upright piano [直立式钢琴]',
884
+ 882: 'vacuum, vacuum cleaner [真空吸尘器]',
885
+ 883: 'vase [花瓶]',
886
+ 884: 'vault [拱顶]',
887
+ 885: 'velvet [天鹅绒]',
888
+ 886: 'vending machine [自动售货机]',
889
+ 887: 'vestment [祭服]',
890
+ 888: 'viaduct [高架桥]',
891
+ 889: 'violin, fiddle [小提琴,小提琴]',
892
+ 890: 'volleyball [排球]',
893
+ 891: 'waffle iron [松饼机]',
894
+ 892: 'wall clock [挂钟]',
895
+ 893: 'wallet, billfold, notecase, pocketbook [钱包,皮夹]',
896
+ 894: 'wardrobe, closet, press [衣柜,壁橱]',
897
+ 895: 'warplane, military plane [军用飞机]',
898
+ 896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin [洗脸盆,洗手盆]',
899
+ 897: 'washer, automatic washer, washing machine [洗衣机,自动洗衣机]',
900
+ 898: 'water bottle [水瓶]',
901
+ 899: 'water jug [水壶]',
902
+ 900: 'water tower [水塔]',
903
+ 901: 'whiskey jug [威士忌壶]',
904
+ 902: 'whistle [哨子]',
905
+ 903: 'wig [假发]',
906
+ 904: 'window screen [纱窗]',
907
+ 905: 'window shade [百叶窗]',
908
+ 906: 'Windsor tie [温莎领带]',
909
+ 907: 'wine bottle [葡萄酒瓶]',
910
+ 908: 'wing [飞机翅膀,飞机]',
911
+ 909: 'wok [炒菜锅]',
912
+ 910: 'wooden spoon [木制的勺子]',
913
+ 911: 'wool, woolen, woollen [毛织品,羊绒]',
914
+ 912: 'worm fence, snake fence, snake-rail fence, Virginia fence [栅栏,围栏]',
915
+ 913: 'wreck [沉船]',
916
+ 914: 'yawl [双桅船]',
917
+ 915: 'yurt [蒙古包]',
918
+ 916: 'web site, website, internet site, site [网站,互联网网站]',
919
+ 917: 'comic book [漫画]',
920
+ 918: 'crossword puzzle, crossword [纵横字谜]',
921
+ 919: 'street sign [路标]',
922
+ 920: 'traffic light, traffic signal, stoplight [交通信号灯]',
923
+ 921: 'book jacket, dust cover, dust jacket, dust wrapper [防尘罩,书皮]',
924
+ 922: 'menu [菜单]',
925
+ 923: 'plate [盘子]',
926
+ 924: 'guacamole [鳄梨酱]',
927
+ 925: 'consomme [清汤]',
928
+ 926: 'hot pot, hotpot [罐焖土豆烧肉]',
929
+ 927: 'trifle [蛋糕]',
930
+ 928: 'ice cream, icecream [冰淇淋]',
931
+ 929: 'ice lolly, lolly, lollipop, popsicle [雪糕,冰棍,冰棒]',
932
+ 930: 'French loaf [法式面包]',
933
+ 931: 'bagel, beigel [百吉饼]',
934
+ 932: 'pretzel [椒盐脆饼]',
935
+ 933: 'cheeseburger [芝士汉堡]',
936
+ 934: 'hotdog, hot dog, red hot [热狗]',
937
+ 935: 'mashed potato [土豆泥]',
938
+ 936: 'head cabbage [结球甘蓝]',
939
+ 937: 'broccoli [西兰花]',
940
+ 938: 'cauliflower [菜花]',
941
+ 939: 'zucchini, courgette [绿皮密生西葫芦]',
942
+ 940: 'spaghetti squash [西葫芦]',
943
+ 941: 'acorn squash [小青南瓜]',
944
+ 942: 'butternut squash [南瓜]',
945
+ 943: 'cucumber, cuke [黄瓜]',
946
+ 944: 'artichoke, globe artichoke [朝鲜蓟]',
947
+ 945: 'bell pepper [甜椒]',
948
+ 946: 'cardoon [刺棘蓟]',
949
+ 947: 'mushroom [蘑菇]',
950
+ 948: 'Granny Smith [绿苹果]',
951
+ 949: 'strawberry [草莓]',
952
+ 950: 'orange [橘子]',
953
+ 951: 'lemon [柠檬]',
954
+ 952: 'fig [无花果]',
955
+ 953: 'pineapple, ananas [菠萝]',
956
+ 954: 'banana [香蕉]',
957
+ 955: 'jackfruit, jak, jack [菠萝蜜]',
958
+ 956: 'custard apple [蛋奶冻苹果]',
959
+ 957: 'pomegranate [石榴]',
960
+ 958: 'hay [干草]',
961
+ 959: 'carbonara [烤面条加干酪沙司]',
962
+ 960: 'chocolate sauce, chocolate syrup [巧克力酱,巧克力糖浆]',
963
+ 961: 'dough [面团]',
964
+ 962: 'meat loaf, meatloaf [瑞士肉包,肉饼]',
965
+ 963: 'pizza, pizza pie [披萨,披萨饼]',
966
+ 964: 'potpie [馅饼]',
967
+ 965: 'burrito [卷饼]',
968
+ 966: 'red wine [红葡萄酒]',
969
+ 967: 'espresso [意大利浓咖啡]',
970
+ 968: 'cup [杯子]',
971
+ 969: 'eggnog [蛋酒]',
972
+ 970: 'alp [高山]',
973
+ 971: 'bubble [泡泡]',
974
+ 972: 'cliff, drop, drop-off [悬崖]',
975
+ 973: 'coral reef [珊瑚礁]',
976
+ 974: 'geyser [间歇泉]',
977
+ 975: 'lakeside, lakeshore [湖边,湖岸]',
978
+ 976: 'promontory, headland, head, foreland [海角]',
979
+ 977: 'sandbar, sand bar [沙洲,沙坝]',
980
+ 978: 'seashore, coast, seacoast, sea-coast [海滨,海岸]',
981
+ 979: 'valley, vale [峡谷]',
982
+ 980: 'volcano [火山]',
983
+ 981: 'ballplayer, baseball player [棒球,棒球运动员]',
984
+ 982: 'groom, bridegroom [新郎]',
985
+ 983: 'scuba diver [潜水员]',
986
+ 984: 'rapeseed [油菜]',
987
+ 985: 'daisy [雏菊]',
988
+ 986: 'yellow ladys slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum [杓兰]',
989
+ 987: 'corn [玉米]',
990
+ 988: 'acorn [橡子]',
991
+ 989: 'hip, rose hip, rosehip [玫瑰果]',
992
+ 990: 'buckeye, horse chestnut, conker [七叶树果实]',
993
+ 991: 'coral fungus [珊瑚菌]',
994
+ 992: 'agaric [木耳]',
995
+ 993: 'gyromitra [鹿花菌]',
996
+ 994: 'stinkhorn, carrion fungus [鬼笔菌]',
997
+ 995: 'earthstar [地星(菌类)]',
998
+ 996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa [多叶奇果菌]',
999
+ 997: 'bolete [牛肝菌]',
1000
+ 998: 'ear, spike, capitulum [玉米穗]',
1001
+ 999: 'toilet tissue, toilet paper, bathroom tissue [卫生纸]',
1002
+ }
pixelflow/data_in1k.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ImageNet-1K Dataset and DataLoader
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.utils.data.distributed import DistributedSampler
7
+ from torchvision.datasets import ImageFolder
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ import math
11
+ from functools import partial
12
+ import numpy as np
13
+ import random
14
+
15
+ from diffusers.models.embeddings import get_2d_rotary_pos_embed
16
+
17
+ # https://github.com/facebookresearch/DiT/blob/main/train.py#L85
18
+ def center_crop_arr(pil_image, image_size):
19
+ while min(*pil_image.size) >= 2 * image_size:
20
+ pil_image = pil_image.resize(
21
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
22
+ )
23
+
24
+ scale = image_size / min(*pil_image.size)
25
+ pil_image = pil_image.resize(
26
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
27
+ )
28
+
29
+ arr = np.array(pil_image)
30
+ crop_y = (arr.shape[0] - image_size) // 2
31
+ crop_x = (arr.shape[1] - image_size) // 2
32
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
33
+
34
+
35
+ def collate_fn(examples, config, noise_scheduler_copy):
36
+ patch_size = config.model.params.patch_size
37
+ pixel_values = torch.stack([eg[0] for eg in examples])
38
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
39
+ input_ids = [eg[1] for eg in examples]
40
+
41
+ batch_size = len(examples)
42
+ stage_indices = list(range(config.scheduler.num_stages)) * (batch_size // config.scheduler.num_stages + 1)
43
+ stage_indices = stage_indices[:batch_size]
44
+
45
+ random.shuffle(stage_indices)
46
+ stage_indices = torch.tensor(stage_indices, dtype=torch.int32)
47
+ orig_height, orig_width = pixel_values.shape[-2:]
48
+ timesteps = torch.randint(0, config.scheduler.num_train_timesteps, (batch_size,))
49
+
50
+ sample_list, input_ids_list, pos_embed_list, seq_len_list, target_list, timestep_list = [], [], [], [], [], []
51
+ for stage_idx in range(config.scheduler.num_stages):
52
+ corrected_stage_idx = config.scheduler.num_stages - stage_idx - 1
53
+ stage_select_indices = timesteps[stage_indices == corrected_stage_idx]
54
+ Timesteps = noise_scheduler_copy.Timesteps_per_stage[corrected_stage_idx][stage_select_indices].float()
55
+ batch_size_select = Timesteps.shape[0]
56
+ pixel_values_select = pixel_values[stage_indices == corrected_stage_idx]
57
+ input_ids_select = [input_ids[i] for i in range(batch_size) if stage_indices[i] == corrected_stage_idx]
58
+
59
+ end_height, end_width = orig_height // (2 ** stage_idx), orig_width // (2 ** stage_idx)
60
+
61
+ ################ build model input ################
62
+ start_t, end_t = noise_scheduler_copy.start_t[corrected_stage_idx], noise_scheduler_copy.end_t[corrected_stage_idx]
63
+
64
+ pixel_values_end = pixel_values_select
65
+ pixel_values_start = pixel_values_select
66
+ if stage_idx > 0:
67
+ # pixel_values_end
68
+ for downsample_idx in range(1, stage_idx + 1):
69
+ pixel_values_end = F.interpolate(pixel_values_end, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
70
+
71
+ # pixel_values_start
72
+ for downsample_idx in range(1, stage_idx + 2):
73
+ pixel_values_start = F.interpolate(pixel_values_start, (orig_height // (2 ** downsample_idx), orig_width // (2 ** downsample_idx)), mode="bilinear")
74
+ # upsample pixel_values_start
75
+ pixel_values_start = F.interpolate(pixel_values_start, (end_height, end_width), mode="nearest")
76
+
77
+ noise = torch.randn_like(pixel_values_end)
78
+ pixel_values_end = end_t * pixel_values_end + (1.0 - end_t) * noise
79
+ pixel_values_start = start_t * pixel_values_start + (1.0 - start_t) * noise
80
+ target = pixel_values_end - pixel_values_start
81
+
82
+ t_select = noise_scheduler_copy.t_window_per_stage[corrected_stage_idx][stage_select_indices].flatten()
83
+ while len(t_select.shape) < pixel_values_start.ndim:
84
+ t_select = t_select.unsqueeze(-1)
85
+ xt = t_select.float() * pixel_values_end + (1.0 - t_select.float()) * pixel_values_start
86
+
87
+ target = rearrange(target, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
88
+ xt = rearrange(xt, 'b c (h ph) (w pw) -> (b h w) (c ph pw)', ph=patch_size, pw=patch_size)
89
+
90
+ pos_embed = get_2d_rotary_pos_embed(
91
+ embed_dim=config.model.params.attention_head_dim,
92
+ crops_coords=((0, 0), (end_height // patch_size, end_width // patch_size)),
93
+ grid_size=(end_height // patch_size, end_width // patch_size),
94
+ )
95
+ seq_len = (end_height // patch_size) * (end_width // patch_size)
96
+ assert end_height == end_width, f"only support square image, got {seq_len}; TODO: latent_size_list"
97
+ sample_list.append(xt)
98
+ target_list.append(target)
99
+ pos_embed_list.extend([pos_embed] * batch_size_select)
100
+ seq_len_list.extend([seq_len] * batch_size_select)
101
+ timestep_list.append(Timesteps)
102
+ input_ids_list.extend(input_ids_select)
103
+
104
+ pixel_values = torch.cat(sample_list, dim=0).to(memory_format=torch.contiguous_format)
105
+ target_values = torch.cat(target_list, dim=0).to(memory_format=torch.contiguous_format)
106
+ pos_embed = torch.cat([torch.stack(one_pos_emb, -1) for one_pos_emb in pos_embed_list], dim=0).float()
107
+ cumsum_q_len = torch.cumsum(torch.tensor([0] + seq_len_list), 0).to(torch.int32)
108
+ latent_size_list = torch.tensor([int(math.sqrt(seq_len)) for seq_len in seq_len_list], dtype=torch.int32)
109
+
110
+ return {
111
+ "pixel_values": pixel_values,
112
+ "input_ids": input_ids_list,
113
+ "pos_embed": pos_embed,
114
+ "cumsum_q_len": cumsum_q_len,
115
+ "batch_latent_size": latent_size_list,
116
+ "seqlen_list_q": seq_len_list,
117
+ "cumsum_kv_len": None,
118
+ "batch_kv_len": None,
119
+ "timesteps": torch.cat(timestep_list, dim=0),
120
+ "target_values": target_values,
121
+ }
122
+
123
+
124
+ def build_imagenet_loader(config, noise_scheduler_copy):
125
+ if config.data.center_crop:
126
+ transform = transforms.Compose([
127
+ transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, config.data.resolution)),
128
+ transforms.RandomHorizontalFlip(),
129
+ transforms.ToTensor(),
130
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
131
+ ])
132
+ else:
133
+ transform = transforms.Compose([
134
+ transforms.Resize(round(config.data.resolution * config.data.expand_ratio), interpolation=transforms.InterpolationMode.LANCZOS),
135
+ transforms.RandomCrop(config.data.resolution),
136
+ transforms.RandomHorizontalFlip(),
137
+ transforms.ToTensor(),
138
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
139
+ ])
140
+ dataset = ImageFolder(config.data.root, transform=transform)
141
+ sampler = DistributedSampler(
142
+ dataset,
143
+ num_replicas=torch.distributed.get_world_size(),
144
+ rank=torch.distributed.get_rank(),
145
+ shuffle=True,
146
+ seed=config.seed,
147
+ )
148
+
149
+ loader = torch.utils.data.DataLoader(
150
+ dataset,
151
+ batch_size=config.data.batch_size,
152
+ collate_fn=partial(collate_fn, config=config, noise_scheduler_copy=noise_scheduler_copy),
153
+ shuffle=False,
154
+ sampler=sampler,
155
+ num_workers=config.data.num_workers,
156
+ drop_last=True,
157
+ )
158
+ return loader
pixelflow/model.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding
7
+ import warnings
8
+
9
+ try:
10
+ from flash_attn import flash_attn_varlen_func
11
+ except ImportError:
12
+ warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning)
13
+ flash_attn_varlen_func = None
14
+
15
+
16
+ def apply_rotary_emb(
17
+ x: torch.Tensor,
18
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
19
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
20
+ cos, sin = freqs_cis.unbind(-1)
21
+ cos = cos[None, None]
22
+ sin = sin[None, None]
23
+ cos, sin = cos.to(x.device), sin.to(x.device)
24
+
25
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
26
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
27
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
28
+
29
+ return out
30
+
31
+
32
+ class PatchEmbed(nn.Module):
33
+ def __init__(self, patch_size, in_channels, embed_dim, bias=True):
34
+ super().__init__()
35
+ self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
36
+
37
+ def forward_unfold(self, x):
38
+ out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t())
39
+ if self.proj.bias is not None:
40
+ out_unfold += self.proj.bias.to(out_unfold.dtype)
41
+ return out_unfold
42
+
43
+ # force fp32 for strict numerical reproducibility (debug only)
44
+ # @torch.autocast('cuda', enabled=False)
45
+ def forward(self, x):
46
+ if self.training:
47
+ return self.forward_unfold(x)
48
+ out = self.proj(x)
49
+ out = out.flatten(2).transpose(1, 2) # BCHW -> BNC
50
+
51
+ return out
52
+
53
+ class AdaLayerNorm(nn.Module):
54
+ def __init__(self, embedding_dim):
55
+ super().__init__()
56
+ self.embedding_dim = embedding_dim
57
+ self.silu = nn.SiLU()
58
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
59
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
60
+
61
+ def forward(self, x, timestep, seqlen_list=None):
62
+ input_dtype = x.dtype
63
+ emb = self.linear(self.silu(timestep))
64
+
65
+ if seqlen_list is not None:
66
+ # equivalent to `torch.repeat_interleave` but faster
67
+ emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)])
68
+ else:
69
+ emb = emb.unsqueeze(1)
70
+
71
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1)
72
+ x = self.norm(x).float() * (1 + scale_msa) + shift_msa
73
+ return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp
74
+
75
+
76
+ class FeedForward(nn.Module):
77
+ def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True):
78
+ super().__init__()
79
+ inner_dim = int(dim * mult) if inner_dim is None else inner_dim
80
+ dim_out = dim_out if dim_out is not None else dim
81
+ self.fc1 = nn.Linear(dim, inner_dim, bias=bias)
82
+ self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias)
83
+
84
+ def forward(self, hidden_states):
85
+ hidden_states = self.fc1(hidden_states)
86
+ hidden_states = F.gelu(hidden_states, approximate="tanh")
87
+ hidden_states = self.fc2(hidden_states)
88
+ return hidden_states
89
+
90
+
91
+ class RMSNorm(nn.Module):
92
+ def __init__(self, dim: int, eps=1e-6):
93
+ super().__init__()
94
+ self.weight = nn.Parameter(torch.ones(dim))
95
+ self.eps = eps
96
+
97
+ def forward(self, x):
98
+ output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
99
+ return (self.weight * output).to(x.dtype)
100
+
101
+
102
+ class Attention(nn.Module):
103
+ def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False):
104
+ super().__init__()
105
+ self.q_dim = q_dim
106
+ self.kv_dim = kv_dim if kv_dim is not None else q_dim
107
+ self.inner_dim = head_dim * heads
108
+ self.dropout = dropout
109
+ self.head_dim = head_dim
110
+ self.num_heads = heads
111
+
112
+ self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias)
113
+ self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
114
+ self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
115
+
116
+ self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias)
117
+
118
+ self.q_norm = RMSNorm(self.inner_dim)
119
+ self.k_norm = RMSNorm(self.inner_dim)
120
+
121
+ def prepare_attention_mask(
122
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L694
123
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
124
+ ):
125
+ head_size = self.num_heads
126
+ if attention_mask is None:
127
+ return attention_mask
128
+
129
+ current_length: int = attention_mask.shape[-1]
130
+ if current_length != target_length:
131
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
132
+
133
+ if out_dim == 3:
134
+ if attention_mask.shape[0] < batch_size * head_size:
135
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
136
+ elif out_dim == 4:
137
+ attention_mask = attention_mask.unsqueeze(1)
138
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
139
+
140
+ return attention_mask
141
+
142
+ def forward(
143
+ self,
144
+ inputs_q,
145
+ inputs_kv,
146
+ attention_mask=None,
147
+ cross_attention=False,
148
+ rope_pos_embed=None,
149
+ cu_seqlens_q=None,
150
+ cu_seqlens_k=None,
151
+ max_seqlen_q=None,
152
+ max_seqlen_k=None,
153
+ ):
154
+
155
+ inputs_kv = inputs_q if inputs_kv is None else inputs_kv
156
+
157
+ query_states = self.q_proj(inputs_q)
158
+ key_states = self.k_proj(inputs_kv)
159
+ value_states = self.v_proj(inputs_kv)
160
+
161
+ query_states = self.q_norm(query_states)
162
+ key_states = self.k_norm(key_states)
163
+
164
+ if max_seqlen_q is None:
165
+ assert not self.training, "PixelFlow needs sequence packing for training"
166
+
167
+ bsz, q_len, _ = inputs_q.shape
168
+ _, kv_len, _ = inputs_kv.shape
169
+
170
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
171
+ key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
172
+ value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
173
+
174
+ query_states = apply_rotary_emb(query_states, rope_pos_embed)
175
+ if not cross_attention:
176
+ key_states = apply_rotary_emb(key_states, rope_pos_embed)
177
+
178
+ if attention_mask is not None:
179
+ attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz)
180
+ # scaled_dot_product_attention expects attention_mask shape to be
181
+ # (batch, heads, source_length, target_length)
182
+ attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1])
183
+
184
+ # with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): # strict numerical reproducibility (debug only)
185
+ attn_output = F.scaled_dot_product_attention(
186
+ query_states,
187
+ key_states,
188
+ value_states,
189
+ attn_mask=attention_mask,
190
+ dropout_p=self.dropout if self.training else 0.0,
191
+ is_causal=False,
192
+ )
193
+
194
+ attn_output = attn_output.transpose(1, 2).contiguous()
195
+ attn_output = attn_output.view(bsz, q_len, self.inner_dim)
196
+ attn_output = self.o_proj(attn_output)
197
+ return attn_output
198
+
199
+ else:
200
+ # sequence packing mode
201
+ query_states = query_states.view(-1, self.num_heads, self.head_dim)
202
+ key_states = key_states.view(-1, self.num_heads, self.head_dim)
203
+ value_states = value_states.view(-1, self.num_heads, self.head_dim)
204
+
205
+ query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
206
+ if not cross_attention:
207
+ key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
208
+
209
+ attn_output = flash_attn_varlen_func(
210
+ query_states,
211
+ key_states,
212
+ value_states,
213
+ cu_seqlens_q=cu_seqlens_q,
214
+ cu_seqlens_k=cu_seqlens_k,
215
+ max_seqlen_q=max_seqlen_q,
216
+ max_seqlen_k=max_seqlen_k,
217
+ )
218
+
219
+ attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
220
+ attn_output = self.o_proj(attn_output)
221
+ return attn_output
222
+
223
+
224
+ class TransformerBlock(nn.Module):
225
+ def __init__(self, dim, num_attention_heads, attention_head_dim, dropout=0.0,
226
+ cross_attention_dim=None, attention_bias=False,
227
+ ):
228
+ super().__init__()
229
+ self.norm1 = AdaLayerNorm(dim)
230
+
231
+ # Self Attention
232
+ self.attn1 = Attention(q_dim=dim, kv_dim=None, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
233
+
234
+ if cross_attention_dim is not None:
235
+ # Cross Attention
236
+ self.norm2 = RMSNorm(dim, eps=1e-6)
237
+ self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
238
+ else:
239
+ self.attn2 = None
240
+
241
+ self.norm3 = RMSNorm(dim, eps=1e-6)
242
+ self.mlp = FeedForward(dim)
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states,
247
+ encoder_hidden_states=None,
248
+ encoder_attention_mask=None,
249
+ timestep=None,
250
+ rope_pos_embed=None,
251
+ cu_seqlens_q=None,
252
+ cu_seqlens_k=None,
253
+ seqlen_list_q=None,
254
+ seqlen_list_k=None,
255
+ ):
256
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q)
257
+
258
+ attn_output = self.attn1(
259
+ inputs_q=norm_hidden_states,
260
+ inputs_kv=None,
261
+ attention_mask=None,
262
+ cross_attention=False,
263
+ rope_pos_embed=rope_pos_embed,
264
+ cu_seqlens_q=cu_seqlens_q,
265
+ cu_seqlens_k=cu_seqlens_q,
266
+ max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
267
+ max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None,
268
+ )
269
+
270
+ attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype)
271
+ hidden_states = attn_output + hidden_states
272
+
273
+ if self.attn2 is not None:
274
+ norm_hidden_states = self.norm2(hidden_states)
275
+ attn_output = self.attn2(
276
+ inputs_q=norm_hidden_states,
277
+ inputs_kv=encoder_hidden_states,
278
+ attention_mask=encoder_attention_mask,
279
+ cross_attention=True,
280
+ rope_pos_embed=rope_pos_embed,
281
+ cu_seqlens_q=cu_seqlens_q,
282
+ cu_seqlens_k=cu_seqlens_k,
283
+ max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
284
+ max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None,
285
+ )
286
+ hidden_states = hidden_states + attn_output
287
+
288
+ norm_hidden_states = self.norm3(hidden_states)
289
+ norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype)
290
+ ff_output = self.mlp(norm_hidden_states)
291
+ ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype)
292
+ hidden_states = ff_output + hidden_states
293
+
294
+ return hidden_states
295
+
296
+
297
+ class PixelFlowModel(torch.nn.Module):
298
+ def __init__(self, in_channels, out_channels, num_attention_heads, attention_head_dim,
299
+ depth, patch_size, dropout=0.0, cross_attention_dim=None, attention_bias=True, num_classes=0,
300
+ ):
301
+ super().__init__()
302
+ self.patch_size = patch_size
303
+ self.attention_head_dim = attention_head_dim
304
+ self.num_classes = num_classes
305
+ self.out_channels = out_channels
306
+
307
+ embed_dim = num_attention_heads * attention_head_dim
308
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
309
+
310
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
311
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
312
+
313
+ # [stage] embedding
314
+ self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
315
+ if self.num_classes > 0:
316
+ # class conditional
317
+ self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1)
318
+
319
+ self.transformer_blocks = nn.ModuleList([
320
+ TransformerBlock(embed_dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, attention_bias) for _ in range(depth)
321
+ ])
322
+
323
+ self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
324
+ self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim)
325
+ self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels)
326
+
327
+ self.initialize_from_scratch()
328
+
329
+ def initialize_from_scratch(self):
330
+ print("Starting Initialization...")
331
+ def _basic_init(module):
332
+ if isinstance(module, nn.Linear):
333
+ torch.nn.init.xavier_uniform_(module.weight)
334
+ if module.bias is not None:
335
+ nn.init.constant_(module.bias, 0)
336
+ self.apply(_basic_init)
337
+
338
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
339
+ w = self.patch_embed.proj.weight.data
340
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
341
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
342
+
343
+ nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02)
344
+ nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02)
345
+
346
+ nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02)
347
+ nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02)
348
+
349
+ if self.num_classes > 0:
350
+ nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02)
351
+
352
+ for block in self.transformer_blocks:
353
+ nn.init.constant_(block.norm1.linear.weight, 0)
354
+ nn.init.constant_(block.norm1.linear.bias, 0)
355
+
356
+ nn.init.constant_(self.proj_out_1.weight, 0)
357
+ nn.init.constant_(self.proj_out_1.bias, 0)
358
+ nn.init.constant_(self.proj_out_2.weight, 0)
359
+ nn.init.constant_(self.proj_out_2.bias, 0)
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states,
364
+ encoder_hidden_states=None,
365
+ class_labels=None,
366
+ timestep=None,
367
+ latent_size=None,
368
+ encoder_attention_mask=None,
369
+ pos_embed=None,
370
+ cu_seqlens_q=None,
371
+ cu_seqlens_k=None,
372
+ seqlen_list_q=None,
373
+ seqlen_list_k=None,
374
+ ):
375
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
376
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
377
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
378
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
379
+
380
+ orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1]
381
+ hidden_states = hidden_states.to(torch.float32)
382
+ hidden_states = self.patch_embed(hidden_states)
383
+
384
+ # timestep, class_embed, latent_size_embed
385
+ timesteps_proj = self.time_proj(timestep)
386
+ conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
387
+
388
+ if self.num_classes > 0:
389
+ class_embed = self.class_embedder(class_labels)
390
+ conditioning += class_embed
391
+
392
+ latent_size_proj = self.time_proj(latent_size)
393
+ latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype))
394
+ conditioning += latent_size_embed
395
+
396
+ for block in self.transformer_blocks:
397
+ hidden_states = block(
398
+ hidden_states,
399
+ encoder_hidden_states=encoder_hidden_states,
400
+ encoder_attention_mask=encoder_attention_mask,
401
+ timestep=conditioning,
402
+ rope_pos_embed=pos_embed,
403
+ cu_seqlens_q=cu_seqlens_q,
404
+ cu_seqlens_k=cu_seqlens_k,
405
+ seqlen_list_q=seqlen_list_q,
406
+ seqlen_list_k=seqlen_list_k,
407
+ )
408
+
409
+ shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1)
410
+ if seqlen_list_q is None:
411
+ shift = shift.unsqueeze(1)
412
+ scale = scale.unsqueeze(1)
413
+ else:
414
+ shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)])
415
+ scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)])
416
+
417
+ hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype)
418
+ hidden_states = self.proj_out_2(hidden_states)
419
+ if self.training:
420
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels)
421
+ hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1)
422
+ return hidden_states
423
+
424
+ height, width = orig_height // self.patch_size, orig_width // self.patch_size
425
+ hidden_states = hidden_states.reshape(
426
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
427
+ )
428
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
429
+ output = hidden_states.reshape(
430
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
431
+ )
432
+
433
+ return output
434
+
435
+ def c2i_forward_cfg_torchdiffq(self, hidden_states, timestep, class_labels, latent_size, pos_embed, cfg_scale):
436
+ # used for evaluation with ODE ('dopri5') solver from torchdiffeq
437
+ half = hidden_states[: len(hidden_states)//2]
438
+ combined = torch.cat([half, half], dim=0)
439
+ out = self.forward(
440
+ hidden_states=combined,
441
+ timestep=timestep,
442
+ class_labels=class_labels,
443
+ latent_size=latent_size,
444
+ pos_embed=pos_embed,
445
+ )
446
+ uncond_out, cond_out = torch.split(out, len(out)//2, dim=0)
447
+ half_output = uncond_out + cfg_scale * (cond_out - uncond_out)
448
+ output = torch.cat([half_output, half_output], dim=0)
449
+ return output
pixelflow/pipeline_pixelflow.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange
2
+ import math
3
+ from typing import List, Optional, Union
4
+ import time
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.utils.torch_utils import randn_tensor
9
+ from diffusers.models.embeddings import get_2d_rotary_pos_embed
10
+
11
+
12
+ class PixelFlowPipeline:
13
+ def __init__(
14
+ self,
15
+ scheduler,
16
+ transformer,
17
+ text_encoder=None,
18
+ tokenizer=None,
19
+ max_token_length=512,
20
+ ):
21
+ super().__init__()
22
+ self.class_cond = text_encoder is None or tokenizer is None
23
+ self.scheduler = scheduler
24
+ self.transformer = transformer
25
+ self.patch_size = transformer.patch_size
26
+ self.head_dim = transformer.attention_head_dim
27
+ self.num_stages = scheduler.num_stages
28
+
29
+ self.text_encoder = text_encoder
30
+ self.tokenizer = tokenizer
31
+ self.max_token_length = max_token_length
32
+
33
+ @torch.autocast("cuda", enabled=False)
34
+ def encode_prompt(
35
+ self,
36
+ prompt: Union[str, List[str]],
37
+ device: Optional[torch.device] = None,
38
+ num_images_per_prompt: int = 1,
39
+ do_classifier_free_guidance: bool = True,
40
+ negative_prompt: Union[str, List[str]] = "",
41
+ prompt_embeds: Optional[torch.FloatTensor] = None,
42
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
43
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
44
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
45
+ use_attention_mask: bool = False,
46
+ max_length: int = 512,
47
+ ):
48
+ # Determine the batch size and normalize prompt input to a list
49
+ if prompt is not None:
50
+ if isinstance(prompt, str):
51
+ prompt = [prompt]
52
+ batch_size = len(prompt)
53
+ else:
54
+ batch_size = prompt_embeds.shape[0]
55
+
56
+ # Process prompt embeddings if not provided
57
+ if prompt_embeds is None:
58
+ text_inputs = self.tokenizer(
59
+ prompt,
60
+ padding="max_length",
61
+ max_length=max_length,
62
+ truncation=True,
63
+ add_special_tokens=True,
64
+ return_tensors="pt",
65
+ )
66
+ text_input_ids = text_inputs.input_ids.to(device)
67
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
68
+ prompt_embeds = self.text_encoder(
69
+ text_input_ids,
70
+ attention_mask=prompt_attention_mask if use_attention_mask else None
71
+ )[0]
72
+
73
+ # Determine dtype from available encoder
74
+ if self.text_encoder is not None:
75
+ dtype = self.text_encoder.dtype
76
+ elif self.transformer is not None:
77
+ dtype = self.transformer.dtype
78
+ else:
79
+ dtype = None
80
+
81
+ # Move prompt embeddings to desired dtype and device
82
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
83
+
84
+ bs_embed, seq_len, _ = prompt_embeds.shape
85
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
86
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
87
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
88
+
89
+ # Handle classifier-free guidance for negative prompts
90
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
91
+ # Normalize negative prompt to list and validate length
92
+ if isinstance(negative_prompt, str):
93
+ uncond_tokens = [negative_prompt] * batch_size
94
+ elif isinstance(negative_prompt, list):
95
+ if len(negative_prompt) != batch_size:
96
+ raise ValueError(f"The negative prompt list must have the same length as the prompt list, but got {len(negative_prompt)} and {batch_size}")
97
+ uncond_tokens = negative_prompt
98
+ else:
99
+ raise ValueError(f"Negative prompt must be a string or a list of strings, but got {type(negative_prompt)}")
100
+
101
+ # Tokenize and encode negative prompts
102
+ uncond_inputs = self.tokenizer(
103
+ uncond_tokens,
104
+ padding="max_length",
105
+ max_length=prompt_embeds.shape[1],
106
+ truncation=True,
107
+ return_attention_mask=True,
108
+ add_special_tokens=True,
109
+ return_tensors="pt",
110
+ )
111
+ negative_input_ids = uncond_inputs.input_ids.to(device)
112
+ negative_prompt_attention_mask = uncond_inputs.attention_mask.to(device)
113
+ negative_prompt_embeds = self.text_encoder(
114
+ negative_input_ids,
115
+ attention_mask=negative_prompt_attention_mask if use_attention_mask else None
116
+ )[0]
117
+
118
+ if do_classifier_free_guidance:
119
+ # Duplicate negative prompt embeddings and attention mask for each generation
120
+ seq_len_neg = negative_prompt_embeds.shape[1]
121
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
122
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
123
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len_neg, -1)
124
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1).repeat(num_images_per_prompt, 1)
125
+ else:
126
+ negative_prompt_embeds = None
127
+ negative_prompt_attention_mask = None
128
+
129
+ # Concatenate negative and positive embeddings and their masks
130
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
131
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
132
+
133
+ return prompt_embeds, prompt_attention_mask
134
+
135
+ def sample_block_noise(self, bs, ch, height, width, eps=1e-6)):
136
+ gamma = self.scheduler.gamma
137
+ dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 - gamma) + torch.ones(4, 4) * gamma + eps * torch.eye(4))
138
+ block_number = bs * ch * (height // 2) * (width // 2)
139
+ noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
140
+ noise = rearrange(noise, '(b c h w) (p q) -> b c (h p) (w q)',b=bs,c=ch,h=height//2,w=width//2,p=2,q=2)
141
+ return noise
142
+
143
+ @torch.no_grad()
144
+ def __call__(
145
+ self,
146
+ prompt,
147
+ height,
148
+ width,
149
+ num_inference_steps=30,
150
+ guidance_scale=4.0,
151
+ num_images_per_prompt=1,
152
+ device=None,
153
+ shift=1.0,
154
+ use_ode_dopri5=False,
155
+ ):
156
+ if isinstance(num_inference_steps, int):
157
+ num_inference_steps = [num_inference_steps] * self.num_stages
158
+
159
+ if use_ode_dopri5:
160
+ assert self.class_cond, "ODE (dopri5) sampling is only supported for class-conditional models now"
161
+ from pixelflow.solver_ode_wrapper import ODE
162
+ sample_fn = ODE(t0=0, t1=1, sampler_type="dopri5", num_steps=num_inference_steps[0], atol=1e-06, rtol=0.001).sample
163
+ else:
164
+ # default Euler
165
+ sample_fn = None
166
+
167
+ self._guidance_scale = guidance_scale
168
+ batch_size = len(prompt)
169
+ if self.class_cond:
170
+ prompt_embeds = torch.tensor(prompt, dtype=torch.int32).to(device)
171
+ negative_prompt_embeds = 1000 * torch.ones_like(prompt_embeds)
172
+ if self.do_classifier_free_guidance:
173
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
174
+ else:
175
+ prompt_embeds, prompt_attention_mask = self.encode_prompt(
176
+ prompt,
177
+ device,
178
+ num_images_per_prompt,
179
+ guidance_scale > 1,
180
+ "",
181
+ prompt_embeds=None,
182
+ negative_prompt_embeds=None,
183
+ use_attention_mask=True,
184
+ max_length=self.max_token_length,
185
+ )
186
+
187
+ init_factor = 2 ** (self.num_stages - 1)
188
+ height, width = height // init_factor, width // init_factor
189
+ shape = (batch_size * num_images_per_prompt, 3, height, width)
190
+ latents = randn_tensor(shape, device=device, dtype=torch.float32)
191
+
192
+ for stage_idx in range(self.num_stages):
193
+ stage_start = time.time()
194
+ # Set the number of inference steps for the current stage
195
+ self.scheduler.set_timesteps(num_inference_steps[stage_idx], stage_idx, device=device, shift=shift)
196
+ Timesteps = self.scheduler.Timesteps
197
+
198
+ if stage_idx > 0:
199
+ height, width = height * 2, width * 2
200
+ latents = F.interpolate(latents, size=(height, width), mode='nearest')
201
+ original_start_t = self.scheduler.original_start_t[stage_idx]
202
+ gamma = self.scheduler.gamma
203
+ alpha = 1 / (math.sqrt(1 - (1 / gamma)) * (1 - original_start_t) + original_start_t)
204
+ beta = alpha * (1 - original_start_t) / math.sqrt(- gamma)
205
+
206
+ # bs, ch, height, width = latents.shape
207
+ noise = self.sample_block_noise(*latents.shape)
208
+ noise = noise.to(device=device, dtype=latents.dtype)
209
+ latents = alpha * latents + beta * noise
210
+
211
+ size_tensor = torch.tensor([latents.shape[-1] // self.patch_size], dtype=torch.int32, device=device)
212
+ pos_embed = get_2d_rotary_pos_embed(
213
+ embed_dim=self.head_dim,
214
+ crops_coords=((0, 0), (latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size)),
215
+ grid_size=(latents.shape[-1] // self.patch_size, latents.shape[-1] // self.patch_size),
216
+ )
217
+ rope_pos = torch.stack(pos_embed, -1)
218
+
219
+ if sample_fn is not None:
220
+ # dopri5
221
+ model_kwargs = dict(class_labels=prompt_embeds, cfg_scale=self.guidance_scale(None, stage_idx), latent_size=size_tensor, pos_embed=rope_pos)
222
+ if stage_idx == 0:
223
+ latents = torch.cat([latents] * 2)
224
+ stage_T_start = self.scheduler.Timesteps_per_stage[stage_idx][0].item()
225
+ stage_T_end = self.scheduler.Timesteps_per_stage[stage_idx][-1].item()
226
+ latents = sample_fn(latents, self.transformer.c2i_forward_cfg_torchdiffq, stage_T_start, stage_T_end, **model_kwargs)[-1]
227
+ if stage_idx == self.num_stages - 1:
228
+ latents = latents[:latents.shape[0] // 2]
229
+ else:
230
+ # euler
231
+ for T in Timesteps:
232
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
233
+ timestep = T.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
234
+ if self.class_cond:
235
+ noise_pred = self.transformer(latent_model_input, timestep=timestep, class_labels=prompt_embeds, latent_size=size_tensor, pos_embed=rope_pos)
236
+ else:
237
+ encoder_hidden_states = prompt_embeds
238
+ encoder_attention_mask = prompt_attention_mask
239
+
240
+ noise_pred = self.transformer(
241
+ latent_model_input,
242
+ encoder_hidden_states=encoder_hidden_states,
243
+ encoder_attention_mask=encoder_attention_mask,
244
+ timestep=timestep,
245
+ latent_size=size_tensor,
246
+ pos_embed=rope_pos,
247
+ )
248
+
249
+ if self.do_classifier_free_guidance:
250
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
251
+ noise_pred = noise_pred_uncond + self.guidance_scale(T, stage_idx) * (noise_pred_text - noise_pred_uncond)
252
+
253
+ latents = self.scheduler.step(model_output=noise_pred, sample=latents)
254
+ stage_end = time.time()
255
+
256
+ samples = (latents / 2 + 0.5).clamp(0, 1)
257
+ samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
258
+ return samples
259
+
260
+ @property
261
+ def device(self):
262
+ return next(self.transformer.parameters()).device
263
+
264
+ @property
265
+ def dtype(self):
266
+ return next(self.transformer.parameters()).dtype
267
+
268
+ def guidance_scale(self, step=None, stage_idx=None):
269
+ if not self.class_cond:
270
+ return self._guidance_scale
271
+ scale_dict = {0: 0, 1: 1/6, 2: 2/3, 3: 1}
272
+ return (self._guidance_scale - 1) * scale_dict[stage_idx] + 1
273
+
274
+ @property
275
+ def do_classifier_free_guidance(self):
276
+ return self._guidance_scale > 0
pixelflow/scheduling_pixelflow.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def cal_rectify_ratio(start_t, gamma):
7
+ return 1 / (math.sqrt(1 - (1 / gamma)) * (1 - start_t) + start_t)
8
+
9
+
10
+ class PixelFlowScheduler:
11
+ def __init__(self, num_train_timesteps, num_stages, gamma=-1 / 3):
12
+ assert num_stages > 0, f"num_stages must be positive, got {num_stages}"
13
+ self.num_stages = num_stages
14
+ self.gamma = gamma
15
+
16
+ self.Timesteps = torch.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=torch.float32)
17
+
18
+ self.t = self.Timesteps / num_train_timesteps # normalized time in [0, 1]
19
+
20
+ self.stage_range = [x / num_stages for x in range(num_stages + 1)]
21
+
22
+ self.original_start_t = dict()
23
+ self.start_t, self.end_t = dict(), dict()
24
+ self.t_window_per_stage = dict()
25
+ self.Timesteps_per_stage = dict()
26
+ stage_distance = list()
27
+
28
+ # stage_idx = 0: min t, min resolution, most noisy
29
+ # stage_idx = num_stages - 1 : max t, max resolution, most clear
30
+ for stage_idx in range(num_stages):
31
+ start_idx = max(int(num_train_timesteps * self.stage_range[stage_idx]), 0)
32
+ end_idx = min(int(num_train_timesteps * self.stage_range[stage_idx + 1]), num_train_timesteps)
33
+
34
+ start_t = self.t[start_idx].item()
35
+ end_t = self.t[end_idx].item() if end_idx < num_train_timesteps else 1.0
36
+
37
+ self.original_start_t[stage_idx] = start_t
38
+
39
+ if stage_idx > 0:
40
+ start_t *= cal_rectify_ratio(start_t, gamma)
41
+
42
+ self.start_t[stage_idx] = start_t
43
+ self.end_t[stage_idx] = end_t
44
+ stage_distance.append(end_t - start_t)
45
+
46
+ total_stage_distance = sum(stage_distance)
47
+ t_within_stage = torch.linspace(0, 1, num_train_timesteps + 1, dtype=torch.float64)[:-1]
48
+
49
+ for stage_idx in range(num_stages):
50
+ start_ratio = 0.0 if stage_idx == 0 else sum(stage_distance[:stage_idx]) / total_stage_distance
51
+ end_ratio = 1.0 if stage_idx == num_stages - 1 else sum(stage_distance[:stage_idx + 1]) / total_stage_distance
52
+
53
+ Timestep_start = self.Timesteps[int(num_train_timesteps * start_ratio)]
54
+ Timestep_end = self.Timesteps[min(int(num_train_timesteps * end_ratio), num_train_timesteps - 1)]
55
+
56
+ self.t_window_per_stage[stage_idx] = t_within_stage
57
+
58
+ if stage_idx == num_stages - 1:
59
+ self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps, dtype=torch.float64)
60
+ else:
61
+ self.Timesteps_per_stage[stage_idx] = torch.linspace(Timestep_start.item(), Timestep_end.item(), num_train_timesteps + 1, dtype=torch.float64)[:-1]
62
+
63
+ @staticmethod
64
+ def time_linear_to_Timesteps(t, t_start, t_end, T_start, T_end):
65
+ """
66
+ linearly map t to T: T = k * t + b
67
+ """
68
+ k = (T_end - T_start) / (t_end - t_start)
69
+ b = T_start - t_start * k
70
+ return k * t + b
71
+
72
+ def set_timesteps(self, num_inference_steps, stage_index, device=None, shift=1.0):
73
+ self.num_inference_steps = num_inference_steps
74
+
75
+ stage_T_start = self.Timesteps_per_stage[stage_index][0].item()
76
+ stage_T_end = self.Timesteps_per_stage[stage_index][-1].item()
77
+
78
+ t_start = self.t_window_per_stage[stage_index][0].item()
79
+ t_end = self.t_window_per_stage[stage_index][-1].item()
80
+
81
+ t = np.linspace(t_start, t_end, num_inference_steps, dtype=np.float64)
82
+ t = t / (shift + (1 - shift) * t)
83
+
84
+ Timesteps = self.time_linear_to_Timesteps(t, t_start, t_end, stage_T_start, stage_T_end)
85
+ self.Timesteps = torch.from_numpy(Timesteps).to(device=device)
86
+
87
+ self.t = torch.from_numpy(np.append(t, 1.0)).to(device=device, dtype=torch.float64)
88
+ self._step_index = None
89
+
90
+ def step(self, model_output, sample):
91
+ if self.step_index is None:
92
+ self._step_index = 0
93
+
94
+ sample = sample.to(torch.float32)
95
+ t = self.t[self.step_index].float()
96
+ t_next = self.t[self.step_index + 1].float()
97
+
98
+ prev_sample = sample + (t_next - t) * model_output
99
+ self._step_index += 1
100
+
101
+ return prev_sample.to(model_output.dtype)
102
+
103
+ @property
104
+ def step_index(self):
105
+ """Current step index for the scheduler."""
106
+ return self._step_index
pixelflow/solver_ode_wrapper.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchdiffeq import odeint
3
+
4
+
5
+ # https://github.com/willisma/SiT/blob/main/transport/integrators.py#L77
6
+ class ODE:
7
+ """ODE solver class"""
8
+ def __init__(
9
+ self,
10
+ *,
11
+ t0,
12
+ t1,
13
+ sampler_type,
14
+ num_steps,
15
+ atol,
16
+ rtol,
17
+ ):
18
+ assert t0 < t1, "ODE sampler has to be in forward time"
19
+
20
+ self.t = torch.linspace(t0, t1, num_steps)
21
+ self.atol = atol
22
+ self.rtol = rtol
23
+ self.sampler_type = sampler_type
24
+
25
+ def time_linear_to_Timesteps(self, t, t_start, t_end, T_start, T_end):
26
+ # T = k * t + b
27
+ k = (T_end - T_start) / (t_end - t_start)
28
+ b = T_start - t_start * k
29
+ return k * t + b
30
+
31
+ def sample(self, x, model, T_start, T_end, **model_kwargs):
32
+ device = x[0].device if isinstance(x, tuple) else x.device
33
+ def _fn(t, x):
34
+ t = torch.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else torch.ones(x.size(0)).to(device) * t
35
+ model_output = model(x, self.time_linear_to_Timesteps(t, 0, 1, T_start, T_end), **model_kwargs)
36
+ assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
37
+ return model_output
38
+
39
+ t = self.t.to(device)
40
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
41
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
42
+ samples = odeint(
43
+ _fn,
44
+ x,
45
+ t,
46
+ method=self.sampler_type,
47
+ atol=atol,
48
+ rtol=rtol
49
+ )
50
+ return samples
pixelflow/utils/config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def get_obj_from_str(string, reload=False):
5
+ module, cls = string.rsplit(".", 1)
6
+ if reload:
7
+ module_imp = importlib.import_module(module)
8
+ importlib.reload(module_imp)
9
+ return getattr(importlib.import_module(module, package=None), cls)
10
+
11
+
12
+ def instantiate_from_config(config):
13
+ if not "target" in config:
14
+ raise KeyError("Expected key `target` to instantiate.")
15
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
16
+
17
+
18
+ def instantiate_optimizer_from_config(config, params):
19
+ if not "target" in config:
20
+ raise KeyError("Expected key `target` to instantiate.")
21
+ return get_obj_from_str(config["target"])(params, **config.get("params", dict()))
22
+
23
+
24
+ def instantiate_dataset_from_config(config, transform):
25
+ if not "target" in config:
26
+ raise KeyError("Expected key `target` to instantiate.")
27
+ return get_obj_from_str(config["target"])(transform=transform, **config.get("params", dict()))
pixelflow/utils/logger.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+
5
+ class PathSimplifierFormatter(logging.Formatter):
6
+ def format(self, record):
7
+ record.short_path = os.path.relpath(record.pathname)
8
+ return super().format(record)
9
+
10
+
11
+ def setup_logger(log_directory, experiment_name, process_rank, source_module=__name__):
12
+ handlers = [logging.StreamHandler()]
13
+
14
+ if process_rank == 0:
15
+ log_file_path = os.path.join(log_directory, f"{experiment_name}.log")
16
+ handlers.append(logging.FileHandler(log_file_path))
17
+
18
+ log_formatter = PathSimplifierFormatter(
19
+ fmt='[%(asctime)s %(short_path)s:%(lineno)d] %(message)s',
20
+ datefmt='%Y-%m-%d %H:%M:%S'
21
+ )
22
+
23
+ for handler in handlers:
24
+ handler.setFormatter(log_formatter)
25
+
26
+ logging.basicConfig(level=logging.INFO, handlers=handlers)
27
+ return logging.getLogger(source_module)
pixelflow/utils/misc.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+
6
+ def seed_everything(seed=0, deterministic_ops=True, allow_tf32=False):
7
+ """
8
+ Sets the seed for reproducibility across various libraries and frameworks, and configures PyTorch backend settings.
9
+
10
+ Args:
11
+ seed (int): The seed value for random number generation. Default is 0.
12
+ deterministic_ops (bool): Whether to enable deterministic operations in PyTorch.
13
+ Enabling this can make results reproducible at the cost of potential performance degradation. Default is True.
14
+ allow_tf32 (bool): Whether to allow TensorFloat-32 (TF32) precision in PyTorch operations. TF32 can improve performance but may affect reproducibility. Default is False.
15
+
16
+ Effects:
17
+ - Seeds Python's random module, NumPy, and PyTorch (CPU and GPU).
18
+ - Sets the environment variable `PYTHONHASHSEED` to the specified seed.
19
+ - Configures PyTorch to use deterministic algorithms if `deterministic_ops` is True.
20
+ - Configures TensorFloat-32 precision based on `allow_tf32`.
21
+ - Issues warnings if configurations may impact reproducibility.
22
+
23
+ Notes:
24
+ - Setting `torch.backends.cudnn.deterministic` to False allows nondeterministic operations, which may introduce variability.
25
+ - Allowing TF32 (`allow_tf32=True`) may lead to non-reproducible results, especially in matrix operations.
26
+ """
27
+ # Seed standard random number generators
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ os.environ['PYTHONHASHSEED'] = str(seed)
31
+
32
+ # Seed PyTorch random number generators
33
+ torch.manual_seed(seed)
34
+ torch.cuda.manual_seed_all(seed)
35
+
36
+ # Configure deterministic operations
37
+ if deterministic_ops:
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.use_deterministic_algorithms(True)
40
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
41
+ else:
42
+ torch.backends.cudnn.deterministic = False
43
+ print("WARNING: torch.backends.cudnn.deterministic is set to False, reproducibility is not guaranteed.")
44
+
45
+ # Configure TensorFloat-32 precision
46
+ if allow_tf32:
47
+ print("WARNING: TensorFloat-32 (TF32) is enabled; reproducibility is not guaranteed.")
48
+
49
+ torch.backends.cudnn.allow_tf32 = allow_tf32 # Default True in PyTorch 2.6.0
50
+ torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Default False in PyTorch 2.6.0
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ einops
2
+ pandas
3
+ pyarrow
4
+ omegaconf
5
+ diffusers==0.32.2
6
+ transformers==4.48.0
7
+ torchdiffeq==0.2.4