File size: 3,066 Bytes
582a540 b492a07 582a540 239b2f3 582a540 239b2f3 582a540 d077743 582a540 d077743 582a540 239b2f3 582a540 d077743 582a540 d077743 582a540 d077743 582a540 239b2f3 582a540 ed9a903 582a540 d077743 582a540 d077743 582a540 d077743 b492a07 d077743 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# Check if GPU is available, otherwise use CPU
# device = "cuda" if torch.cuda.is_available() else "cpu"
# Load pre-trained GPT-2 model and tokenizer
model_name = "gpt2-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
def generate_text(input_text, max_length=16, num_beams=5, do_sample=False, no_repeat_ngram_size=2):
"""
Generate text based on the given input text.
Parameters:
- input_text (str): The input text to start generation from.
- max_length (int): Maximum length of the generated text.
- num_beams (int): Number of beams for beam search.
- do_sample (bool): Whether to use sampling or not.
- no_repeat_ngram_size (int): Size of the n-gram to avoid repetition.
Returns:
- generated_text (str): The generated text.
"""
# Encode the input text and move it to the appropriate device
input_ids = tokenizer(input_text, return_tensors='pt')['input_ids']
# Generate text using the model
output = model.generate(input_ids, max_length=max_length, num_beams=num_beams,
do_sample=do_sample, no_repeat_ngram_size=no_repeat_ngram_size)
# Decode the generated output
generated_text = tokenizer.decode(output[0])
return generated_text
def generate_text_with_nucleus_search(input_text, max_length=128, do_sample=True, top_p=0.9):
"""
Generate text with nucleus sampling based on the given input text.
Parameters:
- input_text (str): The input text to start generation from.
- max_length (int): Maximum length of the generated text.
- do_sample (bool): Whether to use sampling or not.
- top_p (float): Nucleus sampling parameter.
Returns:
- generated_text (str): The generated text.
"""
# Encode the input text and move it to the appropriate device
input_ids = tokenizer(input_text, return_tensors='pt')['input_ids']
# Generate text using nucleus sampling
output = model.generate(input_ids, max_length=max_length, do_sample=do_sample, top_p=top_p)
# Decode the generated output
generated_text = tokenizer.decode(output[0])
return generated_text
# Create Gradio interfaces
input_text_interface = gr.Textbox(lines=5, label="Input Text", placeholder="Enter text for generation...")
output_text_interface = gr.Textbox(label="Generated Text", placeholder="Generated text will appear here...")
gr.Interface(generate_text, input_text_interface, output_text_interface,
title="Text Generation with GPT-2",
description="Generate text using the GPT-2 model.",
allow_flagging="never").launch(share=True)
gr.Interface(generate_text_with_nucleus_search, input_text_interface, output_text_interface,
title="Text Generation with Nucleus Sampling",
description="Generate text using nucleus sampling with the GPT-2 model.",
allow_flagging="never").launch(share=True)
|