File size: 1,418 Bytes
b7318af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import argparse
from transformers import GPT2LMHeadModel, GPT2Tokenizer


def test_model(prompt, model_path='/Users/raghul.v/Desktop/research/pii_extraction_test/results', model_name='distilgpt2'):  # Adjust the default model_name
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_path)  # Loading the fine-tuned model

    # Note: Adjustments for running the model on a CUDA device are commented out,
    # Uncomment and use those if CUDA is available and desired.

    # Encode the prompt to test
    generated = tokenizer.encode(prompt, return_tensors="pt")

    # Generate text from the model
    sample_outputs = model.generate(
        generated,
        do_sample=True,
        max_length=50,
        top_k=50,
        top_p=0.95,
        num_return_sequences=3
    )

    for idx, sample_output in enumerate(sample_outputs):
        decoded_output = tokenizer.decode(sample_output, skip_special_tokens=True)
        print(f"Generated Text {idx}: {decoded_output}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Enter the prompt for the model.")
    parser.add_argument('--prompt', type=str, required=True, help='Prompt for the model')
    args = parser.parse_args()

    # Default model_path is 'results', adjust if the model is saved elsewhere
    test_model(args.prompt, model_path='results', model_name='distilgpt2')