Source code for dscript.pretrained

import os, sys
import torch

from .models.embedding import FullyConnectedEmbed, SkipLSTM
from .models.contact import ContactCNN
from .models.interaction import ModelInteraction


def build_lm_1(state_dict_path):
    """
    :meta private:
    """
    model = SkipLSTM(21, 100, 1024, 3)
    state_dict = torch.load(state_dict_path)
    model.load_state_dict(state_dict)
    model.eval()
    return model


def build_human_1(state_dict_path):
    """
    :meta private:
    """
    embModel = FullyConnectedEmbed(6165, 100, 0.5)
    conModel = ContactCNN(100, 50, 7)
    model = ModelInteraction(embModel, conModel, use_W=True, pool_size=9)
    state_dict = torch.load(state_dict_path)
    model.load_state_dict(state_dict)
    model.eval()
    return model


VALID_MODELS = {
        "lm_v1": build_lm_1,
        "human_v1": build_human_1
        }


[docs]def get_state_dict(version="human_v1", verbose=True): """ Download a pre-trained model if not already exists on local device. :param version: Version of trained model to download [default: human_1] :type version: str :param verbose: Print model download status on stdout [default: True] :type verbose: bool :return: Path to state dictionary for pre-trained language model :rtype: str """ state_dict_basename = f"dscript_{version}.pt" state_dict_basedir = os.path.dirname(os.path.realpath(__file__)) state_dict_fullname = f"{state_dict_basedir}/{state_dict_basename}" state_dict_url = f"http://cb.csail.mit.edu/cb/dscript/data/models/{state_dict_basename}" if not os.path.exists(state_dict_fullname): try: import urllib.request import shutil if verbose: print(f"Downloading model {version} from {state_dict_url}...") with urllib.request.urlopen(state_dict_url) as response, open(state_dict_fullname, 'wb') as out_file: shutil.copyfileobj(response, out_file) except Exception as e: print("Unable to download model - {}".format(e)) sys.exit(1) return state_dict_fullname
[docs]def get_pretrained(version="human_v1"): """ Get pre-trained model object. Currently Available Models ========================== See the `documentation <https://d-script.readthedocs.io/en/main/data.html#trained-models>`_ for most up-to-date list. - ``lm_v1`` - Language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_. - ``human_v1`` - Human trained model from D-SCRIPT manuscript. Default: ``human_v1`` :param version: Version of pre-trained model to get :type version: str :return: Pre-trained model :rtype: dscript.models.* """ if not version in VALID_MODELS: raise ValueError("Model {} does not exist".format(version)) state_dict_path = get_state_dict(version) return VALID_MODELS[version](state_dict_path)