diff --git a/.huggingface/.gitignore b/.huggingface/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f59ec20aabf5842d237244ece8c81ab184faeac1
--- /dev/null
+++ b/.huggingface/.gitignore
@@ -0,0 +1 @@
+*
\ No newline at end of file
diff --git a/.huggingface/download/.gitattributes.lock b/.huggingface/download/.gitattributes.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/.gitattributes.metadata b/.huggingface/download/.gitattributes.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..f534382fc686a3e0adbd51fc13e422378ee6b324
--- /dev/null
+++ b/.huggingface/download/.gitattributes.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+5d7f077bd6e1a90e4cb8544726b05f855a1e0d13
+1723652257.0751407
diff --git a/.huggingface/download/README.md.lock b/.huggingface/download/README.md.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/README.md.metadata b/.huggingface/download/README.md.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..56bd14420994489959bf71dbcbe3e8867ad420a8
--- /dev/null
+++ b/.huggingface/download/README.md.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+73a7808dee60f254271c2a2e61364c9a4679842b
+1723652257.100425
diff --git a/.huggingface/download/app.py.lock b/.huggingface/download/app.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/app.py.metadata b/.huggingface/download/app.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..895b884de95b05d5c31d0f31edcb658bfd5aeb96
--- /dev/null
+++ b/.huggingface/download/app.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+869732a935815773960fec8cd94428792ba3f924
+1723652257.0650802
diff --git a/.huggingface/download/examples/captured_p.webp.lock b/.huggingface/download/examples/captured_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/captured_p.webp.metadata b/.huggingface/download/examples/captured_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..84e056fd03d4911b03ed62d158969d2d0616e152
--- /dev/null
+++ b/.huggingface/download/examples/captured_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+a0b83623bd3e8528385869ca5370eb5fc886b6f5
+1723652257.0598044
diff --git a/.huggingface/download/examples/chair_p.webp.lock b/.huggingface/download/examples/chair_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/chair_p.webp.metadata b/.huggingface/download/examples/chair_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..2bfd2c82bd5e757b3c0ce114f118c75c356f2509
--- /dev/null
+++ b/.huggingface/download/examples/chair_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+a7812c95b549d96ed0495ca85af88441d3f9d457
+1723652257.061323
diff --git a/.huggingface/download/examples/flamingo_p.webp.lock b/.huggingface/download/examples/flamingo_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/flamingo_p.webp.metadata b/.huggingface/download/examples/flamingo_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..308b248bdac3dfd20edeba3941c2e74c39eaec46
--- /dev/null
+++ b/.huggingface/download/examples/flamingo_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+5b0164267efdcd7aa67dc4993ba0f28f05b75d62
+1723652257.0590506
diff --git a/.huggingface/download/examples/hamburger_p.webp.lock b/.huggingface/download/examples/hamburger_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/hamburger_p.webp.metadata b/.huggingface/download/examples/hamburger_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..c6c3438bfa4ccfa08cd00a94d9a9ead8b5b0ac45
--- /dev/null
+++ b/.huggingface/download/examples/hamburger_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+7e5b467a5bf2c52c67e34ae55457cd7940884fe9
+1723652257.0512598
diff --git a/.huggingface/download/examples/horse_p.webp.lock b/.huggingface/download/examples/horse_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/horse_p.webp.metadata b/.huggingface/download/examples/horse_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..6e92f72b901c4d5fe0228870243e005d49c7943e
--- /dev/null
+++ b/.huggingface/download/examples/horse_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+c3bd936b078a421f12321f12fc72a45936dac1f3
+1723652257.069789
diff --git a/.huggingface/download/examples/iso_house.webp.lock b/.huggingface/download/examples/iso_house.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/iso_house.webp.metadata b/.huggingface/download/examples/iso_house.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..76e129c6b5bb2921455ee24a5414795944f3bc7b
--- /dev/null
+++ b/.huggingface/download/examples/iso_house.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+1d8d71b4417299de911fedfe7a89118eefd7103a
+1723652257.4452686
diff --git a/.huggingface/download/examples/marble_p.webp.lock b/.huggingface/download/examples/marble_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/marble_p.webp.metadata b/.huggingface/download/examples/marble_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..5f40ebb25ae513abbc67beb9926f4a8debd13d16
--- /dev/null
+++ b/.huggingface/download/examples/marble_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+548a0862bbedf4fca1bd243021703481e81443b5
+1723652257.4804113
diff --git a/.huggingface/download/examples/police_woman_p.webp.lock b/.huggingface/download/examples/police_woman_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/police_woman_p.webp.metadata b/.huggingface/download/examples/police_woman_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..21cf7ad94c8c0ae903e2b229614928cb715c7e00
--- /dev/null
+++ b/.huggingface/download/examples/police_woman_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+a860e1cfe4e28a9cae59239e851cead16193e4b8
+1723652257.419765
diff --git a/.huggingface/download/examples/poly_fox.webp.lock b/.huggingface/download/examples/poly_fox.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/poly_fox.webp.metadata b/.huggingface/download/examples/poly_fox.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..3876446e8fbd33cf057308d5228c77032f5593e9
--- /dev/null
+++ b/.huggingface/download/examples/poly_fox.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+e749c683447d9d51dd17671a87c406ac8abba6ba
+1723652257.612526
diff --git a/.huggingface/download/examples/robot_p.webp.lock b/.huggingface/download/examples/robot_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/robot_p.webp.metadata b/.huggingface/download/examples/robot_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..ad85c04c64e4e95b221bc3d989697b35856923ef
--- /dev/null
+++ b/.huggingface/download/examples/robot_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+ffe480ebfe0216b8701aa35e29152ac778ca7c57
+1723652257.4218733
diff --git a/.huggingface/download/examples/teapot.webp.lock b/.huggingface/download/examples/teapot.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/teapot.webp.metadata b/.huggingface/download/examples/teapot.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..e2751eabb873b31777afc87cff5f8aa01ff7b81f
--- /dev/null
+++ b/.huggingface/download/examples/teapot.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+b6f13c13498eceba7583a95ef7b7a087f84fc73d
+1723652257.4243357
diff --git a/.huggingface/download/examples/tiger_girl.webp.lock b/.huggingface/download/examples/tiger_girl.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/tiger_girl.webp.metadata b/.huggingface/download/examples/tiger_girl.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..326a1003dbc5913344b03c80399b3ad4ecf060d0
--- /dev/null
+++ b/.huggingface/download/examples/tiger_girl.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+489d048b5af2d805cb1ae41eef18a2d8abcb35aa
+1723652257.461125
diff --git a/.huggingface/download/examples/unicorn_p.webp.lock b/.huggingface/download/examples/unicorn_p.webp.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/examples/unicorn_p.webp.metadata b/.huggingface/download/examples/unicorn_p.webp.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..5aca62c94aec54efa6d313ae8c845e721d2c54ad
--- /dev/null
+++ b/.huggingface/download/examples/unicorn_p.webp.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+333d17406c01cfd93b0d7790926db6519bd2df4c
+1723652257.4432185
diff --git a/.huggingface/download/requirements.txt.lock b/.huggingface/download/requirements.txt.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/requirements.txt.metadata b/.huggingface/download/requirements.txt.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..2b57eede632c3e3d0c477de945522c9995d64b0c
--- /dev/null
+++ b/.huggingface/download/requirements.txt.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+c35dd54b42673760547e7b5f03f48d5ff67a7437
+1723652257.9081738
diff --git a/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.lock b/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.metadata b/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..849e51f63498b1835762ba63e92a2f5c34d7f6fc
--- /dev/null
+++ b/.huggingface/download/tsr/__pycache__/system.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+e959f300ee2347d00c603bd99f9bd867dadf4499
+1723652257.8178537
diff --git a/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.lock b/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.metadata b/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..c90bf6d550752a2abe42dd886036eed9ad38f0ce
--- /dev/null
+++ b/.huggingface/download/tsr/__pycache__/utils.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+4fcac64634b0d8e7f5f48ece0f9bd046ca3bedbb
+1723652257.8950891
diff --git a/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..386b36ecea6d7e792c4257c3e09c6ee61a51c459
--- /dev/null
+++ b/.huggingface/download/tsr/models/__pycache__/camera.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+cf99205bb6f64375040f3c92d02df1f48695385c
+1723652258.0100667
diff --git a/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..01801a02b4a3a703a85aae56a3b5c2a661a9f7e3
--- /dev/null
+++ b/.huggingface/download/tsr/models/__pycache__/isosurface.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+5858895751ab820d26e557b5e9dfebc41679cae1
+1723652257.9176967
diff --git a/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..75f312f908dd2a6393b444a60923544ba971094b
--- /dev/null
+++ b/.huggingface/download/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+f165c9fdb535e8888a8ab3393c6f6b8d7deb9065
+1723652257.9275925
diff --git a/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.lock b/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..a29c0aed75b71b461bb5f9fcf6ba3667f1c47482
--- /dev/null
+++ b/.huggingface/download/tsr/models/__pycache__/network_utils.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+9681517e0c22f29e8b5bc02aaf2a29f06087e554
+1723652257.8843312
diff --git a/.huggingface/download/tsr/models/isosurface.py.lock b/.huggingface/download/tsr/models/isosurface.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/isosurface.py.metadata b/.huggingface/download/tsr/models/isosurface.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..11d635ad9c5bc5b7ab797716a3cdf3d7a8b74320
--- /dev/null
+++ b/.huggingface/download/tsr/models/isosurface.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+39c8389b5562f5ebb787ef85bcfff56d85aa51db
+1723652258.0929828
diff --git a/.huggingface/download/tsr/models/nerf_renderer.py.lock b/.huggingface/download/tsr/models/nerf_renderer.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/nerf_renderer.py.metadata b/.huggingface/download/tsr/models/nerf_renderer.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..6871c9be20e24cf2f1d444fbe61a39f9acb74c90
--- /dev/null
+++ b/.huggingface/download/tsr/models/nerf_renderer.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+b2d3e652c423e62c4e54fdf7c0751602f51b107b
+1723652258.3381011
diff --git a/.huggingface/download/tsr/models/network_utils.py.lock b/.huggingface/download/tsr/models/network_utils.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/network_utils.py.metadata b/.huggingface/download/tsr/models/network_utils.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..638d7f879fc013f3edd9fa40fe73092bc581d61b
--- /dev/null
+++ b/.huggingface/download/tsr/models/network_utils.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+3844f533bf3b6c9afce6de3857255ee08125b1ba
+1723652258.4015312
diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.lock b/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..b8f886c929be63b932bb58094f3f165a84aceb6b
--- /dev/null
+++ b/.huggingface/download/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+a492f56dc156938d3250d77f5c182ab05e1655ea
+1723652258.4204426
diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.lock b/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..7e4b7dad51e835a7234aa7d3b9b6ef011b3c96bb
--- /dev/null
+++ b/.huggingface/download/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+06d33036639d982296542f5d441f991d3e01ddb0
+1723652258.2991421
diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.lock b/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..e6cdf47e17272f30e45a0c997772301a1764f5bc
--- /dev/null
+++ b/.huggingface/download/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+31191fd86c3d1668aefd27059738590dc066cf7b
+1723652258.3975506
diff --git a/.huggingface/download/tsr/models/tokenizers/image.py.lock b/.huggingface/download/tsr/models/tokenizers/image.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/tokenizers/image.py.metadata b/.huggingface/download/tsr/models/tokenizers/image.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..cd4a705ca9dcc22addb8107007688fa2c8c4b32e
--- /dev/null
+++ b/.huggingface/download/tsr/models/tokenizers/image.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+34a092e7170c87da53363822af554c78e5f8083f
+1723652258.3694685
diff --git a/.huggingface/download/tsr/models/tokenizers/triplane.py.lock b/.huggingface/download/tsr/models/tokenizers/triplane.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/tokenizers/triplane.py.metadata b/.huggingface/download/tsr/models/tokenizers/triplane.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..d9a4cc267e07edb6e7035af97ba177134a5ab2f4
--- /dev/null
+++ b/.huggingface/download/tsr/models/tokenizers/triplane.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+ecdd7fd2201c974bb70b18a90a633287b814886f
+1723652258.4722672
diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.lock b/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..792756604534a357a91f5480a6fb73e212afe977
--- /dev/null
+++ b/.huggingface/download/tsr/models/transformer/__pycache__/attention.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+02bf1262175cdbe2bb9651042c247f53c2ab0e91
+1723652258.4920888
diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.lock b/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..77bf98742ff9075cd6073d6d9905caa208e24a30
--- /dev/null
+++ b/.huggingface/download/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+8c2f3d2a8c487f237c94433bd8aaafd26afb8ce0
+1723652258.853393
diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.lock b/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.metadata b/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..53b37dac48350b199e22225212a2a9667237b50f
--- /dev/null
+++ b/.huggingface/download/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+39cb242f767c41c73138dda36700ef1554bbf31f
+1723652258.8555367
diff --git a/.huggingface/download/tsr/models/transformer/attention.py.lock b/.huggingface/download/tsr/models/transformer/attention.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/transformer/attention.py.metadata b/.huggingface/download/tsr/models/transformer/attention.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..45ef4e0f4dc77f49f54d0b8c7eae271b3bffa4ca
--- /dev/null
+++ b/.huggingface/download/tsr/models/transformer/attention.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+eb873231e3ad195a7fef3a2c4ef217be3056cd4e
+1723652258.8082144
diff --git a/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.lock b/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.metadata b/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..a1abfd5d05f5836c7ddc1713c7f3a82de169d246
--- /dev/null
+++ b/.huggingface/download/tsr/models/transformer/basic_transformer_block.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+a55ea6189c9e1da2f86f6d0957ae30f26fefff0a
+1723652258.8690145
diff --git a/.huggingface/download/tsr/models/transformer/transformer_1d.py.lock b/.huggingface/download/tsr/models/transformer/transformer_1d.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/models/transformer/transformer_1d.py.metadata b/.huggingface/download/tsr/models/transformer/transformer_1d.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..336fa56e70e329398341a7d14b98fd9dc8df05f0
--- /dev/null
+++ b/.huggingface/download/tsr/models/transformer/transformer_1d.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+7546232c5126a6622f1fd701ba36f9d4b53b9178
+1723652258.8556223
diff --git a/.huggingface/download/tsr/system.py.lock b/.huggingface/download/tsr/system.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/system.py.metadata b/.huggingface/download/tsr/system.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..3315edb602420260e8297a8513defa1ad0abe77c
--- /dev/null
+++ b/.huggingface/download/tsr/system.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+de5ae6ef75afd3bbbe0bdfacc7258a8c51409cc5
+1723652258.860138
diff --git a/.huggingface/download/tsr/utils.py.lock b/.huggingface/download/tsr/utils.py.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/tsr/utils.py.metadata b/.huggingface/download/tsr/utils.py.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..873d54339e4474925e9226ba5021cbe23784a8e2
--- /dev/null
+++ b/.huggingface/download/tsr/utils.py.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+6a1b59aef75d02d39b29222d300fe9241bb11444
+1723652258.8708167
diff --git a/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.lock b/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.lock
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.metadata b/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.metadata
new file mode 100644
index 0000000000000000000000000000000000000000..da0763d7e0795a213b3e82b539bab44050db58e0
--- /dev/null
+++ b/.huggingface/download/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl.metadata
@@ -0,0 +1,3 @@
+97d3c892634a70bd5ef0cdac7ad3ef7af3b0fa9e
+4af160ba1274e2205d3529a7b82efdb6946c2158a78e19631ed840301055b8d6
+1723652259.3841252
diff --git a/examples/captured_p.webp b/examples/captured_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a0b83623bd3e8528385869ca5370eb5fc886b6f5
Binary files /dev/null and b/examples/captured_p.webp differ
diff --git a/examples/chair_p.webp b/examples/chair_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a7812c95b549d96ed0495ca85af88441d3f9d457
Binary files /dev/null and b/examples/chair_p.webp differ
diff --git a/examples/flamingo_p.webp b/examples/flamingo_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..5b0164267efdcd7aa67dc4993ba0f28f05b75d62
Binary files /dev/null and b/examples/flamingo_p.webp differ
diff --git a/examples/hamburger_p.webp b/examples/hamburger_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..7e5b467a5bf2c52c67e34ae55457cd7940884fe9
Binary files /dev/null and b/examples/hamburger_p.webp differ
diff --git a/examples/horse_p.webp b/examples/horse_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..c3bd936b078a421f12321f12fc72a45936dac1f3
Binary files /dev/null and b/examples/horse_p.webp differ
diff --git a/examples/iso_house.webp b/examples/iso_house.webp
new file mode 100644
index 0000000000000000000000000000000000000000..1d8d71b4417299de911fedfe7a89118eefd7103a
Binary files /dev/null and b/examples/iso_house.webp differ
diff --git a/examples/marble_p.webp b/examples/marble_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..548a0862bbedf4fca1bd243021703481e81443b5
Binary files /dev/null and b/examples/marble_p.webp differ
diff --git a/examples/police_woman_p.webp b/examples/police_woman_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..a860e1cfe4e28a9cae59239e851cead16193e4b8
Binary files /dev/null and b/examples/police_woman_p.webp differ
diff --git a/examples/poly_fox.webp b/examples/poly_fox.webp
new file mode 100644
index 0000000000000000000000000000000000000000..e749c683447d9d51dd17671a87c406ac8abba6ba
Binary files /dev/null and b/examples/poly_fox.webp differ
diff --git a/examples/robot_p.webp b/examples/robot_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..ffe480ebfe0216b8701aa35e29152ac778ca7c57
Binary files /dev/null and b/examples/robot_p.webp differ
diff --git a/examples/teapot.webp b/examples/teapot.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b6f13c13498eceba7583a95ef7b7a087f84fc73d
Binary files /dev/null and b/examples/teapot.webp differ
diff --git a/examples/tiger_girl.webp b/examples/tiger_girl.webp
new file mode 100644
index 0000000000000000000000000000000000000000..489d048b5af2d805cb1ae41eef18a2d8abcb35aa
Binary files /dev/null and b/examples/tiger_girl.webp differ
diff --git a/examples/unicorn_p.webp b/examples/unicorn_p.webp
new file mode 100644
index 0000000000000000000000000000000000000000..333d17406c01cfd93b0d7790926db6519bd2df4c
Binary files /dev/null and b/examples/unicorn_p.webp differ
diff --git a/tsr/__pycache__/system.cpython-310.pyc b/tsr/__pycache__/system.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e959f300ee2347d00c603bd99f9bd867dadf4499
Binary files /dev/null and b/tsr/__pycache__/system.cpython-310.pyc differ
diff --git a/tsr/__pycache__/utils.cpython-310.pyc b/tsr/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4fcac64634b0d8e7f5f48ece0f9bd046ca3bedbb
Binary files /dev/null and b/tsr/__pycache__/utils.cpython-310.pyc differ
diff --git a/tsr/models/__pycache__/camera.cpython-310.pyc b/tsr/models/__pycache__/camera.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf99205bb6f64375040f3c92d02df1f48695385c
Binary files /dev/null and b/tsr/models/__pycache__/camera.cpython-310.pyc differ
diff --git a/tsr/models/__pycache__/isosurface.cpython-310.pyc b/tsr/models/__pycache__/isosurface.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5858895751ab820d26e557b5e9dfebc41679cae1
Binary files /dev/null and b/tsr/models/__pycache__/isosurface.cpython-310.pyc differ
diff --git a/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc b/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f165c9fdb535e8888a8ab3393c6f6b8d7deb9065
Binary files /dev/null and b/tsr/models/__pycache__/nerf_renderer.cpython-310.pyc differ
diff --git a/tsr/models/__pycache__/network_utils.cpython-310.pyc b/tsr/models/__pycache__/network_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9681517e0c22f29e8b5bc02aaf2a29f06087e554
Binary files /dev/null and b/tsr/models/__pycache__/network_utils.cpython-310.pyc differ
diff --git a/tsr/models/isosurface.py b/tsr/models/isosurface.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c8389b5562f5ebb787ef85bcfff56d85aa51db
--- /dev/null
+++ b/tsr/models/isosurface.py
@@ -0,0 +1,48 @@
+from typing import Callable, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torchmcubes import marching_cubes
+
+
+class IsosurfaceHelper(nn.Module):
+    points_range: Tuple[float, float] = (0, 1)
+
+    @property
+    def grid_vertices(self) -> torch.FloatTensor:
+        raise NotImplementedError
+
+
+class MarchingCubeHelper(IsosurfaceHelper):
+    def __init__(self, resolution: int) -> None:
+        super().__init__()
+        self.resolution = resolution
+        self.mc_func: Callable = marching_cubes
+        self._grid_vertices: Optional[torch.FloatTensor] = None
+
+    @property
+    def grid_vertices(self) -> torch.FloatTensor:
+        if self._grid_vertices is None:
+            # keep the vertices on CPU so that we can support very large resolution
+            x, y, z = (
+                torch.linspace(*self.points_range, self.resolution),
+                torch.linspace(*self.points_range, self.resolution),
+                torch.linspace(*self.points_range, self.resolution),
+            )
+            x, y, z = torch.meshgrid(x, y, z, indexing="ij")
+            verts = torch.cat(
+                [x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1
+            ).reshape(-1, 3)
+            self._grid_vertices = verts
+        return self._grid_vertices
+
+    def forward(
+        self,
+        level: torch.FloatTensor,
+    ) -> Tuple[torch.FloatTensor, torch.LongTensor]:
+        level = -level.view(self.resolution, self.resolution, self.resolution)
+        v_pos, t_pos_idx = self.mc_func(level.detach(), 0.0)
+        v_pos = v_pos[..., [2, 1, 0]]
+        v_pos = v_pos / (self.resolution - 1.0)
+        return v_pos, t_pos_idx
diff --git a/tsr/models/nerf_renderer.py b/tsr/models/nerf_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2d3e652c423e62c4e54fdf7c0751602f51b107b
--- /dev/null
+++ b/tsr/models/nerf_renderer.py
@@ -0,0 +1,180 @@
+from dataclasses import dataclass, field
+from typing import Dict
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, reduce
+
+from ..utils import (
+    BaseModule,
+    chunk_batch,
+    get_activation,
+    rays_intersect_bbox,
+    scale_tensor,
+)
+
+
+class TriplaneNeRFRenderer(BaseModule):
+    @dataclass
+    class Config(BaseModule.Config):
+        radius: float
+
+        feature_reduction: str = "concat"
+        density_activation: str = "trunc_exp"
+        density_bias: float = -1.0
+        color_activation: str = "sigmoid"
+        num_samples_per_ray: int = 128
+        randomized: bool = False
+
+    cfg: Config
+
+    def configure(self) -> None:
+        assert self.cfg.feature_reduction in ["concat", "mean"]
+        self.chunk_size = 0
+
+    def set_chunk_size(self, chunk_size: int):
+        assert (
+            chunk_size >= 0
+        ), "chunk_size must be a non-negative integer (0 for no chunking)."
+        self.chunk_size = chunk_size
+
+    def query_triplane(
+        self,
+        decoder: torch.nn.Module,
+        positions: torch.Tensor,
+        triplane: torch.Tensor,
+    ) -> Dict[str, torch.Tensor]:
+        input_shape = positions.shape[:-1]
+        positions = positions.view(-1, 3)
+
+        # positions in (-radius, radius)
+        # normalized to (-1, 1) for grid sample
+        positions = scale_tensor(
+            positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
+        )
+
+        def _query_chunk(x):
+            indices2D: torch.Tensor = torch.stack(
+                (x[..., [0, 1]], x[..., [0, 2]], x[..., [1, 2]]),
+                dim=-3,
+            )
+            out: torch.Tensor = F.grid_sample(
+                rearrange(triplane, "Np Cp Hp Wp -> Np Cp Hp Wp", Np=3),
+                rearrange(indices2D, "Np N Nd -> Np () N Nd", Np=3),
+                align_corners=False,
+                mode="bilinear",
+            )
+            if self.cfg.feature_reduction == "concat":
+                out = rearrange(out, "Np Cp () N -> N (Np Cp)", Np=3)
+            elif self.cfg.feature_reduction == "mean":
+                out = reduce(out, "Np Cp () N -> N Cp", Np=3, reduction="mean")
+            else:
+                raise NotImplementedError
+
+            net_out: Dict[str, torch.Tensor] = decoder(out)
+            return net_out
+
+        if self.chunk_size > 0:
+            net_out = chunk_batch(_query_chunk, self.chunk_size, positions)
+        else:
+            net_out = _query_chunk(positions)
+
+        net_out["density_act"] = get_activation(self.cfg.density_activation)(
+            net_out["density"] + self.cfg.density_bias
+        )
+        net_out["color"] = get_activation(self.cfg.color_activation)(
+            net_out["features"]
+        )
+
+        net_out = {k: v.view(*input_shape, -1) for k, v in net_out.items()}
+
+        return net_out
+
+    def _forward(
+        self,
+        decoder: torch.nn.Module,
+        triplane: torch.Tensor,
+        rays_o: torch.Tensor,
+        rays_d: torch.Tensor,
+        **kwargs,
+    ):
+        rays_shape = rays_o.shape[:-1]
+        rays_o = rays_o.view(-1, 3)
+        rays_d = rays_d.view(-1, 3)
+        n_rays = rays_o.shape[0]
+
+        t_near, t_far, rays_valid = rays_intersect_bbox(rays_o, rays_d, self.cfg.radius)
+        t_near, t_far = t_near[rays_valid], t_far[rays_valid]
+
+        t_vals = torch.linspace(
+            0, 1, self.cfg.num_samples_per_ray + 1, device=triplane.device
+        )
+        t_mid = (t_vals[:-1] + t_vals[1:]) / 2.0
+        z_vals = t_near * (1 - t_mid[None]) + t_far * t_mid[None]  # (N_rays, N_samples)
+
+        xyz = (
+            rays_o[:, None, :] + z_vals[..., None] * rays_d[..., None, :]
+        )  # (N_rays, N_sample, 3)
+
+        mlp_out = self.query_triplane(
+            decoder=decoder,
+            positions=xyz,
+            triplane=triplane,
+        )
+
+        eps = 1e-10
+        # deltas = z_vals[:, 1:] - z_vals[:, :-1] # (N_rays, N_samples)
+        deltas = t_vals[1:] - t_vals[:-1]  # (N_rays, N_samples)
+        alpha = 1 - torch.exp(
+            -deltas * mlp_out["density_act"][..., 0]
+        )  # (N_rays, N_samples)
+        accum_prod = torch.cat(
+            [
+                torch.ones_like(alpha[:, :1]),
+                torch.cumprod(1 - alpha[:, :-1] + eps, dim=-1),
+            ],
+            dim=-1,
+        )
+        weights = alpha * accum_prod  # (N_rays, N_samples)
+        comp_rgb_ = (weights[..., None] * mlp_out["color"]).sum(dim=-2)  # (N_rays, 3)
+        opacity_ = weights.sum(dim=-1)  # (N_rays)
+
+        comp_rgb = torch.zeros(
+            n_rays, 3, dtype=comp_rgb_.dtype, device=comp_rgb_.device
+        )
+        opacity = torch.zeros(n_rays, dtype=opacity_.dtype, device=opacity_.device)
+        comp_rgb[rays_valid] = comp_rgb_
+        opacity[rays_valid] = opacity_
+
+        comp_rgb += 1 - opacity[..., None]
+        comp_rgb = comp_rgb.view(*rays_shape, 3)
+
+        return comp_rgb
+
+    def forward(
+        self,
+        decoder: torch.nn.Module,
+        triplane: torch.Tensor,
+        rays_o: torch.Tensor,
+        rays_d: torch.Tensor,
+    ) -> Dict[str, torch.Tensor]:
+        if triplane.ndim == 4:
+            comp_rgb = self._forward(decoder, triplane, rays_o, rays_d)
+        else:
+            comp_rgb = torch.stack(
+                [
+                    self._forward(decoder, triplane[i], rays_o[i], rays_d[i])
+                    for i in range(triplane.shape[0])
+                ],
+                dim=0,
+            )
+
+        return comp_rgb
+
+    def train(self, mode=True):
+        self.randomized = mode and self.cfg.randomized
+        return super().train(mode=mode)
+
+    def eval(self):
+        self.randomized = False
+        return super().eval()
diff --git a/tsr/models/network_utils.py b/tsr/models/network_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3844f533bf3b6c9afce6de3857255ee08125b1ba
--- /dev/null
+++ b/tsr/models/network_utils.py
@@ -0,0 +1,124 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from ..utils import BaseModule
+
+
+class TriplaneUpsampleNetwork(BaseModule):
+    @dataclass
+    class Config(BaseModule.Config):
+        in_channels: int
+        out_channels: int
+
+    cfg: Config
+
+    def configure(self) -> None:
+        self.upsample = nn.ConvTranspose2d(
+            self.cfg.in_channels, self.cfg.out_channels, kernel_size=2, stride=2
+        )
+
+    def forward(self, triplanes: torch.Tensor) -> torch.Tensor:
+        triplanes_up = rearrange(
+            self.upsample(
+                rearrange(triplanes, "B Np Ci Hp Wp -> (B Np) Ci Hp Wp", Np=3)
+            ),
+            "(B Np) Co Hp Wp -> B Np Co Hp Wp",
+            Np=3,
+        )
+        return triplanes_up
+
+
+class NeRFMLP(BaseModule):
+    @dataclass
+    class Config(BaseModule.Config):
+        in_channels: int
+        n_neurons: int
+        n_hidden_layers: int
+        activation: str = "relu"
+        bias: bool = True
+        weight_init: Optional[str] = "kaiming_uniform"
+        bias_init: Optional[str] = None
+
+    cfg: Config
+
+    def configure(self) -> None:
+        layers = [
+            self.make_linear(
+                self.cfg.in_channels,
+                self.cfg.n_neurons,
+                bias=self.cfg.bias,
+                weight_init=self.cfg.weight_init,
+                bias_init=self.cfg.bias_init,
+            ),
+            self.make_activation(self.cfg.activation),
+        ]
+        for i in range(self.cfg.n_hidden_layers - 1):
+            layers += [
+                self.make_linear(
+                    self.cfg.n_neurons,
+                    self.cfg.n_neurons,
+                    bias=self.cfg.bias,
+                    weight_init=self.cfg.weight_init,
+                    bias_init=self.cfg.bias_init,
+                ),
+                self.make_activation(self.cfg.activation),
+            ]
+        layers += [
+            self.make_linear(
+                self.cfg.n_neurons,
+                4,  # density 1 + features 3
+                bias=self.cfg.bias,
+                weight_init=self.cfg.weight_init,
+                bias_init=self.cfg.bias_init,
+            )
+        ]
+        self.layers = nn.Sequential(*layers)
+
+    def make_linear(
+        self,
+        dim_in,
+        dim_out,
+        bias=True,
+        weight_init=None,
+        bias_init=None,
+    ):
+        layer = nn.Linear(dim_in, dim_out, bias=bias)
+
+        if weight_init is None:
+            pass
+        elif weight_init == "kaiming_uniform":
+            torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity="relu")
+        else:
+            raise NotImplementedError
+
+        if bias:
+            if bias_init is None:
+                pass
+            elif bias_init == "zero":
+                torch.nn.init.zeros_(layer.bias)
+            else:
+                raise NotImplementedError
+
+        return layer
+
+    def make_activation(self, activation):
+        if activation == "relu":
+            return nn.ReLU(inplace=True)
+        elif activation == "silu":
+            return nn.SiLU(inplace=True)
+        else:
+            raise NotImplementedError
+
+    def forward(self, x):
+        inp_shape = x.shape[:-1]
+        x = x.reshape(-1, x.shape[-1])
+
+        features = self.layers(x)
+        features = features.reshape(*inp_shape, -1)
+        out = {"density": features[..., 0:1], "features": features[..., 1:4]}
+
+        return out
diff --git a/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc b/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a492f56dc156938d3250d77f5c182ab05e1655ea
Binary files /dev/null and b/tsr/models/tokenizers/__pycache__/dino.cpython-310.pyc differ
diff --git a/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc b/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06d33036639d982296542f5d441f991d3e01ddb0
Binary files /dev/null and b/tsr/models/tokenizers/__pycache__/image.cpython-310.pyc differ
diff --git a/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc b/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31191fd86c3d1668aefd27059738590dc066cf7b
Binary files /dev/null and b/tsr/models/tokenizers/__pycache__/triplane.cpython-310.pyc differ
diff --git a/tsr/models/tokenizers/image.py b/tsr/models/tokenizers/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a092e7170c87da53363822af554c78e5f8083f
--- /dev/null
+++ b/tsr/models/tokenizers/image.py
@@ -0,0 +1,67 @@
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from huggingface_hub import hf_hub_download
+from transformers.models.vit.modeling_vit import ViTModel
+
+from ...utils import BaseModule
+
+
+class DINOSingleImageTokenizer(BaseModule):
+    @dataclass
+    class Config(BaseModule.Config):
+        pretrained_model_name_or_path: str = "facebook/dino-vitb16"
+        enable_gradient_checkpointing: bool = False
+
+    cfg: Config
+
+    def configure(self) -> None:
+        self.model: ViTModel = ViTModel(
+            ViTModel.config_class.from_pretrained(
+                hf_hub_download(
+                    repo_id=self.cfg.pretrained_model_name_or_path,
+                    filename="config.json",
+                )
+            )
+        )
+
+        if self.cfg.enable_gradient_checkpointing:
+            self.model.encoder.gradient_checkpointing = True
+
+        self.register_buffer(
+            "image_mean",
+            torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
+            persistent=False,
+        )
+        self.register_buffer(
+            "image_std",
+            torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
+            persistent=False,
+        )
+
+    def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
+        packed = False
+        if images.ndim == 4:
+            packed = True
+            images = images.unsqueeze(1)
+
+        batch_size, n_input_views = images.shape[:2]
+        images = (images - self.image_mean) / self.image_std
+        out = self.model(
+            rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
+        )
+        local_features, global_features = out.last_hidden_state, out.pooler_output
+        local_features = local_features.permute(0, 2, 1)
+        local_features = rearrange(
+            local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
+        )
+        if packed:
+            local_features = local_features.squeeze(1)
+
+        return local_features
+
+    def detokenize(self, *args, **kwargs):
+        raise NotImplementedError
diff --git a/tsr/models/tokenizers/triplane.py b/tsr/models/tokenizers/triplane.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecdd7fd2201c974bb70b18a90a633287b814886f
--- /dev/null
+++ b/tsr/models/tokenizers/triplane.py
@@ -0,0 +1,45 @@
+import math
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from ...utils import BaseModule
+
+
+class Triplane1DTokenizer(BaseModule):
+    @dataclass
+    class Config(BaseModule.Config):
+        plane_size: int
+        num_channels: int
+
+    cfg: Config
+
+    def configure(self) -> None:
+        self.embeddings = nn.Parameter(
+            torch.randn(
+                (3, self.cfg.num_channels, self.cfg.plane_size, self.cfg.plane_size),
+                dtype=torch.float32,
+            )
+            * 1
+            / math.sqrt(self.cfg.num_channels)
+        )
+
+    def forward(self, batch_size: int) -> torch.Tensor:
+        return rearrange(
+            repeat(self.embeddings, "Np Ct Hp Wp -> B Np Ct Hp Wp", B=batch_size),
+            "B Np Ct Hp Wp -> B Ct (Np Hp Wp)",
+        )
+
+    def detokenize(self, tokens: torch.Tensor) -> torch.Tensor:
+        batch_size, Ct, Nt = tokens.shape
+        assert Nt == self.cfg.plane_size**2 * 3
+        assert Ct == self.cfg.num_channels
+        return rearrange(
+            tokens,
+            "B Ct (Np Hp Wp) -> B Np Ct Hp Wp",
+            Np=3,
+            Hp=self.cfg.plane_size,
+            Wp=self.cfg.plane_size,
+        )
diff --git a/tsr/models/transformer/__pycache__/attention.cpython-310.pyc b/tsr/models/transformer/__pycache__/attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..02bf1262175cdbe2bb9651042c247f53c2ab0e91
Binary files /dev/null and b/tsr/models/transformer/__pycache__/attention.cpython-310.pyc differ
diff --git a/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc b/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c2f3d2a8c487f237c94433bd8aaafd26afb8ce0
Binary files /dev/null and b/tsr/models/transformer/__pycache__/basic_transformer_block.cpython-310.pyc differ
diff --git a/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc b/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39cb242f767c41c73138dda36700ef1554bbf31f
Binary files /dev/null and b/tsr/models/transformer/__pycache__/transformer_1d.cpython-310.pyc differ
diff --git a/tsr/models/transformer/attention.py b/tsr/models/transformer/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb873231e3ad195a7fef3a2c4ef217be3056cd4e
--- /dev/null
+++ b/tsr/models/transformer/attention.py
@@ -0,0 +1,628 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Attention(nn.Module):
+    r"""
+    A cross attention layer.
+
+    Parameters:
+        query_dim (`int`):
+            The number of channels in the query.
+        cross_attention_dim (`int`, *optional*):
+            The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+        heads (`int`,  *optional*, defaults to 8):
+            The number of heads to use for multi-head attention.
+        dim_head (`int`,  *optional*, defaults to 64):
+            The number of channels in each head.
+        dropout (`float`, *optional*, defaults to 0.0):
+            The dropout probability to use.
+        bias (`bool`, *optional*, defaults to False):
+            Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+        upcast_attention (`bool`, *optional*, defaults to False):
+            Set to `True` to upcast the attention computation to `float32`.
+        upcast_softmax (`bool`, *optional*, defaults to False):
+            Set to `True` to upcast the softmax computation to `float32`.
+        cross_attention_norm (`str`, *optional*, defaults to `None`):
+            The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+        cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+            The number of groups to use for the group norm in the cross attention.
+        added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+            The number of channels to use for the added key and value projections. If `None`, no projection is used.
+        norm_num_groups (`int`, *optional*, defaults to `None`):
+            The number of groups to use for the group norm in the attention.
+        spatial_norm_dim (`int`, *optional*, defaults to `None`):
+            The number of channels to use for the spatial normalization.
+        out_bias (`bool`, *optional*, defaults to `True`):
+            Set to `True` to use a bias in the output linear layer.
+        scale_qk (`bool`, *optional*, defaults to `True`):
+            Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+        only_cross_attention (`bool`, *optional*, defaults to `False`):
+            Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+            `added_kv_proj_dim` is not `None`.
+        eps (`float`, *optional*, defaults to 1e-5):
+            An additional value added to the denominator in group normalization that is used for numerical stability.
+        rescale_output_factor (`float`, *optional*, defaults to 1.0):
+            A factor to rescale the output by dividing it with this value.
+        residual_connection (`bool`, *optional*, defaults to `False`):
+            Set to `True` to add the residual connection to the output.
+        _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+            Set to `True` if the attention block is loaded from a deprecated state dict.
+        processor (`AttnProcessor`, *optional*, defaults to `None`):
+            The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+            `AttnProcessor` otherwise.
+    """
+
+    def __init__(
+        self,
+        query_dim: int,
+        cross_attention_dim: Optional[int] = None,
+        heads: int = 8,
+        dim_head: int = 64,
+        dropout: float = 0.0,
+        bias: bool = False,
+        upcast_attention: bool = False,
+        upcast_softmax: bool = False,
+        cross_attention_norm: Optional[str] = None,
+        cross_attention_norm_num_groups: int = 32,
+        added_kv_proj_dim: Optional[int] = None,
+        norm_num_groups: Optional[int] = None,
+        out_bias: bool = True,
+        scale_qk: bool = True,
+        only_cross_attention: bool = False,
+        eps: float = 1e-5,
+        rescale_output_factor: float = 1.0,
+        residual_connection: bool = False,
+        _from_deprecated_attn_block: bool = False,
+        processor: Optional["AttnProcessor"] = None,
+        out_dim: int = None,
+    ):
+        super().__init__()
+        self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+        self.query_dim = query_dim
+        self.cross_attention_dim = (
+            cross_attention_dim if cross_attention_dim is not None else query_dim
+        )
+        self.upcast_attention = upcast_attention
+        self.upcast_softmax = upcast_softmax
+        self.rescale_output_factor = rescale_output_factor
+        self.residual_connection = residual_connection
+        self.dropout = dropout
+        self.fused_projections = False
+        self.out_dim = out_dim if out_dim is not None else query_dim
+
+        # we make use of this private variable to know whether this class is loaded
+        # with an deprecated state dict so that we can convert it on the fly
+        self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+        self.scale_qk = scale_qk
+        self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+        self.heads = out_dim // dim_head if out_dim is not None else heads
+        # for slice_size > 0 the attention score computation
+        # is split across the batch axis to save memory
+        # You can set slice_size with `set_attention_slice`
+        self.sliceable_head_dim = heads
+
+        self.added_kv_proj_dim = added_kv_proj_dim
+        self.only_cross_attention = only_cross_attention
+
+        if self.added_kv_proj_dim is None and self.only_cross_attention:
+            raise ValueError(
+                "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+            )
+
+        if norm_num_groups is not None:
+            self.group_norm = nn.GroupNorm(
+                num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
+            )
+        else:
+            self.group_norm = None
+
+        self.spatial_norm = None
+
+        if cross_attention_norm is None:
+            self.norm_cross = None
+        elif cross_attention_norm == "layer_norm":
+            self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+        elif cross_attention_norm == "group_norm":
+            if self.added_kv_proj_dim is not None:
+                # The given `encoder_hidden_states` are initially of shape
+                # (batch_size, seq_len, added_kv_proj_dim) before being projected
+                # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+                # before the projection, so we need to use `added_kv_proj_dim` as
+                # the number of channels for the group norm.
+                norm_cross_num_channels = added_kv_proj_dim
+            else:
+                norm_cross_num_channels = self.cross_attention_dim
+
+            self.norm_cross = nn.GroupNorm(
+                num_channels=norm_cross_num_channels,
+                num_groups=cross_attention_norm_num_groups,
+                eps=1e-5,
+                affine=True,
+            )
+        else:
+            raise ValueError(
+                f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+            )
+
+        linear_cls = nn.Linear
+
+        self.linear_cls = linear_cls
+        self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
+
+        if not self.only_cross_attention:
+            # only relevant for the `AddedKVProcessor` classes
+            self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+            self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
+        else:
+            self.to_k = None
+            self.to_v = None
+
+        if self.added_kv_proj_dim is not None:
+            self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+            self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
+
+        self.to_out = nn.ModuleList([])
+        self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
+        self.to_out.append(nn.Dropout(dropout))
+
+        # set attention processor
+        # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+        # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+        # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+        if processor is None:
+            processor = (
+                AttnProcessor2_0()
+                if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+                else AttnProcessor()
+            )
+        self.set_processor(processor)
+
+    def set_processor(self, processor: "AttnProcessor") -> None:
+        self.processor = processor
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        **cross_attention_kwargs,
+    ) -> torch.Tensor:
+        r"""
+        The forward method of the `Attention` class.
+
+        Args:
+            hidden_states (`torch.Tensor`):
+                The hidden states of the query.
+            encoder_hidden_states (`torch.Tensor`, *optional*):
+                The hidden states of the encoder.
+            attention_mask (`torch.Tensor`, *optional*):
+                The attention mask to use. If `None`, no mask is applied.
+            **cross_attention_kwargs:
+                Additional keyword arguments to pass along to the cross attention.
+
+        Returns:
+            `torch.Tensor`: The output of the attention layer.
+        """
+        # The `Attention` class can call different attention processors / attention functions
+        # here we simply pass along all tensors to the selected processor class
+        # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+        return self.processor(
+            self,
+            hidden_states,
+            encoder_hidden_states=encoder_hidden_states,
+            attention_mask=attention_mask,
+            **cross_attention_kwargs,
+        )
+
+    def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+        r"""
+        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
+        is the number of heads initialized while constructing the `Attention` class.
+
+        Args:
+            tensor (`torch.Tensor`): The tensor to reshape.
+
+        Returns:
+            `torch.Tensor`: The reshaped tensor.
+        """
+        head_size = self.heads
+        batch_size, seq_len, dim = tensor.shape
+        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+        tensor = tensor.permute(0, 2, 1, 3).reshape(
+            batch_size // head_size, seq_len, dim * head_size
+        )
+        return tensor
+
+    def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+        r"""
+        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
+        the number of heads initialized while constructing the `Attention` class.
+
+        Args:
+            tensor (`torch.Tensor`): The tensor to reshape.
+            out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+                reshaped to `[batch_size * heads, seq_len, dim // heads]`.
+
+        Returns:
+            `torch.Tensor`: The reshaped tensor.
+        """
+        head_size = self.heads
+        batch_size, seq_len, dim = tensor.shape
+        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+        tensor = tensor.permute(0, 2, 1, 3)
+
+        if out_dim == 3:
+            tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
+
+        return tensor
+
+    def get_attention_scores(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        attention_mask: torch.Tensor = None,
+    ) -> torch.Tensor:
+        r"""
+        Compute the attention scores.
+
+        Args:
+            query (`torch.Tensor`): The query tensor.
+            key (`torch.Tensor`): The key tensor.
+            attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+        Returns:
+            `torch.Tensor`: The attention probabilities/scores.
+        """
+        dtype = query.dtype
+        if self.upcast_attention:
+            query = query.float()
+            key = key.float()
+
+        if attention_mask is None:
+            baddbmm_input = torch.empty(
+                query.shape[0],
+                query.shape[1],
+                key.shape[1],
+                dtype=query.dtype,
+                device=query.device,
+            )
+            beta = 0
+        else:
+            baddbmm_input = attention_mask
+            beta = 1
+
+        attention_scores = torch.baddbmm(
+            baddbmm_input,
+            query,
+            key.transpose(-1, -2),
+            beta=beta,
+            alpha=self.scale,
+        )
+        del baddbmm_input
+
+        if self.upcast_softmax:
+            attention_scores = attention_scores.float()
+
+        attention_probs = attention_scores.softmax(dim=-1)
+        del attention_scores
+
+        attention_probs = attention_probs.to(dtype)
+
+        return attention_probs
+
+    def prepare_attention_mask(
+        self,
+        attention_mask: torch.Tensor,
+        target_length: int,
+        batch_size: int,
+        out_dim: int = 3,
+    ) -> torch.Tensor:
+        r"""
+        Prepare the attention mask for the attention computation.
+
+        Args:
+            attention_mask (`torch.Tensor`):
+                The attention mask to prepare.
+            target_length (`int`):
+                The target length of the attention mask. This is the length of the attention mask after padding.
+            batch_size (`int`):
+                The batch size, which is used to repeat the attention mask.
+            out_dim (`int`, *optional*, defaults to `3`):
+                The output dimension of the attention mask. Can be either `3` or `4`.
+
+        Returns:
+            `torch.Tensor`: The prepared attention mask.
+        """
+        head_size = self.heads
+        if attention_mask is None:
+            return attention_mask
+
+        current_length: int = attention_mask.shape[-1]
+        if current_length != target_length:
+            if attention_mask.device.type == "mps":
+                # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+                # Instead, we can manually construct the padding tensor.
+                padding_shape = (
+                    attention_mask.shape[0],
+                    attention_mask.shape[1],
+                    target_length,
+                )
+                padding = torch.zeros(
+                    padding_shape,
+                    dtype=attention_mask.dtype,
+                    device=attention_mask.device,
+                )
+                attention_mask = torch.cat([attention_mask, padding], dim=2)
+            else:
+                # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+                #       we want to instead pad by (0, remaining_length), where remaining_length is:
+                #       remaining_length: int = target_length - current_length
+                # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+                attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+        if out_dim == 3:
+            if attention_mask.shape[0] < batch_size * head_size:
+                attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+        elif out_dim == 4:
+            attention_mask = attention_mask.unsqueeze(1)
+            attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+        return attention_mask
+
+    def norm_encoder_hidden_states(
+        self, encoder_hidden_states: torch.Tensor
+    ) -> torch.Tensor:
+        r"""
+        Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+        `Attention` class.
+
+        Args:
+            encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+        Returns:
+            `torch.Tensor`: The normalized encoder hidden states.
+        """
+        assert (
+            self.norm_cross is not None
+        ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+        if isinstance(self.norm_cross, nn.LayerNorm):
+            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+        elif isinstance(self.norm_cross, nn.GroupNorm):
+            # Group norm norms along the channels dimension and expects
+            # input to be in the shape of (N, C, *). In this case, we want
+            # to norm along the hidden dimension, so we need to move
+            # (batch_size, sequence_length, hidden_size) ->
+            # (batch_size, hidden_size, sequence_length)
+            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+            encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+            encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+        else:
+            assert False
+
+        return encoder_hidden_states
+
+    @torch.no_grad()
+    def fuse_projections(self, fuse=True):
+        is_cross_attention = self.cross_attention_dim != self.query_dim
+        device = self.to_q.weight.data.device
+        dtype = self.to_q.weight.data.dtype
+
+        if not is_cross_attention:
+            # fetch weight matrices.
+            concatenated_weights = torch.cat(
+                [self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]
+            )
+            in_features = concatenated_weights.shape[1]
+            out_features = concatenated_weights.shape[0]
+
+            # create a new single projection layer and copy over the weights.
+            self.to_qkv = self.linear_cls(
+                in_features, out_features, bias=False, device=device, dtype=dtype
+            )
+            self.to_qkv.weight.copy_(concatenated_weights)
+
+        else:
+            concatenated_weights = torch.cat(
+                [self.to_k.weight.data, self.to_v.weight.data]
+            )
+            in_features = concatenated_weights.shape[1]
+            out_features = concatenated_weights.shape[0]
+
+            self.to_kv = self.linear_cls(
+                in_features, out_features, bias=False, device=device, dtype=dtype
+            )
+            self.to_kv.weight.copy_(concatenated_weights)
+
+        self.fused_projections = fuse
+
+
+class AttnProcessor:
+    r"""
+    Default processor for performing attention-related computations.
+    """
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(
+                batch_size, channel, height * width
+            ).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape
+            if encoder_hidden_states is None
+            else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(
+            attention_mask, sequence_length, batch_size
+        )
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
+                1, 2
+            )
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(
+                encoder_hidden_states
+            )
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(
+                batch_size, channel, height, width
+            )
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class AttnProcessor2_0:
+    r"""
+    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+    """
+
+    def __init__(self):
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError(
+                "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+            )
+
+    def __call__(
+        self,
+        attn: Attention,
+        hidden_states: torch.FloatTensor,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        residual = hidden_states
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(
+                batch_size, channel, height * width
+            ).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape
+            if encoder_hidden_states is None
+            else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(
+                attention_mask, sequence_length, batch_size
+            )
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(
+                batch_size, attn.heads, -1, attention_mask.shape[-1]
+            )
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
+                1, 2
+            )
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(
+                encoder_hidden_states
+            )
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(
+            batch_size, -1, attn.heads * head_dim
+        )
+        hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(
+                batch_size, channel, height, width
+            )
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
diff --git a/tsr/models/transformer/basic_transformer_block.py b/tsr/models/transformer/basic_transformer_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..a55ea6189c9e1da2f86f6d0957ae30f26fefff0a
--- /dev/null
+++ b/tsr/models/transformer/basic_transformer_block.py
@@ -0,0 +1,314 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .attention import Attention
+
+
+class BasicTransformerBlock(nn.Module):
+    r"""
+    A basic Transformer block.
+
+    Parameters:
+        dim (`int`): The number of channels in the input and output.
+        num_attention_heads (`int`): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`): The number of channels in each head.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+        num_embeds_ada_norm (:
+            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+        attention_bias (:
+            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+        only_cross_attention (`bool`, *optional*):
+            Whether to use only cross-attention layers. In this case two cross attention layers are used.
+        double_self_attention (`bool`, *optional*):
+            Whether to use two self-attention layers. In this case no cross attention layers are used.
+        upcast_attention (`bool`, *optional*):
+            Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+        norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+            Whether to use learnable elementwise affine parameters for normalization.
+        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+            The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+        final_dropout (`bool` *optional*, defaults to False):
+            Whether to apply a final dropout after the last feed-forward layer.
+        attention_type (`str`, *optional*, defaults to `"default"`):
+            The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        num_attention_heads: int,
+        attention_head_dim: int,
+        dropout=0.0,
+        cross_attention_dim: Optional[int] = None,
+        activation_fn: str = "geglu",
+        attention_bias: bool = False,
+        only_cross_attention: bool = False,
+        double_self_attention: bool = False,
+        upcast_attention: bool = False,
+        norm_elementwise_affine: bool = True,
+        norm_type: str = "layer_norm",
+        final_dropout: bool = False,
+    ):
+        super().__init__()
+        self.only_cross_attention = only_cross_attention
+
+        assert norm_type == "layer_norm"
+
+        # Define 3 blocks. Each block has its own normalization layer.
+        # 1. Self-Attn
+        self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+        self.attn1 = Attention(
+            query_dim=dim,
+            heads=num_attention_heads,
+            dim_head=attention_head_dim,
+            dropout=dropout,
+            bias=attention_bias,
+            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+            upcast_attention=upcast_attention,
+        )
+
+        # 2. Cross-Attn
+        if cross_attention_dim is not None or double_self_attention:
+            # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+            # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+            # the second cross attention block.
+            self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+
+            self.attn2 = Attention(
+                query_dim=dim,
+                cross_attention_dim=cross_attention_dim
+                if not double_self_attention
+                else None,
+                heads=num_attention_heads,
+                dim_head=attention_head_dim,
+                dropout=dropout,
+                bias=attention_bias,
+                upcast_attention=upcast_attention,
+            )  # is self-attn if encoder_hidden_states is none
+        else:
+            self.norm2 = None
+            self.attn2 = None
+
+        # 3. Feed-forward
+        self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
+        self.ff = FeedForward(
+            dim,
+            dropout=dropout,
+            activation_fn=activation_fn,
+            final_dropout=final_dropout,
+        )
+
+        # let chunk size default to None
+        self._chunk_size = None
+        self._chunk_dim = 0
+
+    def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
+        # Sets chunk feed-forward
+        self._chunk_size = chunk_size
+        self._chunk_dim = dim
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+    ) -> torch.FloatTensor:
+        # Notice that normalization is always applied before the real computation in the following blocks.
+        # 0. Self-Attention
+        norm_hidden_states = self.norm1(hidden_states)
+
+        attn_output = self.attn1(
+            norm_hidden_states,
+            encoder_hidden_states=encoder_hidden_states
+            if self.only_cross_attention
+            else None,
+            attention_mask=attention_mask,
+        )
+
+        hidden_states = attn_output + hidden_states
+
+        # 3. Cross-Attention
+        if self.attn2 is not None:
+            norm_hidden_states = self.norm2(hidden_states)
+
+            attn_output = self.attn2(
+                norm_hidden_states,
+                encoder_hidden_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+            )
+            hidden_states = attn_output + hidden_states
+
+        # 4. Feed-forward
+        norm_hidden_states = self.norm3(hidden_states)
+
+        if self._chunk_size is not None:
+            # "feed_forward_chunk_size" can be used to save memory
+            if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
+                raise ValueError(
+                    f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+                )
+
+            num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
+            ff_output = torch.cat(
+                [
+                    self.ff(hid_slice)
+                    for hid_slice in norm_hidden_states.chunk(
+                        num_chunks, dim=self._chunk_dim
+                    )
+                ],
+                dim=self._chunk_dim,
+            )
+        else:
+            ff_output = self.ff(norm_hidden_states)
+
+        hidden_states = ff_output + hidden_states
+
+        return hidden_states
+
+
+class FeedForward(nn.Module):
+    r"""
+    A feed-forward layer.
+
+    Parameters:
+        dim (`int`): The number of channels in the input.
+        dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+        mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+        final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        dim_out: Optional[int] = None,
+        mult: int = 4,
+        dropout: float = 0.0,
+        activation_fn: str = "geglu",
+        final_dropout: bool = False,
+    ):
+        super().__init__()
+        inner_dim = int(dim * mult)
+        dim_out = dim_out if dim_out is not None else dim
+        linear_cls = nn.Linear
+
+        if activation_fn == "gelu":
+            act_fn = GELU(dim, inner_dim)
+        if activation_fn == "gelu-approximate":
+            act_fn = GELU(dim, inner_dim, approximate="tanh")
+        elif activation_fn == "geglu":
+            act_fn = GEGLU(dim, inner_dim)
+        elif activation_fn == "geglu-approximate":
+            act_fn = ApproximateGELU(dim, inner_dim)
+
+        self.net = nn.ModuleList([])
+        # project in
+        self.net.append(act_fn)
+        # project dropout
+        self.net.append(nn.Dropout(dropout))
+        # project out
+        self.net.append(linear_cls(inner_dim, dim_out))
+        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+        if final_dropout:
+            self.net.append(nn.Dropout(dropout))
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        for module in self.net:
+            hidden_states = module(hidden_states)
+        return hidden_states
+
+
+class GELU(nn.Module):
+    r"""
+    GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+    Parameters:
+        dim_in (`int`): The number of channels in the input.
+        dim_out (`int`): The number of channels in the output.
+        approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+    """
+
+    def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out)
+        self.approximate = approximate
+
+    def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+        if gate.device.type != "mps":
+            return F.gelu(gate, approximate=self.approximate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
+            dtype=gate.dtype
+        )
+
+    def forward(self, hidden_states):
+        hidden_states = self.proj(hidden_states)
+        hidden_states = self.gelu(hidden_states)
+        return hidden_states
+
+
+class GEGLU(nn.Module):
+    r"""
+    A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+    Parameters:
+        dim_in (`int`): The number of channels in the input.
+        dim_out (`int`): The number of channels in the output.
+    """
+
+    def __init__(self, dim_in: int, dim_out: int):
+        super().__init__()
+        linear_cls = nn.Linear
+
+        self.proj = linear_cls(dim_in, dim_out * 2)
+
+    def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+        if gate.device.type != "mps":
+            return F.gelu(gate)
+        # mps: gelu is not implemented for float16
+        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+    def forward(self, hidden_states, scale: float = 1.0):
+        args = ()
+        hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
+        return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+    r"""
+    The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
+    https://arxiv.org/abs/1606.08415.
+
+    Parameters:
+        dim_in (`int`): The number of channels in the input.
+        dim_out (`int`): The number of channels in the output.
+    """
+
+    def __init__(self, dim_in: int, dim_out: int):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.proj(x)
+        return x * torch.sigmoid(1.702 * x)
diff --git a/tsr/models/transformer/transformer_1d.py b/tsr/models/transformer/transformer_1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..7546232c5126a6622f1fd701ba36f9d4b53b9178
--- /dev/null
+++ b/tsr/models/transformer/transformer_1d.py
@@ -0,0 +1,216 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...utils import BaseModule
+from .basic_transformer_block import BasicTransformerBlock
+
+
+class Transformer1D(BaseModule):
+    """
+    A 1D Transformer model for sequence data.
+
+    Parameters:
+        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+        in_channels (`int`, *optional*):
+            The number of channels in the input and output (specify if the input is **continuous**).
+        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+        num_embeds_ada_norm ( `int`, *optional*):
+            The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+            `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+            added to the hidden states.
+
+            During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+        attention_bias (`bool`, *optional*):
+            Configure if the `TransformerBlocks` attention should contain a bias parameter.
+    """
+
+    @dataclass
+    class Config(BaseModule.Config):
+        num_attention_heads: int = 16
+        attention_head_dim: int = 88
+        in_channels: Optional[int] = None
+        out_channels: Optional[int] = None
+        num_layers: int = 1
+        dropout: float = 0.0
+        norm_num_groups: int = 32
+        cross_attention_dim: Optional[int] = None
+        attention_bias: bool = False
+        activation_fn: str = "geglu"
+        only_cross_attention: bool = False
+        double_self_attention: bool = False
+        upcast_attention: bool = False
+        norm_type: str = "layer_norm"
+        norm_elementwise_affine: bool = True
+        gradient_checkpointing: bool = False
+
+    cfg: Config
+
+    def configure(self) -> None:
+        self.num_attention_heads = self.cfg.num_attention_heads
+        self.attention_head_dim = self.cfg.attention_head_dim
+        inner_dim = self.num_attention_heads * self.attention_head_dim
+
+        linear_cls = nn.Linear
+
+        # 2. Define input layers
+        self.in_channels = self.cfg.in_channels
+
+        self.norm = torch.nn.GroupNorm(
+            num_groups=self.cfg.norm_num_groups,
+            num_channels=self.cfg.in_channels,
+            eps=1e-6,
+            affine=True,
+        )
+        self.proj_in = linear_cls(self.cfg.in_channels, inner_dim)
+
+        # 3. Define transformers blocks
+        self.transformer_blocks = nn.ModuleList(
+            [
+                BasicTransformerBlock(
+                    inner_dim,
+                    self.num_attention_heads,
+                    self.attention_head_dim,
+                    dropout=self.cfg.dropout,
+                    cross_attention_dim=self.cfg.cross_attention_dim,
+                    activation_fn=self.cfg.activation_fn,
+                    attention_bias=self.cfg.attention_bias,
+                    only_cross_attention=self.cfg.only_cross_attention,
+                    double_self_attention=self.cfg.double_self_attention,
+                    upcast_attention=self.cfg.upcast_attention,
+                    norm_type=self.cfg.norm_type,
+                    norm_elementwise_affine=self.cfg.norm_elementwise_affine,
+                )
+                for d in range(self.cfg.num_layers)
+            ]
+        )
+
+        # 4. Define output layers
+        self.out_channels = (
+            self.cfg.in_channels
+            if self.cfg.out_channels is None
+            else self.cfg.out_channels
+        )
+
+        self.proj_out = linear_cls(inner_dim, self.cfg.in_channels)
+
+        self.gradient_checkpointing = self.cfg.gradient_checkpointing
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+    ):
+        """
+        The [`Transformer1DModel`] forward method.
+
+        Args:
+            hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+                Input `hidden_states`.
+            encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+                self-attention.
+            timestep ( `torch.LongTensor`, *optional*):
+                Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+                Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+                `AdaLayerZeroNorm`.
+            cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            attention_mask ( `torch.Tensor`, *optional*):
+                An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+                is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+                negative values to the attention scores corresponding to "discard" tokens.
+            encoder_attention_mask ( `torch.Tensor`, *optional*):
+                Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+                    * Mask `(batch, sequence_length)` True = keep, False = discard.
+                    * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+                If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+                above. This bias will be added to the cross-attention scores.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+                tuple.
+
+        Returns:
+            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+            `tuple` where the first element is the sample tensor.
+        """
+        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+        # expects mask of shape:
+        #   [batch, key_tokens]
+        # adds singleton query_tokens dimension:
+        #   [batch,                    1, key_tokens]
+        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+        if attention_mask is not None and attention_mask.ndim == 2:
+            # assume that mask is expressed as:
+            #   (1 = keep,      0 = discard)
+            # convert mask into a bias that can be added to attention scores:
+            #       (keep = +0,     discard = -10000.0)
+            attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+            attention_mask = attention_mask.unsqueeze(1)
+
+        # convert encoder_attention_mask to a bias the same way we do for attention_mask
+        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+            encoder_attention_mask = (
+                1 - encoder_attention_mask.to(hidden_states.dtype)
+            ) * -10000.0
+            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+        # 1. Input
+        batch, _, seq_len = hidden_states.shape
+        residual = hidden_states
+
+        hidden_states = self.norm(hidden_states)
+        inner_dim = hidden_states.shape[1]
+        hidden_states = hidden_states.permute(0, 2, 1).reshape(
+            batch, seq_len, inner_dim
+        )
+        hidden_states = self.proj_in(hidden_states)
+
+        # 2. Blocks
+        for block in self.transformer_blocks:
+            if self.training and self.gradient_checkpointing:
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    block,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    use_reentrant=False,
+                )
+            else:
+                hidden_states = block(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                )
+
+        # 3. Output
+        hidden_states = self.proj_out(hidden_states)
+        hidden_states = (
+            hidden_states.reshape(batch, seq_len, inner_dim)
+            .permute(0, 2, 1)
+            .contiguous()
+        )
+
+        output = hidden_states + residual
+
+        return output
diff --git a/tsr/system.py b/tsr/system.py
new file mode 100644
index 0000000000000000000000000000000000000000..de5ae6ef75afd3bbbe0bdfacc7258a8c51409cc5
--- /dev/null
+++ b/tsr/system.py
@@ -0,0 +1,203 @@
+import math
+import os
+from dataclasses import dataclass, field
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+import trimesh
+from einops import rearrange
+from huggingface_hub import hf_hub_download
+from omegaconf import OmegaConf
+from PIL import Image
+
+from .models.isosurface import MarchingCubeHelper
+from .utils import (
+    BaseModule,
+    ImagePreprocessor,
+    find_class,
+    get_spherical_cameras,
+    scale_tensor,
+)
+
+
+class TSR(BaseModule):
+    @dataclass
+    class Config(BaseModule.Config):
+        cond_image_size: int
+
+        image_tokenizer_cls: str
+        image_tokenizer: dict
+
+        tokenizer_cls: str
+        tokenizer: dict
+
+        backbone_cls: str
+        backbone: dict
+
+        post_processor_cls: str
+        post_processor: dict
+
+        decoder_cls: str
+        decoder: dict
+
+        renderer_cls: str
+        renderer: dict
+
+    cfg: Config
+
+    @classmethod
+    def from_pretrained(
+        cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str, token=None
+    ):
+        if os.path.isdir(pretrained_model_name_or_path):
+            config_path = os.path.join(pretrained_model_name_or_path, config_name)
+            weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
+        else:
+            config_path = hf_hub_download(
+                repo_id=pretrained_model_name_or_path, filename=config_name, token=token
+            )
+            weight_path = hf_hub_download(
+                repo_id=pretrained_model_name_or_path, filename=weight_name, token=token
+            )
+
+        cfg = OmegaConf.load(config_path)
+        OmegaConf.resolve(cfg)
+        model = cls(cfg)
+        ckpt = torch.load(weight_path, map_location="cpu")
+        model.load_state_dict(ckpt)
+        return model
+
+    def configure(self):
+        self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
+            self.cfg.image_tokenizer
+        )
+        self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
+        self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
+        self.post_processor = find_class(self.cfg.post_processor_cls)(
+            self.cfg.post_processor
+        )
+        self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
+        self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
+        self.image_processor = ImagePreprocessor()
+        self.isosurface_helper = None
+
+    def forward(
+        self,
+        image: Union[
+            PIL.Image.Image,
+            np.ndarray,
+            torch.FloatTensor,
+            List[PIL.Image.Image],
+            List[np.ndarray],
+            List[torch.FloatTensor],
+        ],
+        device: str,
+    ) -> torch.FloatTensor:
+        rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
+            device
+        )
+        batch_size = rgb_cond.shape[0]
+
+        input_image_tokens: torch.Tensor = self.image_tokenizer(
+            rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
+        )
+
+        input_image_tokens = rearrange(
+            input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
+        )
+
+        tokens: torch.Tensor = self.tokenizer(batch_size)
+
+        tokens = self.backbone(
+            tokens,
+            encoder_hidden_states=input_image_tokens,
+        )
+
+        scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
+        return scene_codes
+
+    def render(
+        self,
+        scene_codes,
+        n_views: int,
+        elevation_deg: float = 0.0,
+        camera_distance: float = 1.9,
+        fovy_deg: float = 40.0,
+        height: int = 256,
+        width: int = 256,
+        return_type: str = "pil",
+    ):
+        rays_o, rays_d = get_spherical_cameras(
+            n_views, elevation_deg, camera_distance, fovy_deg, height, width
+        )
+        rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
+
+        def process_output(image: torch.FloatTensor):
+            if return_type == "pt":
+                return image
+            elif return_type == "np":
+                return image.detach().cpu().numpy()
+            elif return_type == "pil":
+                return Image.fromarray(
+                    (image.detach().cpu().numpy() * 255.0).astype(np.uint8)
+                )
+            else:
+                raise NotImplementedError
+
+        images = []
+        for scene_code in scene_codes:
+            images_ = []
+            for i in range(n_views):
+                with torch.no_grad():
+                    image = self.renderer(
+                        self.decoder, scene_code, rays_o[i], rays_d[i]
+                    )
+                images_.append(process_output(image))
+            images.append(images_)
+
+        return images
+
+    def set_marching_cubes_resolution(self, resolution: int):
+        if (
+            self.isosurface_helper is not None
+            and self.isosurface_helper.resolution == resolution
+        ):
+            return
+        self.isosurface_helper = MarchingCubeHelper(resolution)
+
+    def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
+        self.set_marching_cubes_resolution(resolution)
+        meshes = []
+        for scene_code in scene_codes:
+            with torch.no_grad():
+                density = self.renderer.query_triplane(
+                    self.decoder,
+                    scale_tensor(
+                        self.isosurface_helper.grid_vertices.to(scene_codes.device),
+                        self.isosurface_helper.points_range,
+                        (-self.renderer.cfg.radius, self.renderer.cfg.radius),
+                    ),
+                    scene_code,
+                )["density_act"]
+            v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
+            v_pos = scale_tensor(
+                v_pos,
+                self.isosurface_helper.points_range,
+                (-self.renderer.cfg.radius, self.renderer.cfg.radius),
+            )
+            with torch.no_grad():
+                color = self.renderer.query_triplane(
+                    self.decoder,
+                    v_pos,
+                    scene_code,
+                )["color"]
+            mesh = trimesh.Trimesh(
+                vertices=v_pos.cpu().numpy(),
+                faces=t_pos_idx.cpu().numpy(),
+                vertex_colors=color.cpu().numpy(),
+            )
+            meshes.append(mesh)
+        return meshes
diff --git a/tsr/utils.py b/tsr/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1b59aef75d02d39b29222d300fe9241bb11444
--- /dev/null
+++ b/tsr/utils.py
@@ -0,0 +1,482 @@
+import importlib
+import math
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import imageio
+import numpy as np
+import PIL.Image
+import rembg
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import trimesh
+from omegaconf import DictConfig, OmegaConf
+from PIL import Image
+
+
+def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
+    scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg)
+    return scfg
+
+
+def find_class(cls_string):
+    module_string = ".".join(cls_string.split(".")[:-1])
+    cls_name = cls_string.split(".")[-1]
+    module = importlib.import_module(module_string, package=None)
+    cls = getattr(module, cls_name)
+    return cls
+
+
+def get_intrinsic_from_fov(fov, H, W, bs=-1):
+    focal_length = 0.5 * H / np.tan(0.5 * fov)
+    intrinsic = np.identity(3, dtype=np.float32)
+    intrinsic[0, 0] = focal_length
+    intrinsic[1, 1] = focal_length
+    intrinsic[0, 2] = W / 2.0
+    intrinsic[1, 2] = H / 2.0
+
+    if bs > 0:
+        intrinsic = intrinsic[None].repeat(bs, axis=0)
+
+    return torch.from_numpy(intrinsic)
+
+
+class BaseModule(nn.Module):
+    @dataclass
+    class Config:
+        pass
+
+    cfg: Config  # add this to every subclass of BaseModule to enable static type checking
+
+    def __init__(
+        self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
+    ) -> None:
+        super().__init__()
+        self.cfg = parse_structured(self.Config, cfg)
+        self.configure(*args, **kwargs)
+
+    def configure(self, *args, **kwargs) -> None:
+        raise NotImplementedError
+
+
+class ImagePreprocessor:
+    def convert_and_resize(
+        self,
+        image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
+        size: int,
+    ):
+        if isinstance(image, PIL.Image.Image):
+            image = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
+        elif isinstance(image, np.ndarray):
+            if image.dtype == np.uint8:
+                image = torch.from_numpy(image.astype(np.float32) / 255.0)
+            else:
+                image = torch.from_numpy(image)
+        elif isinstance(image, torch.Tensor):
+            pass
+
+        batched = image.ndim == 4
+
+        if not batched:
+            image = image[None, ...]
+        image = F.interpolate(
+            image.permute(0, 3, 1, 2),
+            (size, size),
+            mode="bilinear",
+            align_corners=False,
+            antialias=True,
+        ).permute(0, 2, 3, 1)
+        if not batched:
+            image = image[0]
+        return image
+
+    def __call__(
+        self,
+        image: Union[
+            PIL.Image.Image,
+            np.ndarray,
+            torch.FloatTensor,
+            List[PIL.Image.Image],
+            List[np.ndarray],
+            List[torch.FloatTensor],
+        ],
+        size: int,
+    ) -> Any:
+        if isinstance(image, (np.ndarray, torch.FloatTensor)) and image.ndim == 4:
+            image = self.convert_and_resize(image, size)
+        else:
+            if not isinstance(image, list):
+                image = [image]
+            image = [self.convert_and_resize(im, size) for im in image]
+            image = torch.stack(image, dim=0)
+        return image
+
+
+def rays_intersect_bbox(
+    rays_o: torch.Tensor,
+    rays_d: torch.Tensor,
+    radius: float,
+    near: float = 0.0,
+    valid_thresh: float = 0.01,
+):
+    input_shape = rays_o.shape[:-1]
+    rays_o, rays_d = rays_o.view(-1, 3), rays_d.view(-1, 3)
+    rays_d_valid = torch.where(
+        rays_d.abs() < 1e-6, torch.full_like(rays_d, 1e-6), rays_d
+    )
+    if type(radius) in [int, float]:
+        radius = torch.FloatTensor(
+            [[-radius, radius], [-radius, radius], [-radius, radius]]
+        ).to(rays_o.device)
+    radius = (
+        1.0 - 1.0e-3
+    ) * radius  # tighten the radius to make sure the intersection point lies in the bounding box
+    interx0 = (radius[..., 1] - rays_o) / rays_d_valid
+    interx1 = (radius[..., 0] - rays_o) / rays_d_valid
+    t_near = torch.minimum(interx0, interx1).amax(dim=-1).clamp_min(near)
+    t_far = torch.maximum(interx0, interx1).amin(dim=-1)
+
+    # check wheter a ray intersects the bbox or not
+    rays_valid = t_far - t_near > valid_thresh
+
+    t_near[torch.where(~rays_valid)] = 0.0
+    t_far[torch.where(~rays_valid)] = 0.0
+
+    t_near = t_near.view(*input_shape, 1)
+    t_far = t_far.view(*input_shape, 1)
+    rays_valid = rays_valid.view(*input_shape)
+
+    return t_near, t_far, rays_valid
+
+
+def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any:
+    if chunk_size <= 0:
+        return func(*args, **kwargs)
+    B = None
+    for arg in list(args) + list(kwargs.values()):
+        if isinstance(arg, torch.Tensor):
+            B = arg.shape[0]
+            break
+    assert (
+        B is not None
+    ), "No tensor found in args or kwargs, cannot determine batch size."
+    out = defaultdict(list)
+    out_type = None
+    # max(1, B) to support B == 0
+    for i in range(0, max(1, B), chunk_size):
+        out_chunk = func(
+            *[
+                arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+                for arg in args
+            ],
+            **{
+                k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
+                for k, arg in kwargs.items()
+            },
+        )
+        if out_chunk is None:
+            continue
+        out_type = type(out_chunk)
+        if isinstance(out_chunk, torch.Tensor):
+            out_chunk = {0: out_chunk}
+        elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
+            chunk_length = len(out_chunk)
+            out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
+        elif isinstance(out_chunk, dict):
+            pass
+        else:
+            print(
+                f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
+            )
+            exit(1)
+        for k, v in out_chunk.items():
+            v = v if torch.is_grad_enabled() else v.detach()
+            out[k].append(v)
+
+    if out_type is None:
+        return None
+
+    out_merged: Dict[Any, Optional[torch.Tensor]] = {}
+    for k, v in out.items():
+        if all([vv is None for vv in v]):
+            # allow None in return value
+            out_merged[k] = None
+        elif all([isinstance(vv, torch.Tensor) for vv in v]):
+            out_merged[k] = torch.cat(v, dim=0)
+        else:
+            raise TypeError(
+                f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
+            )
+
+    if out_type is torch.Tensor:
+        return out_merged[0]
+    elif out_type in [tuple, list]:
+        return out_type([out_merged[i] for i in range(chunk_length)])
+    elif out_type is dict:
+        return out_merged
+
+
+ValidScale = Union[Tuple[float, float], torch.FloatTensor]
+
+
+def scale_tensor(dat: torch.FloatTensor, inp_scale: ValidScale, tgt_scale: ValidScale):
+    if inp_scale is None:
+        inp_scale = (0, 1)
+    if tgt_scale is None:
+        tgt_scale = (0, 1)
+    if isinstance(tgt_scale, torch.FloatTensor):
+        assert dat.shape[-1] == tgt_scale.shape[-1]
+    dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
+    dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
+    return dat
+
+
+def get_activation(name) -> Callable:
+    if name is None:
+        return lambda x: x
+    name = name.lower()
+    if name == "none":
+        return lambda x: x
+    elif name == "exp":
+        return lambda x: torch.exp(x)
+    elif name == "sigmoid":
+        return lambda x: torch.sigmoid(x)
+    elif name == "tanh":
+        return lambda x: torch.tanh(x)
+    elif name == "softplus":
+        return lambda x: F.softplus(x)
+    else:
+        try:
+            return getattr(F, name)
+        except AttributeError:
+            raise ValueError(f"Unknown activation function: {name}")
+
+
+def get_ray_directions(
+    H: int,
+    W: int,
+    focal: Union[float, Tuple[float, float]],
+    principal: Optional[Tuple[float, float]] = None,
+    use_pixel_centers: bool = True,
+    normalize: bool = True,
+) -> torch.FloatTensor:
+    """
+    Get ray directions for all pixels in camera coordinate.
+    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+               ray-tracing-generating-camera-rays/standard-coordinate-systems
+
+    Inputs:
+        H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
+    Outputs:
+        directions: (H, W, 3), the direction of the rays in camera coordinate
+    """
+    pixel_center = 0.5 if use_pixel_centers else 0
+
+    if isinstance(focal, float):
+        fx, fy = focal, focal
+        cx, cy = W / 2, H / 2
+    else:
+        fx, fy = focal
+        assert principal is not None
+        cx, cy = principal
+
+    i, j = torch.meshgrid(
+        torch.arange(W, dtype=torch.float32) + pixel_center,
+        torch.arange(H, dtype=torch.float32) + pixel_center,
+        indexing="xy",
+    )
+
+    directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1)
+
+    if normalize:
+        directions = F.normalize(directions, dim=-1)
+
+    return directions
+
+
+def get_rays(
+    directions,
+    c2w,
+    keepdim=False,
+    noise_scale=0.0,
+    normalize=False,
+) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+    # Rotate ray directions from camera coordinate to the world coordinate
+    assert directions.shape[-1] == 3
+
+    if directions.ndim == 2:  # (N_rays, 3)
+        if c2w.ndim == 2:  # (4, 4)
+            c2w = c2w[None, :, :]
+        assert c2w.ndim == 3  # (N_rays, 4, 4) or (1, 4, 4)
+        rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1)  # (N_rays, 3)
+        rays_o = c2w[:, :3, 3].expand(rays_d.shape)
+    elif directions.ndim == 3:  # (H, W, 3)
+        assert c2w.ndim in [2, 3]
+        if c2w.ndim == 2:  # (4, 4)
+            rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
+                -1
+            )  # (H, W, 3)
+            rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
+        elif c2w.ndim == 3:  # (B, 4, 4)
+            rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+                -1
+            )  # (B, H, W, 3)
+            rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+    elif directions.ndim == 4:  # (B, H, W, 3)
+        assert c2w.ndim == 3  # (B, 4, 4)
+        rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
+            -1
+        )  # (B, H, W, 3)
+        rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
+
+    # add camera noise to avoid grid-like artifect
+    # https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
+    if noise_scale > 0:
+        rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
+        rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
+
+    if normalize:
+        rays_d = F.normalize(rays_d, dim=-1)
+    if not keepdim:
+        rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
+
+    return rays_o, rays_d
+
+
+def get_spherical_cameras(
+    n_views: int,
+    elevation_deg: float,
+    camera_distance: float,
+    fovy_deg: float,
+    height: int,
+    width: int,
+):
+    azimuth_deg = torch.linspace(0, 360.0, n_views + 1)[:n_views]
+    elevation_deg = torch.full_like(azimuth_deg, elevation_deg)
+    camera_distances = torch.full_like(elevation_deg, camera_distance)
+
+    elevation = elevation_deg * math.pi / 180
+    azimuth = azimuth_deg * math.pi / 180
+
+    # convert spherical coordinates to cartesian coordinates
+    # right hand coordinate system, x back, y right, z up
+    # elevation in (-90, 90), azimuth from +x to +y in (-180, 180)
+    camera_positions = torch.stack(
+        [
+            camera_distances * torch.cos(elevation) * torch.cos(azimuth),
+            camera_distances * torch.cos(elevation) * torch.sin(azimuth),
+            camera_distances * torch.sin(elevation),
+        ],
+        dim=-1,
+    )
+
+    # default scene center at origin
+    center = torch.zeros_like(camera_positions)
+    # default camera up direction as +z
+    up = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None, :].repeat(n_views, 1)
+
+    fovy = torch.full_like(elevation_deg, fovy_deg) * math.pi / 180
+
+    lookat = F.normalize(center - camera_positions, dim=-1)
+    right = F.normalize(torch.cross(lookat, up), dim=-1)
+    up = F.normalize(torch.cross(right, lookat), dim=-1)
+    c2w3x4 = torch.cat(
+        [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]],
+        dim=-1,
+    )
+    c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1)
+    c2w[:, 3, 3] = 1.0
+
+    # get directions by dividing directions_unit_focal by focal length
+    focal_length = 0.5 * height / torch.tan(0.5 * fovy)
+    directions_unit_focal = get_ray_directions(
+        H=height,
+        W=width,
+        focal=1.0,
+    )
+    directions = directions_unit_focal[None, :, :, :].repeat(n_views, 1, 1, 1)
+    directions[:, :, :, :2] = (
+        directions[:, :, :, :2] / focal_length[:, None, None, None]
+    )
+    # must use normalize=True to normalize directions here
+    rays_o, rays_d = get_rays(directions, c2w, keepdim=True, normalize=True)
+
+    return rays_o, rays_d
+
+
+def remove_background(
+    image: PIL.Image.Image,
+    rembg_session: Any = None,
+    force: bool = False,
+    **rembg_kwargs,
+) -> PIL.Image.Image:
+    do_remove = True
+    if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
+        do_remove = False
+    do_remove = do_remove or force
+    if do_remove:
+        image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
+    return image
+
+
+def resize_foreground(
+    image: PIL.Image.Image,
+    ratio: float,
+) -> PIL.Image.Image:
+    image = np.array(image)
+    assert image.shape[-1] == 4
+    alpha = np.where(image[..., 3] > 0)
+    y1, y2, x1, x2 = (
+        alpha[0].min(),
+        alpha[0].max(),
+        alpha[1].min(),
+        alpha[1].max(),
+    )
+    # crop the foreground
+    fg = image[y1:y2, x1:x2]
+    # pad to square
+    size = max(fg.shape[0], fg.shape[1])
+    ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
+    ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
+    new_image = np.pad(
+        fg,
+        ((ph0, ph1), (pw0, pw1), (0, 0)),
+        mode="constant",
+        constant_values=((0, 0), (0, 0), (0, 0)),
+    )
+
+    # compute padding according to the ratio
+    new_size = int(new_image.shape[0] / ratio)
+    # pad to size, double side
+    ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
+    ph1, pw1 = new_size - size - ph0, new_size - size - pw0
+    new_image = np.pad(
+        new_image,
+        ((ph0, ph1), (pw0, pw1), (0, 0)),
+        mode="constant",
+        constant_values=((0, 0), (0, 0), (0, 0)),
+    )
+    new_image = PIL.Image.fromarray(new_image)
+    return new_image
+
+
+def save_video(
+    frames: List[PIL.Image.Image],
+    output_path: str,
+    fps: int = 30,
+):
+    # use imageio to save video
+    frames = [np.array(frame) for frame in frames]
+    writer = imageio.get_writer(output_path, fps=fps)
+    for frame in frames:
+        writer.append_data(frame)
+    writer.close()
+
+
+def to_gradio_3d_orientation(mesh):
+    mesh.apply_transform(trimesh.transformations.rotation_matrix(-np.pi/2, [1, 0, 0]))
+    # mesh.apply_scale([1, 1, -1])
+    mesh.apply_transform(trimesh.transformations.rotation_matrix(np.pi/2, [0, 1, 0]))
+    return mesh
diff --git a/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl b/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl
new file mode 100644
index 0000000000000000000000000000000000000000..88fdefc57ff16c5fd6354342b8801509d33c529e
--- /dev/null
+++ b/wheel/torchmcubes-0.1.0-cp310-cp310-linux_x86_64.whl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4af160ba1274e2205d3529a7b82efdb6946c2158a78e19631ed840301055b8d6
+size 5824388