File size: 2,195 Bytes
d8ed92a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer
import math
import torch
import sys
import pandas as pd

# Function to calculate perplexity of each generated sequence
def calculate_perplexity(sequence, model, tokenizer):
    sequence = "<|endoftext|>"  + sequence + "<|endoftext|>"
    input_ids = torch.tensor(tokenizer.encode(sequence)).unsqueeze(0)
    input_ids = input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
    loss, _ = outputs[:2]
    return math.exp(loss)

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    path = "/workspace/sg666/MDpLM/benchmarks/Generation/ProtGPT2"

    # Load fine-tuned model and tokenizer
    model_path = path + "/finetuned_models/checkpoint-4510"
    model = AutoModelForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # Generate sequences
    protgpt2 = pipeline('text-generation', model=model_path, device=device)
    sequences = protgpt2("", max_length=100, do_sample=True, top_k=950, repetition_penalty=1.5, num_return_sequences=100, eos_token_id=0)

    # Store generated sequences and their associated perplexities
    generated_sequences = []
    perplexities = []


    # Calculate PPL for sequences
    for item in sequences:
        raw_sequence = item['generated_text']
        ppl = calculate_perplexity(raw_sequence, model.to(device), tokenizer)
        generated_sequences.append(raw_sequence)
        perplexities.append(ppl)

    # Clean the generated sequences
    cleaned_sequences = [seq.replace('\n', '').replace('<|endoftext|>', '') for seq in generated_sequences]

    # Create df with cleaned sequences and perplexities
    df = pd.DataFrame({"Sequence": cleaned_sequences, "Perplexity": perplexities})
    df.sort_values(by='Perplexity', inplace=True)

    # Save results
    df.to_csv(path + "/protgpt2_generated_sequences.csv", index=False)

    # View the average de novo generation perplexity
    avg_generation_ppl = df.loc[:, 'Perplexity'].mean()
    print(f'Average de novo generation perplexity: {avg_generation_ppl}')