|
import argparse |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
|
|
def test_model(prompt, model_path='/Users/raghul.v/Desktop/research/pii_extraction_test/results', model_name='distilgpt2'): |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
|
|
|
generated = tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
|
|
sample_outputs = model.generate( |
|
generated, |
|
do_sample=True, |
|
max_length=50, |
|
top_k=50, |
|
top_p=0.95, |
|
num_return_sequences=3 |
|
) |
|
|
|
for idx, sample_output in enumerate(sample_outputs): |
|
decoded_output = tokenizer.decode(sample_output, skip_special_tokens=True) |
|
print(f"Generated Text {idx}: {decoded_output}") |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description="Enter the prompt for the model.") |
|
parser.add_argument('--prompt', type=str, required=True, help='Prompt for the model') |
|
args = parser.parse_args() |
|
|
|
|
|
test_model(args.prompt, model_path='results', model_name='distilgpt2') |