Spaces:
Sleeping
Sleeping
File size: 1,816 Bytes
5e8301a 705eec3 78fc2e0 544fa79 5e8301a 544fa79 5e8301a 705eec3 f8e3be7 544fa79 f8e3be7 53821d4 544fa79 53821d4 705eec3 1a0060e 705eec3 544fa79 705eec3 544fa79 dc6fd47 705eec3 544fa79 705eec3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import gradio as gr
import torch
from model import *
import requests
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
object_store_url = os.getenv("OBJECT_STORE")
username = os.getenv("USERNAME")
password = os.getenv("PASSWORD")
def download(filename, directory):
download_url = f"{object_store_url}{directory}/{filename}"
response = requests.get(download_url, auth=(username, password))
if response.status_code == 200:
with open(filename, "wb") as file:
file.write(response.content)
print("File downloaded successfully")
else:
print(f"Failed to download file. Status code: {response.status_code}")
print(response.text)
download("saved_model.pth", "ShakespeareGPT")
model = torch.load(
"saved_model.pth", map_location=torch.device(device), weights_only=False
)
def generate_text(context, num_of_tokens, temperature=1.0):
if context == None or context == "":
idx = torch.zeros((1, 1), dtype=torch.long)
else:
idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
text = ""
for token in model.generate(
idx, max_new_tokens=num_of_tokens, temperature=temperature
):
text += token
yield text
with gr.Blocks() as demo:
gr.HTML("<h1 align='center'> Shakespeare Text Generator</h1>")
context = gr.Textbox(label="Enter context (optional)")
with gr.Row():
num_of_tokens = gr.Number(label="Max tokens to generate", value=100)
tmp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=1.0)
inputs = [context, num_of_tokens, tmp]
generate_btn = gr.Button(value="Generate")
outputs = [gr.Textbox(label="Generated text: ")]
generate_btn.click(fn=generate_text, inputs=inputs, outputs=outputs)
demo.launch()
|