SakuraD commited on
Commit
0f1b479
Β·
1 Parent(s): 06853f2
Files changed (8) hide show
  1. app.py +102 -0
  2. hitting_baseball.mp4 +0 -0
  3. hoverboarding.mp4 +0 -0
  4. kinetics_class_index.py +402 -0
  5. requirements.txt +6 -0
  6. transforms.py +443 -0
  7. uniformer.py +379 -0
  8. yoga.mp4 +0 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ from decord import VideoReader
9
+ from decord import cpu
10
+ from uniformer import uniformer_small
11
+ from kinetics_class_index import kinetics_classnames
12
+ from transforms import (
13
+ GroupNormalize, GroupScale, GroupCenterCrop,
14
+ Stack, ToTorchFormatTensor
15
+ )
16
+
17
+ import gradio as gr
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ def get_index(num_frames, num_segments=16, dense_sample_rate=8):
22
+ sample_range = num_segments * dense_sample_rate
23
+ sample_pos = max(1, 1 + num_frames - sample_range)
24
+ t_stride = dense_sample_rate
25
+ start_idx = 0 if sample_pos == 1 else sample_pos // 2
26
+ offsets = np.array([
27
+ (idx * t_stride + start_idx) %
28
+ num_frames for idx in range(num_segments)
29
+ ])
30
+ return offsets + 1
31
+
32
+
33
+ def load_video(video_path):
34
+ vr = VideoReader(video_path, ctx=cpu(0))
35
+ num_frames = len(vr)
36
+ frame_indices = get_index(num_frames, 16, 16)
37
+
38
+ # transform
39
+ crop_size = 224
40
+ scale_size = 256
41
+ input_mean = [0.485, 0.456, 0.406]
42
+ input_std = [0.229, 0.224, 0.225]
43
+
44
+ transform = T.Compose([
45
+ GroupScale(int(scale_size)),
46
+ GroupCenterCrop(crop_size),
47
+ Stack(),
48
+ ToTorchFormatTensor(),
49
+ GroupNormalize(input_mean, input_std)
50
+ ])
51
+
52
+ images_group = list()
53
+ for frame_index in frame_indices:
54
+ img = Image.fromarray(vr[frame_index].asnumpy())
55
+ images_group.append(img)
56
+ torch_imgs = transform(images_group)
57
+ return torch_imgs
58
+
59
+
60
+ def inference(video):
61
+ vid = load_video(video)
62
+
63
+ # The model expects inputs of shape: B x C x T x H x W
64
+ TC, H, W = vid.shape
65
+ inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
66
+
67
+ prediction = model(inputs)
68
+ prediction = F.softmax(prediction, dim=1).flatten()
69
+
70
+ return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
71
+
72
+
73
+ # Device on which to run the model
74
+ # Set to cuda to load on GPU
75
+ device = "cpu"
76
+ model_path = hf_hub_download(repo_id="Andy1621/uniformer", filename="uniformer_small_k400_16x8.pth")
77
+ # Pick a pretrained model
78
+ model = uniformer_small()
79
+ state_dict = torch.load(model_path, map_location='cpu')
80
+ model.load_state_dict(state_dict)
81
+
82
+ # Set to eval mode and move to desired device
83
+ model = model.to(device)
84
+ model = model.eval()
85
+
86
+ # Create an id to label name mapping
87
+ kinetics_id_to_classname = {}
88
+ for k, v in kinetics_classnames.items():
89
+ kinetics_id_to_classname[k] = v
90
+
91
+ inputs = gr.inputs.Video()
92
+ label = gr.outputs.Label(num_top_classes=5)
93
+
94
+ title = "UniFormer-S"
95
+ description = "Gradio demo for UniFormer: To use it, simply upload your video, or click one of the examples to load them. Read more at the links below."
96
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.04676' target='_blank'>[ICLR2022] UniFormer: Unified Transformer for Efficient Spatiotemporal Representation Learning</a> | <a href='https://github.com/Sense-X/UniFormer' target='_blank'>Github Repo</a></p>"
97
+
98
+ gr.Interface(
99
+ inference, inputs, outputs=label,
100
+ title=title, description=description, article=article,
101
+ examples=[['hitting_baseball.mp4'], ['hoverboarding.mp4'], ['yoga.mp4']]
102
+ ).launch(enable_queue=True, cache_examples=True)
hitting_baseball.mp4 ADDED
Binary file (687 kB). View file
 
hoverboarding.mp4 ADDED
Binary file (464 kB). View file
 
kinetics_class_index.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kinetics_classnames = {
2
+ "0": "riding a bike",
3
+ "1": "marching",
4
+ "2": "dodgeball",
5
+ "3": "playing cymbals",
6
+ "4": "checking tires",
7
+ "5": "roller skating",
8
+ "6": "tasting beer",
9
+ "7": "clapping",
10
+ "8": "drawing",
11
+ "9": "juggling fire",
12
+ "10": "bobsledding",
13
+ "11": "petting animal (not cat)",
14
+ "12": "spray painting",
15
+ "13": "training dog",
16
+ "14": "eating watermelon",
17
+ "15": "building cabinet",
18
+ "16": "applauding",
19
+ "17": "playing harp",
20
+ "18": "balloon blowing",
21
+ "19": "sled dog racing",
22
+ "20": "wrestling",
23
+ "21": "pole vault",
24
+ "22": "hurling (sport)",
25
+ "23": "riding scooter",
26
+ "24": "shearing sheep",
27
+ "25": "sweeping floor",
28
+ "26": "eating carrots",
29
+ "27": "skateboarding",
30
+ "28": "dunking basketball",
31
+ "29": "disc golfing",
32
+ "30": "eating spaghetti",
33
+ "31": "playing flute",
34
+ "32": "riding mechanical bull",
35
+ "33": "making sushi",
36
+ "34": "trapezing",
37
+ "35": "picking fruit",
38
+ "36": "stretching leg",
39
+ "37": "playing ukulele",
40
+ "38": "tying tie",
41
+ "39": "skydiving",
42
+ "40": "playing cello",
43
+ "41": "jumping into pool",
44
+ "42": "shooting goal (soccer)",
45
+ "43": "trimming trees",
46
+ "44": "bookbinding",
47
+ "45": "ski jumping",
48
+ "46": "walking the dog",
49
+ "47": "riding unicycle",
50
+ "48": "shaving head",
51
+ "49": "hopscotch",
52
+ "50": "playing piano",
53
+ "51": "parasailing",
54
+ "52": "bartending",
55
+ "53": "kicking field goal",
56
+ "54": "finger snapping",
57
+ "55": "dining",
58
+ "56": "yawning",
59
+ "57": "peeling potatoes",
60
+ "58": "canoeing or kayaking",
61
+ "59": "front raises",
62
+ "60": "laughing",
63
+ "61": "dancing macarena",
64
+ "62": "digging",
65
+ "63": "reading newspaper",
66
+ "64": "hitting baseball",
67
+ "65": "clay pottery making",
68
+ "66": "exercising with an exercise ball",
69
+ "67": "playing saxophone",
70
+ "68": "shooting basketball",
71
+ "69": "washing hair",
72
+ "70": "lunge",
73
+ "71": "brushing hair",
74
+ "72": "curling hair",
75
+ "73": "kitesurfing",
76
+ "74": "tapping guitar",
77
+ "75": "bending back",
78
+ "76": "skipping rope",
79
+ "77": "situp",
80
+ "78": "folding paper",
81
+ "79": "cracking neck",
82
+ "80": "assembling computer",
83
+ "81": "cleaning gutters",
84
+ "82": "blowing out candles",
85
+ "83": "shaking hands",
86
+ "84": "dancing gangnam style",
87
+ "85": "windsurfing",
88
+ "86": "tap dancing",
89
+ "87": "skiing (not slalom or crosscountry)",
90
+ "88": "bandaging",
91
+ "89": "push up",
92
+ "90": "doing nails",
93
+ "91": "punching person (boxing)",
94
+ "92": "bouncing on trampoline",
95
+ "93": "scrambling eggs",
96
+ "94": "singing",
97
+ "95": "cleaning floor",
98
+ "96": "krumping",
99
+ "97": "drumming fingers",
100
+ "98": "snowmobiling",
101
+ "99": "gymnastics tumbling",
102
+ "100": "headbanging",
103
+ "101": "catching or throwing frisbee",
104
+ "102": "riding elephant",
105
+ "103": "bee keeping",
106
+ "104": "feeding birds",
107
+ "105": "snatch weight lifting",
108
+ "106": "mowing lawn",
109
+ "107": "fixing hair",
110
+ "108": "playing trumpet",
111
+ "109": "flying kite",
112
+ "110": "crossing river",
113
+ "111": "swinging legs",
114
+ "112": "sanding floor",
115
+ "113": "belly dancing",
116
+ "114": "sneezing",
117
+ "115": "clean and jerk",
118
+ "116": "side kick",
119
+ "117": "filling eyebrows",
120
+ "118": "shuffling cards",
121
+ "119": "recording music",
122
+ "120": "cartwheeling",
123
+ "121": "feeding fish",
124
+ "122": "folding clothes",
125
+ "123": "water skiing",
126
+ "124": "tobogganing",
127
+ "125": "blowing leaves",
128
+ "126": "smoking",
129
+ "127": "unboxing",
130
+ "128": "tai chi",
131
+ "129": "waxing legs",
132
+ "130": "riding camel",
133
+ "131": "slapping",
134
+ "132": "tossing salad",
135
+ "133": "capoeira",
136
+ "134": "playing cards",
137
+ "135": "playing organ",
138
+ "136": "playing violin",
139
+ "137": "playing drums",
140
+ "138": "tapping pen",
141
+ "139": "vault",
142
+ "140": "shoveling snow",
143
+ "141": "playing tennis",
144
+ "142": "getting a tattoo",
145
+ "143": "making a sandwich",
146
+ "144": "making tea",
147
+ "145": "grinding meat",
148
+ "146": "squat",
149
+ "147": "eating doughnuts",
150
+ "148": "ice fishing",
151
+ "149": "snowkiting",
152
+ "150": "kicking soccer ball",
153
+ "151": "playing controller",
154
+ "152": "giving or receiving award",
155
+ "153": "welding",
156
+ "154": "throwing discus",
157
+ "155": "throwing axe",
158
+ "156": "ripping paper",
159
+ "157": "swimming butterfly stroke",
160
+ "158": "air drumming",
161
+ "159": "blowing nose",
162
+ "160": "hockey stop",
163
+ "161": "taking a shower",
164
+ "162": "bench pressing",
165
+ "163": "planting trees",
166
+ "164": "pumping fist",
167
+ "165": "climbing tree",
168
+ "166": "tickling",
169
+ "167": "high kick",
170
+ "168": "waiting in line",
171
+ "169": "slacklining",
172
+ "170": "tango dancing",
173
+ "171": "hurdling",
174
+ "172": "carrying baby",
175
+ "173": "celebrating",
176
+ "174": "sharpening knives",
177
+ "175": "passing American football (in game)",
178
+ "176": "headbutting",
179
+ "177": "playing recorder",
180
+ "178": "brush painting",
181
+ "179": "garbage collecting",
182
+ "180": "robot dancing",
183
+ "181": "shredding paper",
184
+ "182": "pumping gas",
185
+ "183": "rock climbing",
186
+ "184": "hula hooping",
187
+ "185": "braiding hair",
188
+ "186": "opening present",
189
+ "187": "texting",
190
+ "188": "decorating the christmas tree",
191
+ "189": "answering questions",
192
+ "190": "playing keyboard",
193
+ "191": "writing",
194
+ "192": "bungee jumping",
195
+ "193": "sniffing",
196
+ "194": "eating burger",
197
+ "195": "playing accordion",
198
+ "196": "making pizza",
199
+ "197": "playing volleyball",
200
+ "198": "tasting food",
201
+ "199": "pushing cart",
202
+ "200": "spinning poi",
203
+ "201": "cleaning windows",
204
+ "202": "arm wrestling",
205
+ "203": "changing oil",
206
+ "204": "swimming breast stroke",
207
+ "205": "tossing coin",
208
+ "206": "deadlifting",
209
+ "207": "hoverboarding",
210
+ "208": "cutting watermelon",
211
+ "209": "cheerleading",
212
+ "210": "snorkeling",
213
+ "211": "washing hands",
214
+ "212": "eating cake",
215
+ "213": "pull ups",
216
+ "214": "surfing water",
217
+ "215": "eating hotdog",
218
+ "216": "holding snake",
219
+ "217": "playing harmonica",
220
+ "218": "ironing",
221
+ "219": "cutting nails",
222
+ "220": "golf chipping",
223
+ "221": "shot put",
224
+ "222": "hugging",
225
+ "223": "playing clarinet",
226
+ "224": "faceplanting",
227
+ "225": "trimming or shaving beard",
228
+ "226": "drinking shots",
229
+ "227": "riding mountain bike",
230
+ "228": "tying bow tie",
231
+ "229": "swinging on something",
232
+ "230": "skiing crosscountry",
233
+ "231": "unloading truck",
234
+ "232": "cleaning pool",
235
+ "233": "jogging",
236
+ "234": "ice climbing",
237
+ "235": "mopping floor",
238
+ "236": "making bed",
239
+ "237": "diving cliff",
240
+ "238": "washing dishes",
241
+ "239": "grooming dog",
242
+ "240": "weaving basket",
243
+ "241": "frying vegetables",
244
+ "242": "stomping grapes",
245
+ "243": "moving furniture",
246
+ "244": "cooking sausages",
247
+ "245": "doing laundry",
248
+ "246": "dying hair",
249
+ "247": "knitting",
250
+ "248": "reading book",
251
+ "249": "baby waking up",
252
+ "250": "punching bag",
253
+ "251": "surfing crowd",
254
+ "252": "cooking chicken",
255
+ "253": "pushing car",
256
+ "254": "springboard diving",
257
+ "255": "swing dancing",
258
+ "256": "massaging legs",
259
+ "257": "beatboxing",
260
+ "258": "breading or breadcrumbing",
261
+ "259": "somersaulting",
262
+ "260": "brushing teeth",
263
+ "261": "stretching arm",
264
+ "262": "juggling balls",
265
+ "263": "massaging person's head",
266
+ "264": "eating ice cream",
267
+ "265": "extinguishing fire",
268
+ "266": "hammer throw",
269
+ "267": "whistling",
270
+ "268": "crawling baby",
271
+ "269": "using remote controller (not gaming)",
272
+ "270": "playing cricket",
273
+ "271": "opening bottle",
274
+ "272": "playing xylophone",
275
+ "273": "motorcycling",
276
+ "274": "driving car",
277
+ "275": "exercising arm",
278
+ "276": "passing American football (not in game)",
279
+ "277": "playing kickball",
280
+ "278": "sticking tongue out",
281
+ "279": "flipping pancake",
282
+ "280": "catching fish",
283
+ "281": "eating chips",
284
+ "282": "shaking head",
285
+ "283": "sword fighting",
286
+ "284": "playing poker",
287
+ "285": "cooking on campfire",
288
+ "286": "doing aerobics",
289
+ "287": "paragliding",
290
+ "288": "using segway",
291
+ "289": "folding napkins",
292
+ "290": "playing bagpipes",
293
+ "291": "gargling",
294
+ "292": "skiing slalom",
295
+ "293": "strumming guitar",
296
+ "294": "javelin throw",
297
+ "295": "waxing back",
298
+ "296": "riding or walking with horse",
299
+ "297": "plastering",
300
+ "298": "long jump",
301
+ "299": "parkour",
302
+ "300": "wrapping present",
303
+ "301": "egg hunting",
304
+ "302": "archery",
305
+ "303": "cleaning toilet",
306
+ "304": "swimming backstroke",
307
+ "305": "snowboarding",
308
+ "306": "catching or throwing baseball",
309
+ "307": "massaging back",
310
+ "308": "blowing glass",
311
+ "309": "playing guitar",
312
+ "310": "playing chess",
313
+ "311": "golf driving",
314
+ "312": "presenting weather forecast",
315
+ "313": "rock scissors paper",
316
+ "314": "high jump",
317
+ "315": "baking cookies",
318
+ "316": "using computer",
319
+ "317": "washing feet",
320
+ "318": "arranging flowers",
321
+ "319": "playing bass guitar",
322
+ "320": "spraying",
323
+ "321": "cutting pineapple",
324
+ "322": "waxing chest",
325
+ "323": "auctioning",
326
+ "324": "jetskiing",
327
+ "325": "drinking",
328
+ "326": "busking",
329
+ "327": "playing monopoly",
330
+ "328": "salsa dancing",
331
+ "329": "waxing eyebrows",
332
+ "330": "watering plants",
333
+ "331": "zumba",
334
+ "332": "chopping wood",
335
+ "333": "pushing wheelchair",
336
+ "334": "carving pumpkin",
337
+ "335": "building shed",
338
+ "336": "making jewelry",
339
+ "337": "catching or throwing softball",
340
+ "338": "bending metal",
341
+ "339": "ice skating",
342
+ "340": "dancing charleston",
343
+ "341": "abseiling",
344
+ "342": "climbing a rope",
345
+ "343": "crying",
346
+ "344": "cleaning shoes",
347
+ "345": "dancing ballet",
348
+ "346": "driving tractor",
349
+ "347": "triple jump",
350
+ "348": "throwing ball",
351
+ "349": "getting a haircut",
352
+ "350": "running on treadmill",
353
+ "351": "climbing ladder",
354
+ "352": "blasting sand",
355
+ "353": "playing trombone",
356
+ "354": "drop kicking",
357
+ "355": "country line dancing",
358
+ "356": "changing wheel",
359
+ "357": "feeding goats",
360
+ "358": "tying knot (not on a tie)",
361
+ "359": "setting table",
362
+ "360": "shaving legs",
363
+ "361": "kissing",
364
+ "362": "riding mule",
365
+ "363": "counting money",
366
+ "364": "laying bricks",
367
+ "365": "barbequing",
368
+ "366": "news anchoring",
369
+ "367": "smoking hookah",
370
+ "368": "cooking egg",
371
+ "369": "peeling apples",
372
+ "370": "yoga",
373
+ "371": "sharpening pencil",
374
+ "372": "dribbling basketball",
375
+ "373": "petting cat",
376
+ "374": "playing ice hockey",
377
+ "375": "milking cow",
378
+ "376": "shining shoes",
379
+ "377": "juggling soccer ball",
380
+ "378": "scuba diving",
381
+ "379": "playing squash or racquetball",
382
+ "380": "drinking beer",
383
+ "381": "sign language interpreting",
384
+ "382": "playing basketball",
385
+ "383": "breakdancing",
386
+ "384": "testifying",
387
+ "385": "making snowman",
388
+ "386": "golf putting",
389
+ "387": "playing didgeridoo",
390
+ "388": "biking through snow",
391
+ "389": "sailing",
392
+ "390": "jumpstyle dancing",
393
+ "391": "water sliding",
394
+ "392": "grooming horse",
395
+ "393": "massaging feet",
396
+ "394": "playing paintball",
397
+ "395": "making a cake",
398
+ "396": "bowling",
399
+ "397": "contact juggling",
400
+ "398": "applying cream",
401
+ "399": "playing badminton"
402
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ einops
4
+ timm
5
+ Pillow
6
+ decord
transforms.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import random
3
+ from PIL import Image, ImageOps
4
+ import numpy as np
5
+ import numbers
6
+ import math
7
+ import torch
8
+
9
+
10
+ class GroupRandomCrop(object):
11
+ def __init__(self, size):
12
+ if isinstance(size, numbers.Number):
13
+ self.size = (int(size), int(size))
14
+ else:
15
+ self.size = size
16
+
17
+ def __call__(self, img_group):
18
+
19
+ w, h = img_group[0].size
20
+ th, tw = self.size
21
+
22
+ out_images = list()
23
+
24
+ x1 = random.randint(0, w - tw)
25
+ y1 = random.randint(0, h - th)
26
+
27
+ for img in img_group:
28
+ assert(img.size[0] == w and img.size[1] == h)
29
+ if w == tw and h == th:
30
+ out_images.append(img)
31
+ else:
32
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33
+
34
+ return out_images
35
+
36
+
37
+ class MultiGroupRandomCrop(object):
38
+ def __init__(self, size, groups=1):
39
+ if isinstance(size, numbers.Number):
40
+ self.size = (int(size), int(size))
41
+ else:
42
+ self.size = size
43
+ self.groups = groups
44
+
45
+ def __call__(self, img_group):
46
+
47
+ w, h = img_group[0].size
48
+ th, tw = self.size
49
+
50
+ out_images = list()
51
+
52
+ for i in range(self.groups):
53
+ x1 = random.randint(0, w - tw)
54
+ y1 = random.randint(0, h - th)
55
+
56
+ for img in img_group:
57
+ assert(img.size[0] == w and img.size[1] == h)
58
+ if w == tw and h == th:
59
+ out_images.append(img)
60
+ else:
61
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
62
+
63
+ return out_images
64
+
65
+
66
+ class GroupCenterCrop(object):
67
+ def __init__(self, size):
68
+ self.worker = torchvision.transforms.CenterCrop(size)
69
+
70
+ def __call__(self, img_group):
71
+ return [self.worker(img) for img in img_group]
72
+
73
+
74
+ class GroupRandomHorizontalFlip(object):
75
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
76
+ """
77
+
78
+ def __init__(self, is_flow=False):
79
+ self.is_flow = is_flow
80
+
81
+ def __call__(self, img_group, is_flow=False):
82
+ v = random.random()
83
+ if v < 0.5:
84
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
85
+ if self.is_flow:
86
+ for i in range(0, len(ret), 2):
87
+ # invert flow pixel values when flipping
88
+ ret[i] = ImageOps.invert(ret[i])
89
+ return ret
90
+ else:
91
+ return img_group
92
+
93
+
94
+ class GroupNormalize(object):
95
+ def __init__(self, mean, std):
96
+ self.mean = mean
97
+ self.std = std
98
+
99
+ def __call__(self, tensor):
100
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
101
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
102
+
103
+ # TODO: make efficient
104
+ for t, m, s in zip(tensor, rep_mean, rep_std):
105
+ t.sub_(m).div_(s)
106
+
107
+ return tensor
108
+
109
+
110
+ class GroupScale(object):
111
+ """ Rescales the input PIL.Image to the given 'size'.
112
+ 'size' will be the size of the smaller edge.
113
+ For example, if height > width, then image will be
114
+ rescaled to (size * height / width, size)
115
+ size: size of the smaller edge
116
+ interpolation: Default: PIL.Image.BILINEAR
117
+ """
118
+
119
+ def __init__(self, size, interpolation=Image.BILINEAR):
120
+ self.worker = torchvision.transforms.Resize(size, interpolation)
121
+
122
+ def __call__(self, img_group):
123
+ return [self.worker(img) for img in img_group]
124
+
125
+
126
+ class GroupOverSample(object):
127
+ def __init__(self, crop_size, scale_size=None, flip=True):
128
+ self.crop_size = crop_size if not isinstance(
129
+ crop_size, int) else (crop_size, crop_size)
130
+
131
+ if scale_size is not None:
132
+ self.scale_worker = GroupScale(scale_size)
133
+ else:
134
+ self.scale_worker = None
135
+ self.flip = flip
136
+
137
+ def __call__(self, img_group):
138
+
139
+ if self.scale_worker is not None:
140
+ img_group = self.scale_worker(img_group)
141
+
142
+ image_w, image_h = img_group[0].size
143
+ crop_w, crop_h = self.crop_size
144
+
145
+ offsets = GroupMultiScaleCrop.fill_fix_offset(
146
+ False, image_w, image_h, crop_w, crop_h)
147
+ oversample_group = list()
148
+ for o_w, o_h in offsets:
149
+ normal_group = list()
150
+ flip_group = list()
151
+ for i, img in enumerate(img_group):
152
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
153
+ normal_group.append(crop)
154
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
155
+
156
+ if img.mode == 'L' and i % 2 == 0:
157
+ flip_group.append(ImageOps.invert(flip_crop))
158
+ else:
159
+ flip_group.append(flip_crop)
160
+
161
+ oversample_group.extend(normal_group)
162
+ if self.flip:
163
+ oversample_group.extend(flip_group)
164
+ return oversample_group
165
+
166
+
167
+ class GroupFullResSample(object):
168
+ def __init__(self, crop_size, scale_size=None, flip=True):
169
+ self.crop_size = crop_size if not isinstance(
170
+ crop_size, int) else (crop_size, crop_size)
171
+
172
+ if scale_size is not None:
173
+ self.scale_worker = GroupScale(scale_size)
174
+ else:
175
+ self.scale_worker = None
176
+ self.flip = flip
177
+
178
+ def __call__(self, img_group):
179
+
180
+ if self.scale_worker is not None:
181
+ img_group = self.scale_worker(img_group)
182
+
183
+ image_w, image_h = img_group[0].size
184
+ crop_w, crop_h = self.crop_size
185
+
186
+ w_step = (image_w - crop_w) // 4
187
+ h_step = (image_h - crop_h) // 4
188
+
189
+ offsets = list()
190
+ offsets.append((0 * w_step, 2 * h_step)) # left
191
+ offsets.append((4 * w_step, 2 * h_step)) # right
192
+ offsets.append((2 * w_step, 2 * h_step)) # center
193
+
194
+ oversample_group = list()
195
+ for o_w, o_h in offsets:
196
+ normal_group = list()
197
+ flip_group = list()
198
+ for i, img in enumerate(img_group):
199
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
200
+ normal_group.append(crop)
201
+ if self.flip:
202
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
203
+
204
+ if img.mode == 'L' and i % 2 == 0:
205
+ flip_group.append(ImageOps.invert(flip_crop))
206
+ else:
207
+ flip_group.append(flip_crop)
208
+
209
+ oversample_group.extend(normal_group)
210
+ oversample_group.extend(flip_group)
211
+ return oversample_group
212
+
213
+
214
+ class GroupMultiScaleCrop(object):
215
+
216
+ def __init__(self, input_size, scales=None, max_distort=1,
217
+ fix_crop=True, more_fix_crop=True):
218
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
219
+ self.max_distort = max_distort
220
+ self.fix_crop = fix_crop
221
+ self.more_fix_crop = more_fix_crop
222
+ self.input_size = input_size if not isinstance(input_size, int) else [
223
+ input_size, input_size]
224
+ self.interpolation = Image.BILINEAR
225
+
226
+ def __call__(self, img_group):
227
+
228
+ im_size = img_group[0].size
229
+
230
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
231
+ crop_img_group = [
232
+ img.crop(
233
+ (offset_w,
234
+ offset_h,
235
+ offset_w +
236
+ crop_w,
237
+ offset_h +
238
+ crop_h)) for img in img_group]
239
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
240
+ for img in crop_img_group]
241
+ return ret_img_group
242
+
243
+ def _sample_crop_size(self, im_size):
244
+ image_w, image_h = im_size[0], im_size[1]
245
+
246
+ # find a crop size
247
+ base_size = min(image_w, image_h)
248
+ crop_sizes = [int(base_size * x) for x in self.scales]
249
+ crop_h = [
250
+ self.input_size[1] if abs(
251
+ x - self.input_size[1]) < 3 else x for x in crop_sizes]
252
+ crop_w = [
253
+ self.input_size[0] if abs(
254
+ x - self.input_size[0]) < 3 else x for x in crop_sizes]
255
+
256
+ pairs = []
257
+ for i, h in enumerate(crop_h):
258
+ for j, w in enumerate(crop_w):
259
+ if abs(i - j) <= self.max_distort:
260
+ pairs.append((w, h))
261
+
262
+ crop_pair = random.choice(pairs)
263
+ if not self.fix_crop:
264
+ w_offset = random.randint(0, image_w - crop_pair[0])
265
+ h_offset = random.randint(0, image_h - crop_pair[1])
266
+ else:
267
+ w_offset, h_offset = self._sample_fix_offset(
268
+ image_w, image_h, crop_pair[0], crop_pair[1])
269
+
270
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
271
+
272
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273
+ offsets = self.fill_fix_offset(
274
+ self.more_fix_crop, image_w, image_h, crop_w, crop_h)
275
+ return random.choice(offsets)
276
+
277
+ @staticmethod
278
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
279
+ w_step = (image_w - crop_w) // 4
280
+ h_step = (image_h - crop_h) // 4
281
+
282
+ ret = list()
283
+ ret.append((0, 0)) # upper left
284
+ ret.append((4 * w_step, 0)) # upper right
285
+ ret.append((0, 4 * h_step)) # lower left
286
+ ret.append((4 * w_step, 4 * h_step)) # lower right
287
+ ret.append((2 * w_step, 2 * h_step)) # center
288
+
289
+ if more_fix_crop:
290
+ ret.append((0, 2 * h_step)) # center left
291
+ ret.append((4 * w_step, 2 * h_step)) # center right
292
+ ret.append((2 * w_step, 4 * h_step)) # lower center
293
+ ret.append((2 * w_step, 0 * h_step)) # upper center
294
+
295
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
296
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
297
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
298
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
299
+
300
+ return ret
301
+
302
+
303
+ class GroupRandomSizedCrop(object):
304
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
305
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
306
+ This is popularly used to train the Inception networks
307
+ size: size of the smaller edge
308
+ interpolation: Default: PIL.Image.BILINEAR
309
+ """
310
+
311
+ def __init__(self, size, interpolation=Image.BILINEAR):
312
+ self.size = size
313
+ self.interpolation = interpolation
314
+
315
+ def __call__(self, img_group):
316
+ for attempt in range(10):
317
+ area = img_group[0].size[0] * img_group[0].size[1]
318
+ target_area = random.uniform(0.08, 1.0) * area
319
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
320
+
321
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
322
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
323
+
324
+ if random.random() < 0.5:
325
+ w, h = h, w
326
+
327
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
328
+ x1 = random.randint(0, img_group[0].size[0] - w)
329
+ y1 = random.randint(0, img_group[0].size[1] - h)
330
+ found = True
331
+ break
332
+ else:
333
+ found = False
334
+ x1 = 0
335
+ y1 = 0
336
+
337
+ if found:
338
+ out_group = list()
339
+ for img in img_group:
340
+ img = img.crop((x1, y1, x1 + w, y1 + h))
341
+ assert(img.size == (w, h))
342
+ out_group.append(
343
+ img.resize(
344
+ (self.size, self.size), self.interpolation))
345
+ return out_group
346
+ else:
347
+ # Fallback
348
+ scale = GroupScale(self.size, interpolation=self.interpolation)
349
+ crop = GroupRandomCrop(self.size)
350
+ return crop(scale(img_group))
351
+
352
+
353
+ class ConvertDataFormat(object):
354
+ def __init__(self, model_type):
355
+ self.model_type = model_type
356
+
357
+ def __call__(self, images):
358
+ if self.model_type == '2D':
359
+ return images
360
+ tc, h, w = images.size()
361
+ t = tc // 3
362
+ images = images.view(t, 3, h, w)
363
+ images = images.permute(1, 0, 2, 3)
364
+ return images
365
+
366
+
367
+ class Stack(object):
368
+
369
+ def __init__(self, roll=False):
370
+ self.roll = roll
371
+
372
+ def __call__(self, img_group):
373
+ if img_group[0].mode == 'L':
374
+ return np.concatenate([np.expand_dims(x, 2)
375
+ for x in img_group], axis=2)
376
+ elif img_group[0].mode == 'RGB':
377
+ if self.roll:
378
+ return np.concatenate([np.array(x)[:, :, ::-1]
379
+ for x in img_group], axis=2)
380
+ else:
381
+ #print(np.concatenate(img_group, axis=2).shape)
382
+ # print(img_group[0].shape)
383
+ return np.concatenate(img_group, axis=2)
384
+
385
+
386
+ class ToTorchFormatTensor(object):
387
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
388
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
389
+
390
+ def __init__(self, div=True):
391
+ self.div = div
392
+
393
+ def __call__(self, pic):
394
+ if isinstance(pic, np.ndarray):
395
+ # handle numpy array
396
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
397
+ else:
398
+ # handle PIL Image
399
+ img = torch.ByteTensor(
400
+ torch.ByteStorage.from_buffer(
401
+ pic.tobytes()))
402
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
403
+ # put it from HWC to CHW format
404
+ # yikes, this transpose takes 80% of the loading time/CPU
405
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
406
+ return img.float().div(255) if self.div else img.float()
407
+
408
+
409
+ class IdentityTransform(object):
410
+
411
+ def __call__(self, data):
412
+ return data
413
+
414
+
415
+ if __name__ == "__main__":
416
+ trans = torchvision.transforms.Compose([
417
+ GroupScale(256),
418
+ GroupRandomCrop(224),
419
+ Stack(),
420
+ ToTorchFormatTensor(),
421
+ GroupNormalize(
422
+ mean=[.485, .456, .406],
423
+ std=[.229, .224, .225]
424
+ )]
425
+ )
426
+
427
+ im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
428
+
429
+ color_group = [im] * 3
430
+ rst = trans(color_group)
431
+
432
+ gray_group = [im.convert('L')] * 9
433
+ gray_rst = trans(gray_group)
434
+
435
+ trans2 = torchvision.transforms.Compose([
436
+ GroupRandomSizedCrop(256),
437
+ Stack(),
438
+ ToTorchFormatTensor(),
439
+ GroupNormalize(
440
+ mean=[.485, .456, .406],
441
+ std=[.229, .224, .225])
442
+ ])
443
+ print(trans2(color_group))
uniformer.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ import torch.nn as nn
4
+ from functools import partial
5
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
6
+
7
+
8
+ def conv_3xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
9
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (2, stride, stride), (1, 0, 0), groups=groups)
10
+
11
+ def conv_1xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
12
+ return nn.Conv3d(inp, oup, (1, kernel_size, kernel_size), (1, stride, stride), (0, 0, 0), groups=groups)
13
+
14
+ def conv_3xnxn_std(inp, oup, kernel_size=3, stride=3, groups=1):
15
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (1, stride, stride), (1, 0, 0), groups=groups)
16
+
17
+ def conv_1x1x1(inp, oup, groups=1):
18
+ return nn.Conv3d(inp, oup, (1, 1, 1), (1, 1, 1), (0, 0, 0), groups=groups)
19
+
20
+ def conv_3x3x3(inp, oup, groups=1):
21
+ return nn.Conv3d(inp, oup, (3, 3, 3), (1, 1, 1), (1, 1, 1), groups=groups)
22
+
23
+ def conv_5x5x5(inp, oup, groups=1):
24
+ return nn.Conv3d(inp, oup, (5, 5, 5), (1, 1, 1), (2, 2, 2), groups=groups)
25
+
26
+ def bn_3d(dim):
27
+ return nn.BatchNorm3d(dim)
28
+
29
+
30
+ class Mlp(nn.Module):
31
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
32
+ super().__init__()
33
+ out_features = out_features or in_features
34
+ hidden_features = hidden_features or in_features
35
+ self.fc1 = nn.Linear(in_features, hidden_features)
36
+ self.act = act_layer()
37
+ self.fc2 = nn.Linear(hidden_features, out_features)
38
+ self.drop = nn.Dropout(drop)
39
+
40
+ def forward(self, x):
41
+ x = self.fc1(x)
42
+ x = self.act(x)
43
+ x = self.drop(x)
44
+ x = self.fc2(x)
45
+ x = self.drop(x)
46
+ return x
47
+
48
+
49
+ class Attention(nn.Module):
50
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
51
+ super().__init__()
52
+ self.num_heads = num_heads
53
+ head_dim = dim // num_heads
54
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
55
+ self.scale = qk_scale or head_dim ** -0.5
56
+
57
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
58
+ self.attn_drop = nn.Dropout(attn_drop)
59
+ self.proj = nn.Linear(dim, dim)
60
+ self.proj_drop = nn.Dropout(proj_drop)
61
+
62
+ def forward(self, x):
63
+ B, N, C = x.shape
64
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
65
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
66
+
67
+ attn = (q @ k.transpose(-2, -1)) * self.scale
68
+ attn = attn.softmax(dim=-1)
69
+ attn = self.attn_drop(attn)
70
+
71
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
72
+ x = self.proj(x)
73
+ x = self.proj_drop(x)
74
+ return x
75
+
76
+
77
+ class CMlp(nn.Module):
78
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
79
+ super().__init__()
80
+ out_features = out_features or in_features
81
+ hidden_features = hidden_features or in_features
82
+ self.fc1 = conv_1x1x1(in_features, hidden_features)
83
+ self.act = act_layer()
84
+ self.fc2 = conv_1x1x1(hidden_features, out_features)
85
+ self.drop = nn.Dropout(drop)
86
+
87
+ def forward(self, x):
88
+ x = self.fc1(x)
89
+ x = self.act(x)
90
+ x = self.drop(x)
91
+ x = self.fc2(x)
92
+ x = self.drop(x)
93
+ return x
94
+
95
+
96
+ class CBlock(nn.Module):
97
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
98
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
99
+ super().__init__()
100
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
101
+ self.norm1 = bn_3d(dim)
102
+ self.conv1 = conv_1x1x1(dim, dim, 1)
103
+ self.conv2 = conv_1x1x1(dim, dim, 1)
104
+ self.attn = conv_5x5x5(dim, dim, groups=dim)
105
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
106
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
107
+ self.norm2 = bn_3d(dim)
108
+ mlp_hidden_dim = int(dim * mlp_ratio)
109
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
110
+
111
+ def forward(self, x):
112
+ x = x + self.pos_embed(x)
113
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
114
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
115
+ return x
116
+
117
+
118
+ class SABlock(nn.Module):
119
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
120
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
121
+ super().__init__()
122
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
123
+ self.norm1 = norm_layer(dim)
124
+ self.attn = Attention(
125
+ dim,
126
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
127
+ attn_drop=attn_drop, proj_drop=drop)
128
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
129
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
130
+ self.norm2 = norm_layer(dim)
131
+ mlp_hidden_dim = int(dim * mlp_ratio)
132
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
133
+
134
+ def forward(self, x):
135
+ x = x + self.pos_embed(x)
136
+ B, C, T, H, W = x.shape
137
+ x = x.flatten(2).transpose(1, 2)
138
+ x = x + self.drop_path(self.attn(self.norm1(x)))
139
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
140
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
141
+ return x
142
+
143
+
144
+ class SplitSABlock(nn.Module):
145
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
146
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
147
+ super().__init__()
148
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
149
+ self.t_norm = norm_layer(dim)
150
+ self.t_attn = Attention(
151
+ dim,
152
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
153
+ attn_drop=attn_drop, proj_drop=drop)
154
+ self.norm1 = norm_layer(dim)
155
+ self.attn = Attention(
156
+ dim,
157
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
158
+ attn_drop=attn_drop, proj_drop=drop)
159
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
160
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
161
+ self.norm2 = norm_layer(dim)
162
+ mlp_hidden_dim = int(dim * mlp_ratio)
163
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
164
+
165
+ def forward(self, x):
166
+ x = x + self.pos_embed(x)
167
+ B, C, T, H, W = x.shape
168
+ attn = x.view(B, C, T, H * W).permute(0, 3, 2, 1).contiguous()
169
+ attn = attn.view(B * H * W, T, C)
170
+ attn = attn + self.drop_path(self.t_attn(self.t_norm(attn)))
171
+ attn = attn.view(B, H * W, T, C).permute(0, 2, 1, 3).contiguous()
172
+ attn = attn.view(B * T, H * W, C)
173
+ residual = x.view(B, C, T, H * W).permute(0, 2, 3, 1).contiguous()
174
+ residual = residual.view(B * T, H * W, C)
175
+ attn = residual + self.drop_path(self.attn(self.norm1(attn)))
176
+ attn = attn.view(B, T * H * W, C)
177
+ out = attn + self.drop_path(self.mlp(self.norm2(attn)))
178
+ out = out.transpose(1, 2).reshape(B, C, T, H, W)
179
+ return out
180
+
181
+
182
+ class SpeicalPatchEmbed(nn.Module):
183
+ """ Image to Patch Embedding
184
+ """
185
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
186
+ super().__init__()
187
+ img_size = to_2tuple(img_size)
188
+ patch_size = to_2tuple(patch_size)
189
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
190
+ self.img_size = img_size
191
+ self.patch_size = patch_size
192
+ self.num_patches = num_patches
193
+ self.norm = nn.LayerNorm(embed_dim)
194
+ self.proj = conv_3xnxn(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
195
+
196
+ def forward(self, x):
197
+ B, C, T, H, W = x.shape
198
+ # FIXME look at relaxing size constraints
199
+ # assert H == self.img_size[0] and W == self.img_size[1], \
200
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
201
+ x = self.proj(x)
202
+ B, C, T, H, W = x.shape
203
+ x = x.flatten(2).transpose(1, 2)
204
+ x = self.norm(x)
205
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
206
+ return x
207
+
208
+
209
+ class PatchEmbed(nn.Module):
210
+ """ Image to Patch Embedding
211
+ """
212
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, std=False):
213
+ super().__init__()
214
+ img_size = to_2tuple(img_size)
215
+ patch_size = to_2tuple(patch_size)
216
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
217
+ self.img_size = img_size
218
+ self.patch_size = patch_size
219
+ self.num_patches = num_patches
220
+ self.norm = nn.LayerNorm(embed_dim)
221
+ if std:
222
+ self.proj = conv_3xnxn_std(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
223
+ else:
224
+ self.proj = conv_1xnxn(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
225
+
226
+ def forward(self, x):
227
+ B, C, T, H, W = x.shape
228
+ # FIXME look at relaxing size constraints
229
+ # assert H == self.img_size[0] and W == self.img_size[1], \
230
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
231
+ x = self.proj(x)
232
+ B, C, T, H, W = x.shape
233
+ x = x.flatten(2).transpose(1, 2)
234
+ x = self.norm(x)
235
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
236
+ return x
237
+
238
+
239
+ class Uniformer(nn.Module):
240
+ """ Vision Transformer
241
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
242
+ https://arxiv.org/abs/2010.11929
243
+ """
244
+ def __init__(self, depth=[5, 8, 20, 7], num_classes=400, img_size=224, in_chans=3, embed_dim=[64, 128, 320, 512],
245
+ head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
246
+ drop_rate=0.3, attn_drop_rate=0., drop_path_rate=0., norm_layer=None, split=False, std=False):
247
+ super().__init__()
248
+
249
+ self.num_classes = num_classes
250
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
251
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
252
+
253
+ self.patch_embed1 = SpeicalPatchEmbed(
254
+ img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
255
+ self.patch_embed2 = PatchEmbed(
256
+ img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1], std=std)
257
+ self.patch_embed3 = PatchEmbed(
258
+ img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2], std=std)
259
+ self.patch_embed4 = PatchEmbed(
260
+ img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3], std=std)
261
+
262
+ self.pos_drop = nn.Dropout(p=drop_rate)
263
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
264
+ num_heads = [dim // head_dim for dim in embed_dim]
265
+ self.blocks1 = nn.ModuleList([
266
+ CBlock(
267
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
268
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
269
+ for i in range(depth[0])])
270
+ self.blocks2 = nn.ModuleList([
271
+ CBlock(
272
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer)
274
+ for i in range(depth[1])])
275
+ if split:
276
+ self.blocks3 = nn.ModuleList([
277
+ SplitSABlock(
278
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
279
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)
280
+ for i in range(depth[2])])
281
+ self.blocks4 = nn.ModuleList([
282
+ SplitSABlock(
283
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
284
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)
285
+ for i in range(depth[3])])
286
+ else:
287
+ self.blocks3 = nn.ModuleList([
288
+ SABlock(
289
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
290
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)
291
+ for i in range(depth[2])])
292
+ self.blocks4 = nn.ModuleList([
293
+ SABlock(
294
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
295
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)
296
+ for i in range(depth[3])])
297
+ self.norm = bn_3d(embed_dim[-1])
298
+
299
+ # Representation layer
300
+ if representation_size:
301
+ self.num_features = representation_size
302
+ self.pre_logits = nn.Sequential(OrderedDict([
303
+ ('fc', nn.Linear(embed_dim, representation_size)),
304
+ ('act', nn.Tanh())
305
+ ]))
306
+ else:
307
+ self.pre_logits = nn.Identity()
308
+
309
+ # Classifier head
310
+ self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
311
+
312
+ self.apply(self._init_weights)
313
+
314
+ for name, p in self.named_parameters():
315
+ # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
316
+ # are multiplied by 0*0, which is hard for the model to move out of.
317
+ if 't_attn.qkv.weight' in name:
318
+ nn.init.constant_(p, 0)
319
+ if 't_attn.qkv.bias' in name:
320
+ nn.init.constant_(p, 0)
321
+ if 't_attn.proj.weight' in name:
322
+ nn.init.constant_(p, 1)
323
+ if 't_attn.proj.bias' in name:
324
+ nn.init.constant_(p, 0)
325
+
326
+ def _init_weights(self, m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+
335
+ @torch.jit.ignore
336
+ def no_weight_decay(self):
337
+ return {'pos_embed', 'cls_token'}
338
+
339
+ def get_classifier(self):
340
+ return self.head
341
+
342
+ def reset_classifier(self, num_classes, global_pool=''):
343
+ self.num_classes = num_classes
344
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
345
+
346
+ def forward_features(self, x):
347
+ x = self.patch_embed1(x)
348
+ x = self.pos_drop(x)
349
+ for blk in self.blocks1:
350
+ x = blk(x)
351
+ x = self.patch_embed2(x)
352
+ for blk in self.blocks2:
353
+ x = blk(x)
354
+ x = self.patch_embed3(x)
355
+ for blk in self.blocks3:
356
+ x = blk(x)
357
+ x = self.patch_embed4(x)
358
+ for blk in self.blocks4:
359
+ x = blk(x)
360
+ x = self.norm(x)
361
+ x = self.pre_logits(x)
362
+ return x
363
+
364
+ def forward(self, x):
365
+ x = self.forward_features(x)
366
+ x = x.flatten(2).mean(-1)
367
+ x = self.head(x)
368
+ return x
369
+
370
+
371
+ def uniformer_small():
372
+ return Uniformer(
373
+ depth=[3, 4, 8, 3], embed_dim=[64, 128, 320, 512],
374
+ head_dim=64, drop_rate=0.1)
375
+
376
+ def uniformer_base():
377
+ return Uniformer(
378
+ depth=[5, 8, 20, 7], embed_dim=[64, 128, 320, 512],
379
+ head_dim=64, drop_rate=0.3)
yoga.mp4 ADDED
Binary file (776 kB). View file