{ "cells": [ { "cell_type": "markdown", "id": "6db2e2b4", "metadata": { "id": "6db2e2b4" }, "source": [ "## Build and Compose Conditioners\n", "\n", "### Overview\n", "Protein design via Chroma is highly customizable and programmable. Our robust Conditioner framework enables automatic conditional sampling tailored to a diverse array of protein specifications. This can involve either restraints (which bias the distribution of states using classifier guidance) or constraints (that directly limit the scope of the underlying sampling process). For a detailed explanation, refer to Supplementary Appendix M in our paper. We offer a variety of pre-defined conditioners, including those for managing substructure, symmetry, shape, semantics, and even natural-language prompts (see `chroma.layers.structure.conditioners`). These conditioners can be utilized in any combination to suit your specific needs." ] }, { "cell_type": "markdown", "id": "3b4c35a7", "metadata": { "id": "3b4c35a7" }, "source": [ "### Composing Conditioners\n", "\n", "Conditioners in Chroma can be combined seamlessly using `conditioners.ComposedConditioner`, akin to how layers are sequenced in `torch.nn.Sequential`. You can define individual conditioners and then aggregate them into a single collective list which will sequentially apply constrained transformations.\n", "\n", "```python\n", "composed_conditioner = conditioners.ComposedConditioner([conditioner1, conditioner2, conditioner3])\n", "```" ] }, { "cell_type": "markdown", "id": "b2lOsBQFhypc", "metadata": { "id": "b2lOsBQFhypc" }, "source": [ "#### Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "4db56efd", "metadata": { "id": "4db56efd" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0854e7da4ca04f71a86260c2e66bbfcc", "version_major": 2, "version_minor": 0 }, "text/plain": [] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import locale\n", "locale.getpreferredencoding = lambda: \"UTF-8\"\n", "%pip install generate-chroma > /dev/null 2>&1\n", "from chroma import api\n", "api.register_key(input(\"Enter API key: \"))" ] }, { "cell_type": "markdown", "id": "f3ee7c51", "metadata": { "id": "f3ee7c51" }, "source": [ "#### Example 1: Combining Symmetry and Secondary Structure\n", "In this scenario, we initially apply guidance for secondary structure to condition the content accordingly. This is followed by incorporating Cyclic symmetry. This approach involves adding a secondary structure classifier to conditionally sample an Asymmetric unit (AU) that is beta-rich, followed by symmetrization." ] }, { "cell_type": "code", "execution_count": 10, "id": "37b9c48f", "metadata": { "id": "37b9c48f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt\n", "Loaded from cache\n", "Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt\n", "Loaded from cache\n", "Data saved to /tmp/chroma_weights/3262b44702040b1dcfccd71ebbcf451d/weights.pt\n", "Computing reference stats for 2g3n\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c67d34c17a45442cb2804cc3c8060222", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Integrating SDE: 0%| | 0/500 [00:00 8\u001b[0m substruct_conditioner \u001b[38;5;241m=\u001b[39m \u001b[43mconditioners\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSubstructureConditioner\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mprotein\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbackbone_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mchroma\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackbone_network\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mselection\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mx < 25 and y < 25\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# C_3 symmetry\u001b[39;00m\n\u001b[1;32m 12\u001b[0m c_symmetry \u001b[38;5;241m=\u001b[39m conditioners\u001b[38;5;241m.\u001b[39mSymmetryConditioner(G\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mC_3\u001b[39m\u001b[38;5;124m\"\u001b[39m, num_chain_neighbors\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m3\u001b[39m)\n", "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/conditioners.py:881\u001b[0m, in \u001b[0;36mSubstructureConditioner.__init__\u001b[0;34m(self, protein, backbone_model, selection, rg, weight, tspan, weight_max, gamma, center_init)\u001b[0m\n\u001b[1;32m 879\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_distribution \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbackbone_model\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39mbase_gaussian\n\u001b[1;32m 880\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnoise_schedule \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbackbone_model\u001b[38;5;241m.\u001b[39mnoise_perturb\u001b[38;5;241m.\u001b[39mnoise_schedule\n\u001b[0;32m--> 881\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconditional_distribution \u001b[38;5;241m=\u001b[39m \u001b[43mmvn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mConditionalBackboneMVNGlobular\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 882\u001b[0m \u001b[43m \u001b[49m\u001b[43mcovariance_model\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_distribution\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcovariance_model\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 883\u001b[0m \u001b[43m \u001b[49m\u001b[43mcomplex_scaling\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_distribution\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcomplex_scaling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 884\u001b[0m \u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 885\u001b[0m \u001b[43m \u001b[49m\u001b[43mC\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mC\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 886\u001b[0m \u001b[43m \u001b[49m\u001b[43mD\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mD\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 887\u001b[0m \u001b[43m \u001b[49m\u001b[43mgamma\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgamma\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 888\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 889\u001b[0m X \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconditional_distribution\u001b[38;5;241m.\u001b[39msample(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 890\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtspan \u001b[38;5;241m=\u001b[39m tspan\n", "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/mvn.py:563\u001b[0m, in \u001b[0;36mConditionalBackboneMVNGlobular.__init__\u001b[0;34m(self, covariance_model, complex_scaling, sigma_translation, X, C, D, gamma, **kwargs)\u001b[0m\n\u001b[1;32m 560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mR\u001b[39m\u001b[38;5;124m\"\u001b[39m, R)\n\u001b[1;32m 561\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRRt\u001b[39m\u001b[38;5;124m\"\u001b[39m, RRt)\n\u001b[0;32m--> 563\u001b[0m R_clamp, RRt_clamp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_condition_RRt\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRRt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mD\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 564\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mR_clamp\u001b[39m\u001b[38;5;124m\"\u001b[39m, R_clamp)\n\u001b[1;32m 565\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRRt_clamp\u001b[39m\u001b[38;5;124m\"\u001b[39m, RRt_clamp)\n", "File \u001b[0;32m~/anaconda3/envs/mlfold/lib/python3.8/site-packages/chroma/layers/structure/mvn.py:712\u001b[0m, in \u001b[0;36mConditionalBackboneMVNGlobular._condition_RRt\u001b[0;34m(self, RRt, D)\u001b[0m\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mS22\u001b[39m\u001b[38;5;124m\"\u001b[39m, RRt[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnonzero_indices][:, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnonzero_indices])\n\u001b[1;32m 711\u001b[0m S_clamp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS11 \u001b[38;5;241m-\u001b[39m ((\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS12 \u001b[38;5;241m@\u001b[39m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mpinv(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS22) \u001b[38;5;241m@\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mS21))\n\u001b[0;32m--> 712\u001b[0m R_clamp \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinalg\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcholesky\u001b[49m\u001b[43m(\u001b[49m\u001b[43mS_clamp\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 713\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mregister_buffer(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRRt_clamp_restricted\u001b[39m\u001b[38;5;124m\"\u001b[39m, R_clamp \u001b[38;5;241m@\u001b[39m R_clamp\u001b[38;5;241m.\u001b[39mt())\n\u001b[1;32m 714\u001b[0m RRt_clamp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_scatter(\n\u001b[1;32m 715\u001b[0m torch\u001b[38;5;241m.\u001b[39mzeros_like(RRt), \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mzero_indices, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mRRt_clamp_restricted\n\u001b[1;32m 716\u001b[0m )\n", "\u001b[0;31mRuntimeError\u001b[0m: torch.linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 27 is not positive-definite)." ] } ], "source": [ "from chroma.data import Protein\n", "\n", "PDB_ID = '3BDI'\n", "chroma = Chroma()\n", "\n", "protein = Protein(PDB_ID, canonicalize=True, device='cuda')\n", "# regenerate residues with X coord < 25 A and y coord < 25 A\n", "substruct_conditioner = conditioners.SubstructureConditioner(\n", " protein, backbone_model=chroma.backbone_network, selection=\"x < 25 and y < 25\")\n", "\n", "# C_3 symmetry\n", "c_symmetry = conditioners.SymmetryConditioner(G=\"C_3\", num_chain_neighbors=3)\n", "\n", "# Composing\n", "composed_cond = conditioners.ComposedConditioner([substruct_conditioner, c_symmetry])\n", "\n", "protein, trajectories = chroma.sample(\n", " protein_init=protein,\n", " conditioner=composed_cond,\n", " langevin_factor=4.0,\n", " langevin_isothermal=True,\n", " inverse_temperature=8.0,\n", " sde_func='langevin',\n", " steps=500,\n", " full_output=True,\n", ")\n", "\n", "protein" ] }, { "cell_type": "markdown", "id": "de3c2b97", "metadata": { "id": "de3c2b97" }, "source": [ "### Build your own conditioners: 2D protein lattices\n", "\n", "An attractive aspect of this conditioner framework is that it is very general, enabling both constraints (which involve operations on $x$) and restraints (which amount to changes to $U$). At the same time, generation under restraints can still be (and often is) challenging, as the resulting effective energy landscape can become arbitrarily rugged and difficult to integrate. We therefore advise caution when using and developing new conditioners or conditioner combinations. We find that inspecting diffusition trajectories (including unconstrained and denoised trajectories, $\\hat{x}_t$ and $\\tilde{x}_t$) can be a good tool for identifying integration challenges and defining either better conditioner forms or better sampling regimes.\n", "\n", "Here we present how to build a conditioner that generates a periodic 2D lattice. You can easily extend this code snippet to generate 3D protein materials." ] }, { "cell_type": "code", "execution_count": 13, "id": "2bb9dcf3", "metadata": { "id": "2bb9dcf3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cached data from /tmp/chroma_weights/90e339502ae6b372797414167ce5a632/weights.pt\n", "Loaded from cache\n", "Using cached data from /tmp/chroma_weights/03a3a9af343ae74998768a2711c8b7ce/weights.pt\n", "Loaded from cache\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "14663b93ee8a48d3932635b073d60eda", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Integrating SDE: 0%| | 0/500 [00:00