ProteinBERT_HL / testforhuggingface.py
miladansari's picture
Update testforhuggingface.py
bb4aec7
raw
history blame
1.97 kB
# -*- coding: utf-8 -*-
"""testforhuggingface.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1w2iR_ooTp26Ng1TtfIv8_6bOp2WL9Xa9
"""
# Commented out IPython magic to ensure Python compatibility.
!pip install git+'https://github.com/miladansari/protein_bert.git'
!git clone https://github.com/nadavbra/shared_utils.git
#import related libraries
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import datetime
from pandas import read_csv
# %load_ext tensorboard
#reading files.
test = read_csv('hemolytic.test.csv',skipinitialspace=True)
train = read_csv('hemolytic.train.csv',skipinitialspace=True)
seqs=train['seq']
seqs_test=test['seq']
#set sequence length to 37 (the longest sequence in dataset is 35 and the model will add <start> and <end> token to the sequence.
seq_len=37
batch_size=32
from proteinbert import load_pretrained_model
from proteinbert.conv_and_global_attention_model import get_model_with_hidden_layers_as_outputs
#load pretrained model.
pretrained_model_generator, input_encoder = load_pretrained_model(local_model_dump_dir='proteinbert_preModel')
model = get_model_with_hidden_layers_as_outputs(pretrained_model_generator.create_model(seq_len))
#extract local representaion (Embeddings).
X = input_encoder.encode_X(seqs, seq_len)
local_representations, global_representations= model.predict(X, batch_size = batch_size)
X_test= input_encoder.encode_X(seqs_test, seq_len)
local_representations_test, global_representations_test= model.predict(X_test, batch_size = batch_size)
#simple classifier
model_D=tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape = local_representations[0].shape),
tf.keras.layers.Dense(1, activation = 'sigmoid')])
model_D.load_weights('/model_D_weights')