|
import torch |
|
from model import BigramLanguageModel, decode |
|
import gradio as gr |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
model = BigramLanguageModel() |
|
model.load_state_dict(torch.load("./neo_gpt.pth", map_location=device)) |
|
def generate_text(max_new_tokens): |
|
context = torch.zeros((1, 1), dtype=torch.long) |
|
return decode(model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()) |
|
|
|
|
|
|
|
title = "Text Generation: Write Like Shakespeare" |
|
description = "This Gradio app uses a large language model (LLM) to generate text in the style of William Shakespeare." |
|
|
|
|
|
|
|
g_app = gr.Interface( |
|
fn = generate_text, |
|
inputs = [gr.Number(value = 10,label = "Number of Output Tokens",info = "Specify the desired length of the text to be generated.")], |
|
outputs = [gr.TextArea(lines = 5,label="Generated Text")], |
|
title = title, |
|
description = description |
|
|
|
) |
|
|
|
|
|
g_app.launch() |
|
|