File size: 24,484 Bytes
e9e75df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "6db2e2b4",
      "metadata": {
        "id": "6db2e2b4"
      },
      "source": [
        "## Build and Compose Conditioners\n",
        "\n",
        "### Overview\n",
        "Protein design via Chroma is highly customizable and programmable. Our robust Conditioner framework enables automatic conditional sampling tailored to a diverse array of protein specifications. This can involve either restraints (which bias the distribution of states using classifier guidance) or constraints (that directly limit the scope of the underlying sampling process). For a detailed explanation, refer to Supplementary Appendix M in our paper. We offer a variety of pre-defined conditioners, including those for managing substructure, symmetry, shape, semantics, and even natural-language prompts (see `chroma.layers.structure.conditioners`). These conditioners can be utilized in any combination to suit your specific needs."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3b4c35a7",
      "metadata": {
        "id": "3b4c35a7"
      },
      "source": [
        "### Composing Conditioners\n",
        "\n",
        "Conditioners in Chroma can be combined seamlessly using `conditioners.ComposedConditioner`, akin to how layers are sequenced in `torch.nn.Sequential`. You can define individual conditioners and then aggregate them into a single collective list which will sequentially apply constrained transformations.\n",
        "\n",
        "```python\n",
        "composed_conditioner = conditioners.ComposedConditioner([conditioner1, conditioner2, conditioner3])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b2lOsBQFhypc",
      "metadata": {
        "id": "b2lOsBQFhypc"
      },
      "source": [
        "#### Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "4db56efd",
      "metadata": {
        "id": "4db56efd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Note: you may need to restart the kernel to use updated packages.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0854e7da4ca04f71a86260c2e66bbfcc",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": []
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import locale\n",
        "locale.getpreferredencoding = lambda: \"UTF-8\"\n",
        "%pip install generate-chroma > /dev/null 2>&1\n",
        "from chroma import api\n",
        "api.register_key(input(\"Enter API key: \"))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f3ee7c51",
      "metadata": {
        "id": "f3ee7c51"
      },
      "source": [
        "#### Example 1: Combining Symmetry and Secondary Structure\n",
        "In this scenario, we initially apply guidance for secondary structure to condition the content accordingly. This is followed by incorporating Cyclic symmetry. This approach involves adding a secondary structure classifier to conditionally sample an Asymmetric unit (AU) that is beta-rich, followed by symmetrization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "id": "37b9c48f",
      "metadata": {
        "id": "37b9c48f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt\n",
            "Loaded from cache\n",
            "Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt\n",
            "Loaded from cache\n",
            "Data saved to /tmp/chroma_weights/3262b44702040b1dcfccd71ebbcf451d/weights.pt\n",
            "Computing reference stats for 2g3n\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c67d34c17a45442cb2804cc3c8060222",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Integrating SDE:   0%|          | 0/500 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "760cbac8cf3746919bff0644fd797491",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Potts Sampling:   0%|          | 0/500 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a65f9ce5a38f4c21a66a0aacfe76b99d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Sequential decoding:   0%|          | 0/300 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "dbbe249b2a304c30b440602527330c95",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "NGLWidget()"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from chroma.models import Chroma\n",
        "from chroma.layers.structure import conditioners\n",
        "\n",
        "chroma = Chroma()\n",
        "# Conditional on C=2 (mostly beta)\n",
        "beta = conditioners.ProClassConditioner('cath', \"2\", weight=5, max_norm=20)\n",
        "c_symmetry = conditioners.SymmetryConditioner(G=\"C_3\", num_chain_neighbors=2)\n",
        "composed_cond = conditioners.ComposedConditioner([beta, c_symmetry])\n",
        "\n",
        "symm_beta = chroma.sample(chain_lengths=[100],\n",
        "    conditioner=composed_cond,\n",
        "    langevin_factor=8,\n",
        "    inverse_temperature=8,\n",
        "    sde_func=\"langevin\",\n",
        "    steps=500)\n",
        "\n",
        "symm_beta"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8b97d729",
      "metadata": {
        "id": "8b97d729"
      },
      "source": [
        "#### Example 2: Merging Symmetry and Substructure\n",
        "Here, our goal is to construct symmetric assemblies from a single-chain protein, partially redesigning it to merge three identical AUs into a Cyclic complex. We begin by defining the backbones targeted for redesign and then reposition the AU to prevent clashes during symmetrization. This is followed by the symmetrization operation itself.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "35c78aad",
      "metadata": {
        "id": "35c78aad"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt\n",
            "Loaded from cache\n",
            "Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt\n",
            "Loaded from cache\n"
          ]
        },
        {
          "ename": "RuntimeError",
          "evalue": "torch.linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 27 is not positive-definite).",
          "output_type": "error",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
            "Cell \u001b[0;32mIn[14], line 8\u001b[0m\n\u001b[1;32m      6\u001b[0m protein \u001b[38;5;241m=\u001b[39m Protein(PDB_ID, canonicalize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m      7\u001b[0m \u001b[38;5;66;03m# regenerate residues with X coord < 25 A and y coord < 25 A\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m substruct_conditioner \u001b[38;5;241m=\u001b[39m \u001b[43mconditioners\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSubstructureConditioner\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m      9\u001b[0m \u001b[43m    \u001b[49m\u001b[43mprotein\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackbone_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchroma\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackbone_network\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mselection\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mx < 25 and y < 25\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     11\u001b[0m \u001b[38;5;66;03m# C_3 symmetry\u001b[39;00m\n\u001b[1;32m     12\u001b[0m c_symmetry \u001b[38;5;241m=\u001b[39m conditioners\u001b[38;5;241m.\u001b[39mSymmetryConditioner(G\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC_3\u001b[39m\u001b[38;5;124m\"\u001b[39m, num_chain_neighbors\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/conditioners.py:881\u001b[0m, in \u001b[0;36mSubstructureConditioner.__init__\u001b[0;34m(self, protein, backbone_model, selection, rg, weight, tspan, weight_max, gamma, center_init)\u001b[0m\n\u001b[1;32m    879\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_distribution \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbackbone_model\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39mbase_gaussian\n\u001b[1;32m    880\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnoise_schedule \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbackbone_model\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39mnoise_schedule\n\u001b[0;32m--> 881\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconditional_distribution \u001b[38;5;241m=\u001b[39m \u001b[43mmvn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mConditionalBackboneMVNGlobular\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    882\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcovariance_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_distribution\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcovariance_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    883\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcomplex_scaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_distribution\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcomplex_scaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    884\u001b[0m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    885\u001b[0m \u001b[43m    \u001b[49m\u001b[43mC\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    886\u001b[0m \u001b[43m    \u001b[49m\u001b[43mD\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mD\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    887\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgamma\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgamma\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    888\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    889\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconditional_distribution\u001b[38;5;241m.\u001b[39msample(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m    890\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtspan \u001b[38;5;241m=\u001b[39m tspan\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/mvn.py:563\u001b[0m, in \u001b[0;36mConditionalBackboneMVNGlobular.__init__\u001b[0;34m(self, covariance_model, complex_scaling, sigma_translation, X, C, D, gamma, **kwargs)\u001b[0m\n\u001b[1;32m    560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mR\u001b[39m\u001b[38;5;124m\"\u001b[39m, R)\n\u001b[1;32m    561\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRRt\u001b[39m\u001b[38;5;124m\"\u001b[39m, RRt)\n\u001b[0;32m--> 563\u001b[0m R_clamp, RRt_clamp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_condition_RRt\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRRt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mD\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    564\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mR_clamp\u001b[39m\u001b[38;5;124m\"\u001b[39m, R_clamp)\n\u001b[1;32m    565\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRRt_clamp\u001b[39m\u001b[38;5;124m\"\u001b[39m, RRt_clamp)\n",
            "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/mvn.py:712\u001b[0m, in \u001b[0;36mConditionalBackboneMVNGlobular._condition_RRt\u001b[0;34m(self, RRt, D)\u001b[0m\n\u001b[1;32m    709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS22\u001b[39m\u001b[38;5;124m\"\u001b[39m, RRt[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnonzero_indices][:, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnonzero_indices])\n\u001b[1;32m    711\u001b[0m S_clamp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS11 \u001b[38;5;241m-\u001b[39m ((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS12 \u001b[38;5;241m@\u001b[39m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mpinv(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS22) \u001b[38;5;241m@\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS21))\n\u001b[0;32m--> 712\u001b[0m R_clamp \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinalg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcholesky\u001b[49m\u001b[43m(\u001b[49m\u001b[43mS_clamp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    713\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRRt_clamp_restricted\u001b[39m\u001b[38;5;124m\"\u001b[39m, R_clamp \u001b[38;5;241m@\u001b[39m R_clamp\u001b[38;5;241m.\u001b[39mt())\n\u001b[1;32m    714\u001b[0m RRt_clamp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_scatter(\n\u001b[1;32m    715\u001b[0m     torch\u001b[38;5;241m.\u001b[39mzeros_like(RRt), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mzero_indices, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mRRt_clamp_restricted\n\u001b[1;32m    716\u001b[0m )\n",
            "\u001b[0;31mRuntimeError\u001b[0m: torch.linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 27 is not positive-definite)."
          ]
        }
      ],
      "source": [
        "from chroma.data import Protein\n",
        "\n",
        "PDB_ID = '3BDI'\n",
        "chroma = Chroma()\n",
        "\n",
        "protein = Protein(PDB_ID, canonicalize=True, device='cuda')\n",
        "# regenerate residues with X coord < 25 A and y coord < 25 A\n",
        "substruct_conditioner = conditioners.SubstructureConditioner(\n",
        "    protein, backbone_model=chroma.backbone_network, selection=\"x < 25 and y < 25\")\n",
        "\n",
        "# C_3 symmetry\n",
        "c_symmetry = conditioners.SymmetryConditioner(G=\"C_3\", num_chain_neighbors=3)\n",
        "\n",
        "# Composing\n",
        "composed_cond = conditioners.ComposedConditioner([substruct_conditioner, c_symmetry])\n",
        "\n",
        "protein, trajectories = chroma.sample(\n",
        "    protein_init=protein,\n",
        "    conditioner=composed_cond,\n",
        "    langevin_factor=4.0,\n",
        "    langevin_isothermal=True,\n",
        "    inverse_temperature=8.0,\n",
        "    sde_func='langevin',\n",
        "    steps=500,\n",
        "    full_output=True,\n",
        ")\n",
        "\n",
        "protein"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "de3c2b97",
      "metadata": {
        "id": "de3c2b97"
      },
      "source": [
        "### Build your own conditioners: 2D protein lattices\n",
        "\n",
        "An attractive aspect of this conditioner framework is that it is very general, enabling both constraints (which involve operations on $x$) and restraints (which amount to changes to $U$). At the same time, generation under restraints can still be (and often is) challenging, as the resulting effective energy landscape can become arbitrarily rugged and difficult to integrate. We therefore advise caution when using and developing new conditioners or conditioner combinations. We find that inspecting diffusition trajectories (including unconstrained and denoised trajectories, $\\hat{x}_t$ and $\\tilde{x}_t$) can be a good tool for identifying integration challenges and defining either better conditioner forms or better sampling regimes.\n",
        "\n",
        "Here we present how to build a conditioner that generates a periodic 2D lattice. You can easily extend this code snippet to generate 3D protein materials."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "2bb9dcf3",
      "metadata": {
        "id": "2bb9dcf3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt\n",
            "Loaded from cache\n",
            "Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt\n",
            "Loaded from cache\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "14663b93ee8a48d3932635b073d60eda",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Integrating SDE:   0%|          | 0/500 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4a123abffdcc4b6b84062d85ba3d4ead",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Potts Sampling:   0%|          | 0/500 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "3cc618db756747d1993075205fa455cd",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Sequential decoding:   0%|          | 0/720 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "08f4338906a744d6bfcf78868f4853f3",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "NGLWidget()"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import torch\n",
        "\n",
        "class Lattice2DConditioner(conditioners.Conditioner):\n",
        "    def __init__(self, M, N, cell):\n",
        "        super().__init__()\n",
        "        # Setup the coordinates of a 2D lattice\n",
        "        self.order = M*N\n",
        "        x = torch.arange(M) * cell[0]\n",
        "        y = torch.arange(N) * cell[1]\n",
        "        xx, yy = torch.meshgrid(x, y, indexing=\"ij\")\n",
        "        dX = torch.stack([xx.flatten(), yy.flatten(), torch.zeros(M * N)], dim=1)\n",
        "        self.register_buffer(\"dX\", dX)\n",
        "\n",
        "    def forward(self, X, C, O, U, t):\n",
        "        # Tesselate the unit cell on the lattice\n",
        "        X = (X[:,None,...] + self.dX[None,:,None,None]).reshape(1, -1, 4, 3)\n",
        "        C = torch.cat([C + C.unique().max() * i for i in range(self.dX.shape[0])], dim=1)\n",
        "        # Average the gradient\n",
        "        X.register_hook(lambda gradX: gradX / self.order)\n",
        "        return X, C, O, U, t\n",
        "\n",
        "chroma = Chroma()\n",
        "M, N = 3, 3\n",
        "conditioner = Lattice2DConditioner(M=M, N=N, cell=[25., 25.]).cuda()\n",
        "protein, trajectories = chroma.sample(\n",
        "    chain_lengths=[80], conditioner=conditioner, sde_func='langevin',\n",
        "    potts_symmetry_order=conditioner.order,\n",
        "    full_output=True\n",
        ")\n",
        "\n",
        "protein"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4ecf48ff",
      "metadata": {
        "id": "4ecf48ff"
      },
      "source": [
        "#### Notes on Troubleshooting\n",
        "1. The sequence in which you apply conditioners is crucial. Generally, it's best to apply stringent and all-encompassing constraints towards the end. For instance, symmetry, a constraint that affects the entire complex, should be implemented last in the conditioner list.\n",
        "When troubleshooting a conditioner, it's helpful to test it on a singular protein state. This helps in verifying if the resulting transformation aligns with your expectations.\n",
        "2. If your conditioner, like the SymmetryConditioner, make copies of a single protein multiple times, it's important to divide the pull-back gradients by the number of protein copies. This prevents excessive gradient accumulation on the protein asymmetric unit, similar to what occurs in the Lattice2DConditioner. Refer to Appendix M for more details.\n",
        "3. Adjusting sampling hyperparameters may be necessary when experimenting with new conditioners. Key parameters to consider include the langevin_factor, inverse_temperature, isothermal settings, steps, and guidance scale (especially when applying restraints). For dealing with hard constraints, it's usually advisable to use sde_func='langevin'."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}