liruiw commited on
Commit
179feba
·
verified ·
1 Parent(s): 73a6f2d

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +9 -3
  2. config.json +892 -0
  3. model.safetensors +3 -0
  4. random_states_0.pkl +3 -0
README.md CHANGED
@@ -1,3 +1,9 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Library: [More Information Needed]
9
+ - Docs: [More Information Needed]
config.json ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Diffusion": true,
3
+ "S": 1024,
4
+ "T": 12,
5
+ "action_contrastive_loss": false,
6
+ "action_domains": [
7
+ "language_table",
8
+ "ucsd_pick_and_place_dataset_converted_externally_to_rlds",
9
+ "kaist_nonprehensile_converted_externally_to_rlds",
10
+ "ucsd_kitchen_dataset_converted_externally_to_rlds",
11
+ "utokyo_xarm_bimanual_converted_externally_to_rlds",
12
+ "stanford_hydra_dataset_converted_externally_to_rlds",
13
+ "austin_sirius_dataset_converted_externally_to_rlds",
14
+ "berkeley_fanuc_manipulation",
15
+ "berkeley_mvp_converted_externally_to_rlds",
16
+ "berkeley_rpt_converted_externally_to_rlds",
17
+ "cmu_play_fusion",
18
+ "iamlab_cmu_pickup_insert_converted_externally_to_rlds",
19
+ "qut_dexterous_manpulation",
20
+ "robo_net",
21
+ "dlr_sara_grid_clamp_converted_externally_to_rlds",
22
+ "cmu_stretch",
23
+ "columbia_cairlab_pusht_real",
24
+ "droid",
25
+ "toto",
26
+ "io_ai_tech",
27
+ "dobbe",
28
+ "berkeley_gnm_cory_hall",
29
+ "plex_robosuite",
30
+ "berkeley_cable_routing",
31
+ "imperial_wrist_dataset",
32
+ "bc_z",
33
+ "kuka",
34
+ "roboturk",
35
+ "robomimic",
36
+ "nyu_door_opening_surprising_effectiveness"
37
+ ],
38
+ "action_loss_weight": 1.0,
39
+ "action_network": "modulate",
40
+ "action_stats": [
41
+ [
42
+ [
43
+ 0.00014842326345387846,
44
+ -0.0005635050474666059
45
+ ],
46
+ [
47
+ 0.030163198709487915,
48
+ 0.042305462062358856
49
+ ]
50
+ ],
51
+ [
52
+ [
53
+ 0.14697618782520294,
54
+ -0.12370482087135315,
55
+ 0.051797714084386826,
56
+ -0.07087255269289017
57
+ ],
58
+ [
59
+ 0.48446422815322876,
60
+ 0.4629215896129608,
61
+ 0.5406527519226074,
62
+ 0.8932808637619019
63
+ ]
64
+ ],
65
+ [
66
+ [
67
+ 0.0019365031039342284,
68
+ 0.00024300716177094728,
69
+ 0.0008035349892452359,
70
+ -0.0021898974664509296,
71
+ 4.0033126424532384e-05,
72
+ -0.0037320367991924286,
73
+ 181.3382110595703,
74
+ 170.906005859375,
75
+ 186.00167846679688,
76
+ 153.71453857421875,
77
+ 174.13552856445312,
78
+ 83.48700714111328,
79
+ 32.40367889404297,
80
+ 1.0105239152908325,
81
+ 0.70769864320755,
82
+ 0.33714282512664795,
83
+ 0.4880707859992981,
84
+ 0.49914246797561646,
85
+ 0.7389975786209106,
86
+ 1.229773998260498
87
+ ],
88
+ [
89
+ 0.014657468535006046,
90
+ 0.016509365290403366,
91
+ 0.012914449907839298,
92
+ 0.023157890886068344,
93
+ 0.020476851612329483,
94
+ 0.019903959706425667,
95
+ 27.33956527709961,
96
+ 34.60658645629883,
97
+ 25.291311264038086,
98
+ 53.12702941894531,
99
+ 34.06013870239258,
100
+ 50.88619613647461,
101
+ 33.19802474975586,
102
+ 0.5348116159439087,
103
+ 0.4462398290634155,
104
+ 0.13059484958648682,
105
+ 0.37823718786239624,
106
+ 0.3443543016910553,
107
+ 0.4008789658546448,
108
+ 0.5644641518592834
109
+ ]
110
+ ],
111
+ [
112
+ [
113
+ 408.48406982421875,
114
+ 118.63397979736328,
115
+ 198.41452026367188,
116
+ -121.96654510498047,
117
+ -34.39997863769531,
118
+ 52.22698974609375,
119
+ 0.7438188791275024,
120
+ 0.038725052028894424
121
+ ],
122
+ [
123
+ 122.93132019042969,
124
+ 107.72244262695312,
125
+ 128.7881317138672,
126
+ 115.888916015625,
127
+ 27.235536575317383,
128
+ 40.505306243896484,
129
+ 0.43652451038360596,
130
+ 0.1929423063993454
131
+ ]
132
+ ],
133
+ [
134
+ [
135
+ 0.46141278743743896,
136
+ 0.10542168468236923,
137
+ 0.25353577733039856,
138
+ -1.685599684715271,
139
+ -0.05627294257283211,
140
+ -0.4933978319168091,
141
+ 0.34305593371391296,
142
+ 0.44558027386665344,
143
+ 0.5123444199562073,
144
+ 0.2677648663520813,
145
+ 1.277772307395935,
146
+ 0.12375026196241379,
147
+ 0.11488401144742966,
148
+ 0.33085882663726807
149
+ ],
150
+ [
151
+ 0.053879205137491226,
152
+ 0.05707748979330063,
153
+ 0.04467933997511864,
154
+ 2.469133138656616,
155
+ 0.18709321320056915,
156
+ 0.22889232635498047,
157
+ 0.47383636236190796,
158
+ 0.06301931291818619,
159
+ 0.041924599558115005,
160
+ 0.047896090894937515,
161
+ 2.701061248779297,
162
+ 0.2972114384174347,
163
+ 0.7424992918968201,
164
+ 0.4667799174785614
165
+ ]
166
+ ],
167
+ [
168
+ [
169
+ 0.0007764044567011297,
170
+ 0.0001343307230854407,
171
+ -0.00026648343191482127,
172
+ 0.0013218839885666966,
173
+ -0.004740390460938215,
174
+ 0.002773461164906621,
175
+ 0.5106820464134216
176
+ ],
177
+ [
178
+ 0.008042743429541588,
179
+ 0.00913731288164854,
180
+ 0.009599598124623299,
181
+ 0.04121660068631172,
182
+ 0.038332853466272354,
183
+ 0.04602774232625961,
184
+ 0.4999658763408661
185
+ ]
186
+ ],
187
+ [
188
+ [
189
+ 0.07727599143981934,
190
+ 0.03225162252783775,
191
+ 0.04257211461663246,
192
+ 0.0,
193
+ 0.0,
194
+ -0.01612210087478161,
195
+ 0.13071605563163757
196
+ ],
197
+ [
198
+ 0.3917539417743683,
199
+ 0.30044373869895935,
200
+ 0.27837157249450684,
201
+ 0.0,
202
+ 0.0,
203
+ 0.081514872610569,
204
+ 0.9911611676216125
205
+ ]
206
+ ],
207
+ [
208
+ [
209
+ 0.0007766556227579713,
210
+ -0.000321519240969792,
211
+ -0.0014813995221629739,
212
+ -0.0007485907408408821,
213
+ -0.00015667964180465788,
214
+ 0.0001845337392296642
215
+ ],
216
+ [
217
+ 0.003409236203879118,
218
+ 0.004994169808924198,
219
+ 0.005332312546670437,
220
+ 0.007559089455753565,
221
+ 0.004051606170833111,
222
+ 0.008588160388171673
223
+ ]
224
+ ],
225
+ [
226
+ [
227
+ -6.743222911609337e-05,
228
+ 0.0031809681095182896,
229
+ -0.00013550207950174809,
230
+ -0.0009742751135490835,
231
+ -8.3738968896796e-06,
232
+ -0.002912015886977315,
233
+ -0.0006995691219344735,
234
+ 0.48066604137420654
235
+ ],
236
+ [
237
+ 0.002549938391894102,
238
+ 0.012658610939979553,
239
+ 0.005411175545305014,
240
+ 0.018054410815238953,
241
+ 0.0016273874789476395,
242
+ 0.021100502461194992,
243
+ 0.005715933162719011,
244
+ 0.4996056854724884
245
+ ]
246
+ ],
247
+ [
248
+ [
249
+ 0.00014941584959160537,
250
+ -0.00028024936909787357,
251
+ -8.037472071009688e-06,
252
+ -0.00032872759038582444,
253
+ 1.9844068447127938e-05,
254
+ 3.272057801950723e-05,
255
+ 8.096991950878873e-05,
256
+ 0.4784493148326874
257
+ ],
258
+ [
259
+ 0.0015258367639034986,
260
+ 0.004546448588371277,
261
+ 0.0007782428874634206,
262
+ 0.003019175725057721,
263
+ 0.0010663573630154133,
264
+ 0.005132743623107672,
265
+ 0.004171756561845541,
266
+ 0.4998187720775604
267
+ ]
268
+ ],
269
+ [
270
+ [
271
+ 0.000523168477229774,
272
+ 3.85410530725494e-05,
273
+ -0.00017000196385197341,
274
+ -0.00029378157341852784,
275
+ -0.00036922883009538054,
276
+ -0.0001573827030370012,
277
+ 5.717058229492977e-06,
278
+ 0.5699702501296997,
279
+ 0.002427969593554735
280
+ ],
281
+ [
282
+ 0.0018014844972640276,
283
+ 0.002389610279351473,
284
+ 0.0018651892896741629,
285
+ 0.039326585829257965,
286
+ 0.03775598481297493,
287
+ 0.005358702037483454,
288
+ 0.007674811407923698,
289
+ 0.49466049671173096,
290
+ 0.04917134344577789
291
+ ]
292
+ ],
293
+ [
294
+ [
295
+ 0.5280895829200745,
296
+ 0.02888699807226658,
297
+ 0.18680934607982635,
298
+ -0.01308287400752306,
299
+ 0.9998903870582581,
300
+ 0.003612307133153081,
301
+ 0.016001908108592033,
302
+ 0.5531076192855835
303
+ ],
304
+ [
305
+ 0.08082365244626999,
306
+ 0.11135152727365494,
307
+ 0.07754139602184296,
308
+ 0.01604105904698372,
309
+ 0.0006265711272135377,
310
+ 0.007812995463609695,
311
+ 0.013805469498038292,
312
+ 0.497190922498703
313
+ ]
314
+ ],
315
+ [
316
+ [
317
+ 3.650687176559586e-06,
318
+ -0.000507326505612582,
319
+ -0.00031988348928280175,
320
+ 0.000982428900897503,
321
+ -3.8030557334423065e-05,
322
+ -0.002412878442555666,
323
+ 0.004487304482609034,
324
+ -0.0035084427800029516
325
+ ],
326
+ [
327
+ 0.012881123460829258,
328
+ 0.015523286536335945,
329
+ 0.012135118246078491,
330
+ 0.0009839057456701994,
331
+ 0.0022208373993635178,
332
+ 0.024184072390198708,
333
+ 0.9997038841247559,
334
+ 0.027473121881484985
335
+ ]
336
+ ],
337
+ [
338
+ [
339
+ -1.2117172445869073e-05,
340
+ 1.6497699107276276e-05,
341
+ -0.008070996962487698,
342
+ -3.756756632355973e-05,
343
+ -0.02855828031897545
344
+ ],
345
+ [
346
+ 0.03243051841855049,
347
+ 0.03239370137453079,
348
+ 0.09075836837291718,
349
+ 0.17676453292369843,
350
+ 0.9996473789215088
351
+ ]
352
+ ],
353
+ [
354
+ [
355
+ -1.662798604229465e-05,
356
+ -4.2423445847816765e-05,
357
+ -0.00039035530062392354,
358
+ 2.9382475986494683e-05,
359
+ 9.349627362098545e-05,
360
+ 7.735285180388018e-05,
361
+ 1.0
362
+ ],
363
+ [
364
+ 0.0004397016600705683,
365
+ 0.0005161615554243326,
366
+ 0.0012931948294863105,
367
+ 0.0005689726676791906,
368
+ 0.0007594820926897228,
369
+ 0.0007169033051468432,
370
+ 0.0
371
+ ]
372
+ ],
373
+ [
374
+ [
375
+ 0.0003590668202377856,
376
+ 0.0,
377
+ 0.001637771725654602,
378
+ 0.0,
379
+ 0.0,
380
+ 0.0,
381
+ 0.39824214577674866,
382
+ 0.0054051512852311134
383
+ ],
384
+ [
385
+ 0.004081381484866142,
386
+ 0.0,
387
+ 0.003803750965744257,
388
+ 0.0,
389
+ 0.0,
390
+ 0.0,
391
+ 0.48956871032714844,
392
+ 0.07332666963338852
393
+ ]
394
+ ],
395
+ [
396
+ [
397
+ 0.0,
398
+ 0.0,
399
+ 0.0,
400
+ 0.0,
401
+ 0.009795918129384518,
402
+ -0.0013581214006990194,
403
+ 0.0016117944614961743,
404
+ 0.0
405
+ ],
406
+ [
407
+ 0.0,
408
+ 0.0,
409
+ 0.0,
410
+ 0.0,
411
+ 0.09848489612340927,
412
+ 0.013396660797297955,
413
+ 0.01613754965364933,
414
+ 0.0
415
+ ]
416
+ ],
417
+ [
418
+ [
419
+ 0.5393196940422058,
420
+ 0.0013354304246604443,
421
+ 0.3156941831111908,
422
+ 0.3164699971675873,
423
+ -0.09018929302692413,
424
+ -0.049339085817337036,
425
+ 0.40983372926712036
426
+ ],
427
+ [
428
+ 0.11741136759519577,
429
+ 0.17491821944713593,
430
+ 0.16181626915931702,
431
+ 2.7440731525421143,
432
+ 0.3496827483177185,
433
+ 0.7598394155502319,
434
+ 0.43021807074546814
435
+ ]
436
+ ],
437
+ [
438
+ [
439
+ 0.0,
440
+ -0.6654278039932251,
441
+ 0.1888580173254013,
442
+ 0.03210142254829407,
443
+ 0.00612324383109808,
444
+ 0.38309070467948914,
445
+ 0.009382354095578194,
446
+ 0.3636060655117035
447
+ ],
448
+ [
449
+ 0.0,
450
+ 0.5778681635856628,
451
+ 0.296415776014328,
452
+ 0.3222154378890991,
453
+ 0.07795067131519318,
454
+ 0.12293250113725662,
455
+ 0.19515110552310944,
456
+ 0.10152395814657211
457
+ ]
458
+ ],
459
+ [
460
+ [
461
+ 2.957821561722085e-05,
462
+ 0.00012851174687966704,
463
+ -0.00010689908231142908,
464
+ -5.97012804064434e-05,
465
+ 0.00022397778229787946,
466
+ 6.999688048381358e-05,
467
+ 0.09176551550626755
468
+ ],
469
+ [
470
+ 0.0028307351749390364,
471
+ 0.002651946386322379,
472
+ 0.0025581379886716604,
473
+ 0.018357520923018456,
474
+ 0.025164088234305382,
475
+ 0.024045433849096298,
476
+ 0.5909407734870911
477
+ ]
478
+ ],
479
+ [
480
+ [
481
+ -0.00011003677354892716,
482
+ 0.001112840254791081,
483
+ -0.00011267208174103871,
484
+ -7.512857700930908e-05,
485
+ -0.0006745870341546834,
486
+ -5.703312126570381e-05,
487
+ 0.6326711177825928
488
+ ],
489
+ [
490
+ 0.043496448546648026,
491
+ 0.04464876651763916,
492
+ 0.12467490881681442,
493
+ 0.005452098790556192,
494
+ 0.011218013241887093,
495
+ 0.00624604569748044,
496
+ 0.39724212884902954
497
+ ]
498
+ ],
499
+ [
500
+ [
501
+ 0.06122741475701332,
502
+ 0.0038701200392097235
503
+ ],
504
+ [
505
+ 0.025848353281617165,
506
+ 0.0030985879711806774
507
+ ]
508
+ ],
509
+ [
510
+ [
511
+ 0.060740936547517776,
512
+ 0.053044628351926804,
513
+ -0.04193497821688652,
514
+ -0.000676018709782511,
515
+ -0.0015231040306389332,
516
+ 0.004273010417819023,
517
+ -0.05146767199039459
518
+ ],
519
+ [
520
+ 0.34811627864837646,
521
+ 0.46360549330711365,
522
+ 0.4386604428291321,
523
+ 0.023519689217209816,
524
+ 0.019431674852967262,
525
+ 0.1616460531949997,
526
+ 0.9985936880111694
527
+ ]
528
+ ],
529
+ [
530
+ [
531
+ 0.0,
532
+ 0.0,
533
+ 0.049097269773483276,
534
+ 0.07730317115783691,
535
+ -0.07240438461303711,
536
+ 0.02373087964951992,
537
+ 0.10240031778812408
538
+ ],
539
+ [
540
+ 0.0,
541
+ 0.0,
542
+ 0.3465680181980133,
543
+ 0.2670310139656067,
544
+ 0.1823672205209732,
545
+ 0.1818883866071701,
546
+ 0.21153412759304047
547
+ ]
548
+ ],
549
+ [
550
+ [
551
+ 0.0003212452866137028,
552
+ -0.0010083492379635572,
553
+ 0.00092211680021137,
554
+ 0.001238797907717526,
555
+ -4.7416866436833516e-05,
556
+ 2.5170325898216106e-05,
557
+ 0.5775114297866821,
558
+ 0.023655574768781662
559
+ ],
560
+ [
561
+ 0.003074005013331771,
562
+ 0.0067475223913788795,
563
+ 0.010976199060678482,
564
+ 0.024110153317451477,
565
+ 0.003232581540942192,
566
+ 0.0039499602280557156,
567
+ 0.49393802881240845,
568
+ 0.15197230875492096
569
+ ]
570
+ ],
571
+ [
572
+ [
573
+ 0.00028649900923483074,
574
+ -0.008722408674657345,
575
+ -0.03069918043911457,
576
+ -0.0008381816442124546,
577
+ -0.016971644014120102,
578
+ -0.05745099112391472,
579
+ -0.0026707653887569904,
580
+ -0.024192843586206436,
581
+ -0.07967454195022583,
582
+ -0.004741811193525791,
583
+ -0.030430495738983154,
584
+ -0.09769809991121292,
585
+ -0.006405732128769159,
586
+ -0.03590046241879463,
587
+ -0.11270859092473984,
588
+ -0.007021840196102858,
589
+ -0.04052259027957916,
590
+ -0.12620966136455536,
591
+ -0.006953817792236805,
592
+ -0.04445187374949455,
593
+ -0.13877084851264954,
594
+ -0.006491991225630045,
595
+ -0.04791347682476044,
596
+ -0.15081816911697388,
597
+ -0.0057747503742575645,
598
+ -0.05111181363463402,
599
+ -0.16245798766613007,
600
+ -0.004867491777986288,
601
+ -0.054257530719041824,
602
+ -0.1738300770521164,
603
+ 0.1657339185476303,
604
+ 0.15363934636116028,
605
+ 0.14477591216564178,
606
+ 0.13839827477931976,
607
+ 0.14092908799648285,
608
+ 0.15468865633010864,
609
+ 0.16648422181606293,
610
+ 0.17608821392059326,
611
+ 0.1841760128736496,
612
+ 0.19062727689743042,
613
+ -0.00996700394898653,
614
+ 0.0009040668956004083,
615
+ 0.004995268769562244,
616
+ -0.018695320934057236,
617
+ 0.0023894852492958307,
618
+ 0.009505861438810825,
619
+ -0.025692706927657127,
620
+ 0.0043935589492321014,
621
+ 0.013725746423006058,
622
+ -0.031206561252474785,
623
+ 0.006276523228734732,
624
+ 0.017453059554100037,
625
+ -0.03552345186471939,
626
+ 0.00730851711705327,
627
+ 0.0201703030616045,
628
+ -0.03902909904718399,
629
+ 0.0068913171999156475,
630
+ 0.021356917917728424,
631
+ -0.04216034710407257,
632
+ 0.005601761396974325,
633
+ 0.021485209465026855,
634
+ -0.04517875239253044,
635
+ 0.0038381200283765793,
636
+ 0.020964166149497032,
637
+ -0.04812570661306381,
638
+ 0.0018662408692762256,
639
+ 0.020047230646014214,
640
+ -0.05107533559203148,
641
+ -0.00014347363321576267,
642
+ 0.01891058310866356
643
+ ],
644
+ [
645
+ 0.04161107912659645,
646
+ 0.04643801972270012,
647
+ 0.07714500278234482,
648
+ 0.06882365792989731,
649
+ 0.07853177934885025,
650
+ 0.13730594515800476,
651
+ 0.08874603360891342,
652
+ 0.10269544273614883,
653
+ 0.18733122944831848,
654
+ 0.10394992679357529,
655
+ 0.12144353240728378,
656
+ 0.228963240981102,
657
+ 0.11596192419528961,
658
+ 0.1365172266960144,
659
+ 0.26162227988243103,
660
+ 0.1259639710187912,
661
+ 0.1488255262374878,
662
+ 0.2899230718612671,
663
+ 0.13448232412338257,
664
+ 0.1590842306613922,
665
+ 0.3123351037502289,
666
+ 0.141754150390625,
667
+ 0.16746656596660614,
668
+ 0.3306305706501007,
669
+ 0.1480536013841629,
670
+ 0.17442390322685242,
671
+ 0.34733083844184875,
672
+ 0.1536930501461029,
673
+ 0.18031899631023407,
674
+ 0.36258259415626526,
675
+ 0.36356091499328613,
676
+ 0.3559049963951111,
677
+ 0.34951725602149963,
678
+ 0.3456125855445862,
679
+ 0.3479859530925751,
680
+ 0.35668134689331055,
681
+ 0.36377906799316406,
682
+ 0.3706970512866974,
683
+ 0.3976169526576996,
684
+ 0.39772698283195496,
685
+ 0.03057853877544403,
686
+ 0.02315731719136238,
687
+ 0.020660309121012688,
688
+ 0.054019346833229065,
689
+ 0.039159927517175674,
690
+ 0.03595462813973427,
691
+ 0.07279643416404724,
692
+ 0.051145341247320175,
693
+ 0.04803183302283287,
694
+ 0.08770721405744553,
695
+ 0.060230545699596405,
696
+ 0.05767446011304855,
697
+ 0.0995459109544754,
698
+ 0.06725169718265533,
699
+ 0.06545353680849075,
700
+ 0.10892871767282486,
701
+ 0.07302306592464447,
702
+ 0.07213838398456573,
703
+ 0.116375632584095,
704
+ 0.0778125748038292,
705
+ 0.07771806418895721,
706
+ 0.12312348932027817,
707
+ 0.0816049873828888,
708
+ 0.08206423372030258,
709
+ 0.1288398653268814,
710
+ 0.08456701040267944,
711
+ 0.08530736714601517,
712
+ 0.13414840400218964,
713
+ 0.08684109151363373,
714
+ 0.0876062661409378
715
+ ]
716
+ ],
717
+ [
718
+ [
719
+ 0.0,
720
+ 0.0,
721
+ 0.0,
722
+ 0.06579020619392395,
723
+ 0.0,
724
+ 0.0,
725
+ -0.0417020283639431,
726
+ 0.05600078031420708,
727
+ 0.8763857483863831,
728
+ 0.0,
729
+ -0.0006691364105790854,
730
+ 0.0005162839079275727,
731
+ -0.0025432889815419912
732
+ ],
733
+ [
734
+ 0.0,
735
+ 0.0,
736
+ 0.0,
737
+ 0.3762088716030121,
738
+ 0.0,
739
+ 0.0,
740
+ 0.15823891758918762,
741
+ 0.2252153754234314,
742
+ 0.31409522891044617,
743
+ 0.0,
744
+ 0.023507647216320038,
745
+ 0.036004047840833664,
746
+ 0.05768127366900444
747
+ ]
748
+ ],
749
+ [
750
+ [
751
+ -0.15657226741313934,
752
+ 0.00228637782856822,
753
+ -0.0009536752477288246,
754
+ -0.00012742729450110346,
755
+ 0.0,
756
+ 0.0014414743054658175,
757
+ -0.0015724773984402418,
758
+ -0.0011747290845960379
759
+ ],
760
+ [
761
+ 0.9879051446914673,
762
+ 0.09651217609643936,
763
+ 0.08441831171512604,
764
+ 0.06647706776857376,
765
+ 0.0,
766
+ 0.0495850145816803,
767
+ 0.06368337571620941,
768
+ 0.06135875731706619
769
+ ]
770
+ ],
771
+ [
772
+ [
773
+ 0.17389154434204102,
774
+ 0.005625918973237276,
775
+ -0.1695142686367035,
776
+ 0.0031083673238754272,
777
+ 0.005127986893057823,
778
+ 0.012693661265075207,
779
+ -0.4065398871898651
780
+ ],
781
+ [
782
+ 0.2601781189441681,
783
+ 0.13021306693553925,
784
+ 0.4979441463947296,
785
+ 0.022246459499001503,
786
+ 0.06382154673337936,
787
+ 0.08343781530857086,
788
+ 0.913625180721283
789
+ ]
790
+ ],
791
+ [
792
+ [
793
+ 0.02290557324886322,
794
+ -0.00010951685544569045,
795
+ -0.011411379091441631,
796
+ -0.0015635089948773384,
797
+ 0.04783362150192261,
798
+ -0.0063293022103607655,
799
+ 0.0013472747523337603,
800
+ 0.001141763525083661
801
+ ],
802
+ [
803
+ 0.09001470357179642,
804
+ 0.00812098290771246,
805
+ 0.033615339547395706,
806
+ 0.013250669464468956,
807
+ 0.21339112520217896,
808
+ 0.01224832609295845,
809
+ 0.019686469808220863,
810
+ 0.00791964028030634
811
+ ]
812
+ ]
813
+ ],
814
+ "action_token_size": 64,
815
+ "arch": "STTransformerDecoder",
816
+ "attn_drop": 0.1,
817
+ "attn_dropout": 0.1,
818
+ "buffer_size": 64,
819
+ "d_action": 28,
820
+ "d_actions": [
821
+ 2,
822
+ 4,
823
+ 100,
824
+ 8,
825
+ 14,
826
+ 35,
827
+ 70,
828
+ 6,
829
+ 16,
830
+ 120,
831
+ 18,
832
+ 80,
833
+ 8,
834
+ 5,
835
+ 7,
836
+ 8,
837
+ 40,
838
+ 49,
839
+ 120,
840
+ 7,
841
+ 7,
842
+ 2,
843
+ 7,
844
+ 35,
845
+ 8,
846
+ 350,
847
+ 13,
848
+ 40,
849
+ 21,
850
+ 8
851
+ ],
852
+ "d_model": 256,
853
+ "dataloader_apply_corruption": false,
854
+ "dataloader_apply_mask": true,
855
+ "dataloader_mask_ratio_min": 0.1,
856
+ "diffloss_d": 4,
857
+ "diffloss_w": 1024,
858
+ "diffusion_batch_mul": 1,
859
+ "dim": 512,
860
+ "drop_action_ratio": 0.0,
861
+ "factored_vocab_size": 512,
862
+ "grad_checkpointing": false,
863
+ "image_vocab_size": null,
864
+ "init_actions": true,
865
+ "jointly_predict_actions": false,
866
+ "jointly_predict_states": true,
867
+ "label_drop_prob": 0.5,
868
+ "mask_ratio_min": 0.7,
869
+ "maskgit_steps": 16,
870
+ "max_corrupt_rate": 0.2,
871
+ "mlp_bias": false,
872
+ "mlp_drop": 0.05,
873
+ "mlp_ratio": 4.0,
874
+ "non_mlm_ratio": 0.2,
875
+ "num_factored_vocabs": 2,
876
+ "num_heads": 8,
877
+ "num_layers": 32,
878
+ "num_prompt_frames": 4,
879
+ "num_sampling_steps": "100",
880
+ "patch_size": 2,
881
+ "predict_unmask": false,
882
+ "proj_bias": true,
883
+ "proj_dropout": 0.1,
884
+ "qk_norm": false,
885
+ "qkv_bias": true,
886
+ "random_dummy_action": true,
887
+ "shared_action_mlps": true,
888
+ "use_actions": true,
889
+ "use_mup": false,
890
+ "vae_embed_dim": 4,
891
+ "vae_stride": 1
892
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd72f5f01d09227b4f0622194eff5a9dc4539a2227c86dcada1fd14a1a00a61f
3
+ size 4247325512
random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81c1a3f91b1b0455825bdbb8364fc1174d2a23943e095930f96b6dc133140c71
3
+ size 16100