File size: 2,832 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import argparse
import selfies as sf
from tqdm import tqdm
from transformers import T5EncoderModel
from transformers import set_seed
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion import dist_util, logger
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
from src.improved_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
)
from src.scripts.mydatasets import Lang2molDataset_submission
import streamlit as st
import os


@st.cache_resource
def get_encoder():
    model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
    model.eval()
    return model


@st.cache_resource
def get_tokenizer():
    return Tokenizer()


@st.cache_resource
def get_model():
    model = TransformerNetModel(
        in_channels=32,
        model_channels=128,
        dropout=0.1,
        vocab_size=35073,
        hidden_size=1024,
        num_attention_heads=16,
        num_hidden_layers=12,
    )
    model.load_state_dict(
        dist_util.load_state_dict(
            os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"),
            map_location="cpu",
        )
    )
    model.eval()
    return model


@st.cache_resource
def get_diffusion():
    return SpacedDiffusion(
        use_timesteps=[i for i in range(0, 2000, 10)],
        betas=gd.get_named_beta_schedule("sqrt", 2000),
        model_mean_type=(gd.ModelMeanType.START_X),
        model_var_type=((gd.ModelVarType.FIXED_LARGE)),
        loss_type=gd.LossType.E2E_MSE,
        rescale_timesteps=True,
        model_arch="transformer",
        training_mode="e2e",
    )


tokenizer = get_tokenizer()
encoder = get_encoder()
model = get_model()
diffusion = get_diffusion()

sample_fn = diffusion.ddim_sample_loop

text_input = st.text_area("Enter molecule description")
output = tokenizer(
    text_input,
    max_length=256,
    truncation=True,
    padding="max_length",
    add_special_tokens=True,
    return_tensors="pt",
    return_attention_mask=True,
)
caption_state = encoder(
    input_ids=output["input_ids"],
    attention_mask=output["attention_mask"],
).last_hidden_state
caption_mask = output["attention_mask"]

outputs = sample_fn(
    model,
    (1, 256, 32),
    clip_denoised=False,
    denoised_fn=None,
    model_kwargs={},
    top_p=1.0,
    progress=True,
    caption=(caption_state, caption_mask),
)
logits = model.get_logits(torch.tensor(outputs))
cands = torch.topk(logits, k=1, dim=-1)
outputs = cands.indices
outputs = outputs.squeeze(-1)
outputs = tokenizer.decode(outputs)
result = sf.decoder(
    outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
).replace("\t", "")

st.write(result)