File size: 6,701 Bytes
7782ac2
b0ab0d5
7782ac2
95ba5bc
 
 
 
 
 
 
 
 
7782ac2
 
4470746
711f689
200cdf1
 
 
0673854
200cdf1
 
 
0673854
 
dd9ffc6
0673854
200cdf1
4470746
200cdf1
 
 
 
 
0673854
4470746
9f1ea4e
4470746
0673854
9f1ea4e
4470746
 
0673854
200cdf1
 
711f689
4470746
711f689
81748f8
 
 
 
 
0673854
95ba5bc
 
49021fb
95ba5bc
49021fb
 
 
 
95ba5bc
 
 
49021fb
 
 
 
95ba5bc
 
 
 
 
d1da608
95ba5bc
 
 
 
 
 
 
 
 
 
 
0673854
 
 
 
95ba5bc
 
 
 
 
 
 
 
 
 
 
 
7782ac2
 
 
0673854
7c181a3
f9310fd
 
 
 
 
95ba5bc
 
 
 
b0ab0d5
 
 
f9310fd
 
 
 
95ba5bc
b0ab0d5
 
 
95ba5bc
 
 
 
 
 
 
 
b0ab0d5
 
 
95ba5bc
 
 
 
 
 
 
 
 
 
f9310fd
95ba5bc
f9310fd
95ba5bc
 
 
 
 
4f94923
 
f9310fd
4f94923
7782ac2
 
 
 
 
 
 
711f689
 
f8e8929
ebb60a3
7782ac2
 
 
711f689
f9310fd
7782ac2
f9310fd
7782ac2
 
 
 
4f94923
7782ac2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import numpy as np
import os
import torch
import subprocess

from rdkit import Chem
from src import const
from src.visualizer import save_xyz_file
from src.datasets import get_dataloader, collate_with_fragment_edges, parse_molecule
from src.lightning import DDPM
from src.linker_size_lightning import SizeClassifier


HTML_TEMPLATE = """<!DOCTYPE html>
<html>
    <head>
        <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
        <style>
            .mol-container {{
                width: 600px;
                height: 600px;
                position: relative;
            }}
            .mol-container select{{
                background-image:None;
            }}
        </style>
        <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
    </head>
    
    <body>
    <div id="container" class="mol-container"></div>
    <script>
        $(document).ready(function() {{
            let element = $("#container");
            let config = {{ backgroundColor: "white" }};
            let viewer = $3Dmol.createViewer( element, config );
            viewer.addModel(`{molecule}`, "{fmt}")
            viewer.getModel().setStyle({{ stick: {{ colorscheme:"greenCarbon" }} }})
            viewer.zoomTo();
            viewer.render();
        }});
    </script>
    </body>
</html>
"""

IFRAME_TEMPLATE = """<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera; 
display-capture; encrypted-media;" sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups 
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("results", exist_ok=True)
os.makedirs("models", exist_ok=True)

subprocess.run(
    'wget https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1 -O models/geom_size_gnn.ckpt',
    shell=True
)
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
print('Loaded SizeGNN model')

subprocess.run(
    'wget https://zenodo.org/record/7121300/files/geom_difflinker.ckpt?download=1 -O models/geom_difflinker.ckpt',
    shell=True
)
ddpm = DDPM.load_from_checkpoint('models/geom_difflinker.ckpt', map_location=device).eval().to(device)
print('Loaded diffusion model')


def sample_fn(_data):
    output, _ = size_nn.forward(_data, return_loss=False)
    probabilities = torch.softmax(output, dim=1)
    distribution = torch.distributions.Categorical(probs=probabilities)
    samples = distribution.sample()
    sizes = []
    for label in samples.detach().cpu().numpy():
        sizes.append(size_nn.linker_id2size[label])
    sizes = torch.tensor(sizes, device=samples.device, dtype=torch.long)
    return sizes


def read_molecule_content(path):
    with open(path, "r") as f:
        return "".join(f.readlines())


def read_molecule(path):
    if path.endswith('.pdb'):
        return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol'):
        return Chem.MolFromMolFile(path, sanitize=False, removeHs=True)
    elif path.endswith('.mol2'):
        return Chem.MolFromMol2File(path, sanitize=False, removeHs=True)
    elif path.endswith('.sdf'):
        return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0]
    raise Exception('Unknown file extension')


def generate(input_file):
    try:
        path = input_file.name
        molecule = read_molecule(path)
        name = '.'.join(path.split('/')[-1].split('.')[:-1])
        inp_sdf = f'results/{name}_input.sdf'
        inp_xyz = f'results/{name}_input.xyz'
        out_sdf = f'results/{name}_output.sdf'
        out_xyz = f'results/{name}_output.xyz'

        print(f'Input path={path}, name={name}')
    except Exception as e:
        return f'Could not read the molecule: {e}'

    if molecule.GetNumAtoms() > 50:
        return f'Too large molecule: upper limit is 50 heavy atoms'

    with Chem.SDWriter(inp_sdf) as w:
        w.write(molecule)
    Chem.MolToXYZFile(molecule, inp_xyz)

    positions, one_hot, charges = parse_molecule(molecule, is_geom=True)
    anchors = np.zeros_like(charges)
    fragment_mask = np.ones_like(charges)
    linker_mask = np.zeros_like(charges)
    print('Read and parsed molecule')

    dataset = [{
        'uuid': '0',
        'name': '0',
        'positions': torch.tensor(positions, dtype=const.TORCH_FLOAT, device=device),
        'one_hot': torch.tensor(one_hot, dtype=const.TORCH_FLOAT, device=device),
        'charges': torch.tensor(charges, dtype=const.TORCH_FLOAT, device=device),
        'anchors': torch.tensor(anchors, dtype=const.TORCH_FLOAT, device=device),
        'fragment_mask': torch.tensor(fragment_mask, dtype=const.TORCH_FLOAT, device=device),
        'linker_mask': torch.tensor(linker_mask, dtype=const.TORCH_FLOAT, device=device),
        'num_atoms': len(positions),
    }]
    dataloader = get_dataloader(dataset, batch_size=1, collate_fn=collate_with_fragment_edges)
    print('Created dataloader')

    for data in dataloader:
        chain, node_mask = ddpm.sample_chain(data, sample_fn=sample_fn, keep_frames=1)
        print('Generated linker')
        x = chain[0][:, :, :ddpm.n_dims]
        h = chain[0][:, :, ddpm.n_dims:]
        save_xyz_file('results', h, x, node_mask, names=[name], is_geom=True, suffix='output')
        print('Saved XYZ file')
        subprocess.run(f'obabel {out_xyz} -O {out_sdf}', shell=True)
        print('Converted to SDF')
        break

    generated_molecule = read_molecule_content(out_sdf)
    html = HTML_TEMPLATE.format(molecule=generated_molecule, fmt='sdf')
    return [
        IFRAME_TEMPLATE.format(html=html),
        [inp_sdf, inp_xyz, out_sdf, out_xyz],
    ]


demo = gr.Blocks()
with demo:
    gr.Markdown('# DiffLinker: Equivariant 3D-Conditional Diffusion Model for Molecular Linker Design')
    with gr.Box():
        with gr.Row():
            with gr.Column():
                gr.Markdown('## Input Fragments')
                gr.Markdown('Upload the file with 3D-coordinates of the input fragments in .pdb, .mol2 or .sdf format')
                input_file = gr.File(file_count='single', label='Input fragments')
            
    button = gr.Button('Generate Linker!')
    
    gr.Markdown('')
    gr.Markdown('## Output')
    visualization = gr.HTML()
    output_files = gr.File(file_count="multiple", label="Output Files")
    
    button.click(
        fn=generate,
        inputs=[input_file],
        outputs=[visualization, output_files],
    )

demo.launch()