Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from model import * | |
import requests | |
import os | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
object_store_url = os.getenv("OBJECT_STORE") | |
username = os.getenv("USERNAME") | |
password = os.getenv("PASSWORD") | |
def download(filename, directory): | |
download_url = f"{object_store_url}{directory}/{filename}" | |
response = requests.get(download_url, auth=(username, password)) | |
if response.status_code == 200: | |
with open(filename, "wb") as file: | |
file.write(response.content) | |
print("File downloaded successfully") | |
else: | |
print(f"Failed to download file. Status code: {response.status_code}") | |
print(response.text) | |
download("saved_model.pth", "ShakespeareGPT") | |
model = torch.load( | |
"saved_model.pth", map_location=torch.device(device), weights_only=False | |
) | |
def generate_text(context, num_of_tokens, temperature=1.0): | |
if context == None or context == "": | |
idx = torch.zeros((1, 1), dtype=torch.long) | |
else: | |
idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0) | |
text = "" | |
for token in model.generate( | |
idx, max_new_tokens=num_of_tokens, temperature=temperature | |
): | |
text += token | |
yield text | |
with gr.Blocks() as demo: | |
gr.HTML("<h1 align='center'> Shakespeare Text Generator</h1>") | |
context = gr.Textbox(label="Enter context (optional)") | |
with gr.Row(): | |
num_of_tokens = gr.Number(label="Max tokens to generate", value=100) | |
tmp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=1.0) | |
inputs = [context, num_of_tokens, tmp] | |
generate_btn = gr.Button(value="Generate") | |
outputs = [gr.Textbox(label="Generated text: ")] | |
generate_btn.click(fn=generate_text, inputs=inputs, outputs=outputs) | |
demo.launch() | |