|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import os |
|
|
|
def load_model_and_tokenizer(model_path): |
|
|
|
try: |
|
print(f"Attempting to load model from directory: {model_path}") |
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
except Exception as e: |
|
print(f"Failed to load from directory. Error: {e}") |
|
|
|
safetensors_path = os.path.join(model_path, "model.safetensors") |
|
if os.path.exists(safetensors_path): |
|
print(f"Attempting to load model from file: {safetensors_path}") |
|
model = AutoModelForCausalLM.from_pretrained(safetensors_path) |
|
else: |
|
raise ValueError(f"Could not find model at {model_path} or {safetensors_path}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("distilgpt2") |
|
|
|
return model, tokenizer |
|
|
|
def generate_text(model, tokenizer, prompt, max_length=125, num_return_sequences=1): |
|
input_ids = tokenizer.encode(prompt, return_tensors='pt') |
|
|
|
|
|
output = model.generate( |
|
input_ids, |
|
max_length=max_length, |
|
num_return_sequences=num_return_sequences, |
|
no_repeat_ngram_size=6, |
|
top_k=25, |
|
top_p=0.99, |
|
temperature=0.34 |
|
) |
|
|
|
return [tokenizer.decode(seq, skip_special_tokens=True) for seq in output] |
|
|
|
def main(): |
|
model_path = r"literalpathtothefoldernamed\checkpoint-4000" |
|
|
|
print(f"Attempting to load model...") |
|
model, tokenizer = load_model_and_tokenizer(model_path) |
|
|
|
print("Model loaded successfully. Enter prompts to generate text. Type 'quit' to exit.") |
|
|
|
while True: |
|
prompt = input("Enter a prompt: ") |
|
if prompt.lower() == 'quit': |
|
break |
|
|
|
generated_texts = generate_text(model, tokenizer, prompt) |
|
|
|
print("\nGenerated Text:") |
|
for i, text in enumerate(generated_texts, 1): |
|
print(f"{i}. {text}\n") |
|
|
|
if __name__ == "__main__": |
|
main() |