TianlaiChen commited on
Commit
107b8d2
·
1 Parent(s): 4306629

Add application file

Browse files
Files changed (1) hide show
  1. app.py +44 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
3
+ import torch
4
+ from torch.distributions.categorical import Categorical
5
+
6
+ # Load the model and tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M")
8
+ model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M")
9
+
10
+ def generate_peptide(protein_seq, peptide_length, top_k):
11
+ peptide_length = int(peptide_length)
12
+ top_k = int(top_k)
13
+
14
+ masked_peptide = '<mask>' * peptide_length
15
+ input_sequence = protein_seq + masked_peptide
16
+ inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device)
17
+
18
+ with torch.no_grad():
19
+ logits = model(**inputs).logits
20
+ mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
21
+ logits_at_masks = logits[0, mask_token_indices]
22
+
23
+ # Apply top-k sampling
24
+ top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1)
25
+ probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1)
26
+ predicted_indices = Categorical(probabilities).sample()
27
+ predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1)
28
+
29
+ generated_peptide = tokenizer.decode(predicted_token_ids, skip_special_tokens=True)
30
+ return f"Generated Sequence: {generated_peptide.replace(' ', '')}"
31
+
32
+ # Define the Gradio interface
33
+ interface = gr.Interface(
34
+ fn=generate_peptide,
35
+ inputs=[
36
+ gr.inputs.Textbox(label="Protein Sequence", default="Enter protein sequence here", type="text"),
37
+ gr.inputs.Dropdown(choices=[str(i) for i in range(2, 51)], label="Peptide Length", default="15"),
38
+ gr.inputs.Dropdown(choices=[str(i) for i in range(1, 11)], label="Top K Value", default="3")
39
+ ],
40
+ outputs="textbox",
41
+ live=True
42
+ )
43
+
44
+ interface.launch()