nferruz commited on
Commit
275ee36
·
1 Parent(s): a230244

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +19 -6
README.md CHANGED
@@ -24,7 +24,8 @@ In the example below, ProtGPT2 generates sequences that follow the amino acid 'M
24
  ```
25
  >>> from transformers import pipeline
26
  >>> protgpt2 = pipeline('text-generation', model="nferruz/ProtGPT2")
27
- >>> sequences = protgpt2("M", max_length=100, do_sample=True, top_k=950, repetition_penalty=1.2, num_return_sequences=10, eos_token_id=0)
 
28
  >>> for seq in sequences:
29
  print(seq):
30
  {'generated_text': 'MINDLLDISRIISGKMTLDRAEVNLTAIARQVVEEQRQAAEAKSIQLLCSTPDTNHYVFG\nDFDRLKQTLWNLLSNAVKFTPSGGTVELELGYNAEGMEVYVKDSGIGIDPAFLPYVFDRF\nRQSDAADSRNYGGLGLGLAIVKHLLDLHEGNVSAQSEGFGKGATFTVLLPLKPLKRELAA\nVNRHTAVQQSAPLNDNLAGMKILIVEDRPDTNEMVSYILEEAGAIVETAESGAAALTSLK\nSYSPDLVLSDIGMPMMDGYEMIEYIREWKTTKGG'}
@@ -54,15 +55,27 @@ The HuggingFace script run_clm.py can be found here: https://github.com/huggingf
54
 
55
  ### **How to select the best sequences**
56
  We've observed that perplexity values correlate with AlphaFold2's plddt.
57
- We recommend to compute perplexity for each sequence with the HuggingFace evaluate method `perplexity`:
58
 
59
  ```
60
- from evaluate import load
61
- perplexity = load("perplexity", module_type="metric")
62
- results = perplexity.compute(predictions=predictions, model_id='nferruz/ProtGPT2')
 
 
 
 
 
 
 
 
 
 
 
 
63
  ```
64
 
65
- Where `predictions` is a list containing the generated sequences.
66
  We do not yet have a threshold as of what perplexity value gives a 'good' or 'bad' sequence, but given the fast inference times, the best is to sample many sequences, order them by perplexity, and select those with the lower values (the lower the better).
67
 
68
 
 
24
  ```
25
  >>> from transformers import pipeline
26
  >>> protgpt2 = pipeline('text-generation', model="nferruz/ProtGPT2")
27
+ # length is expressed in tokens, where each token has an average length of 4 amino acids.
28
+ >>> sequences = protgpt2("<|endoftext|>", max_length=100, do_sample=True, top_k=950, repetition_penalty=1.2, num_return_sequences=10, eos_token_id=0)
29
  >>> for seq in sequences:
30
  print(seq):
31
  {'generated_text': 'MINDLLDISRIISGKMTLDRAEVNLTAIARQVVEEQRQAAEAKSIQLLCSTPDTNHYVFG\nDFDRLKQTLWNLLSNAVKFTPSGGTVELELGYNAEGMEVYVKDSGIGIDPAFLPYVFDRF\nRQSDAADSRNYGGLGLGLAIVKHLLDLHEGNVSAQSEGFGKGATFTVLLPLKPLKRELAA\nVNRHTAVQQSAPLNDNLAGMKILIVEDRPDTNEMVSYILEEAGAIVETAESGAAALTSLK\nSYSPDLVLSDIGMPMMDGYEMIEYIREWKTTKGG'}
 
55
 
56
  ### **How to select the best sequences**
57
  We've observed that perplexity values correlate with AlphaFold2's plddt.
58
+ We recommend to compute perplexity for each sequence as follows:
59
 
60
  ```
61
+ def calculatePerplexity(sequence, model, tokenizer):
62
+ with torch.no_grad():
63
+ outputs = model(sequence, labels=input_ids)
64
+ loss, logits = outputs[:2]
65
+ return math.exp(loss)
66
+
67
+ # Generate sequences by loading model and tokenizer (previously downloaded)
68
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
+ tokenizer = AutoTokenizer.from_pretrained('/path/to/tokenizer') # replace with the actual path
70
+ model = GPT2LMHeadModel.from_pretrained('/path/to/output').to(device)
71
+ output = model.generate("<|endoftext|>", max_length=400, do_sample=True, top_k=950, repetition_penalty=1.2, num_return_sequences=10, eos_token_id=0)
72
+
73
+ # Take (for example) the first sequence
74
+ sequence = output[0]
75
+ ppl = calculatePerplexity(sequence, model, tokenizer)
76
  ```
77
 
78
+ Where `ppl` is a value with the perplexity for that sequence.
79
  We do not yet have a threshold as of what perplexity value gives a 'good' or 'bad' sequence, but given the fast inference times, the best is to sample many sequences, order them by perplexity, and select those with the lower values (the lower the better).
80
 
81