File size: 1,816 Bytes
5e8301a
705eec3
78fc2e0
544fa79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e8301a
544fa79
 
 
 
5e8301a
705eec3
f8e3be7
544fa79
 
f8e3be7
 
53821d4
544fa79
 
 
 
53821d4
705eec3
 
1a0060e
705eec3
 
544fa79
705eec3
 
544fa79
 
 
 
dc6fd47
705eec3
544fa79
 
 
705eec3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torch
from model import *
import requests
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
object_store_url = os.getenv("OBJECT_STORE")
username = os.getenv("USERNAME")
password = os.getenv("PASSWORD")


def download(filename, directory):
    download_url = f"{object_store_url}{directory}/{filename}"
    response = requests.get(download_url, auth=(username, password))
    if response.status_code == 200:
        with open(filename, "wb") as file:
            file.write(response.content)
        print("File downloaded successfully")
    else:
        print(f"Failed to download file. Status code: {response.status_code}")
        print(response.text)


download("saved_model.pth", "ShakespeareGPT")
model = torch.load(
    "saved_model.pth", map_location=torch.device(device), weights_only=False
)


def generate_text(context, num_of_tokens, temperature=1.0):
    if context == None or context == "":
        idx = torch.zeros((1, 1), dtype=torch.long)
    else:
        idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
    text = ""
    for token in model.generate(
        idx, max_new_tokens=num_of_tokens, temperature=temperature
    ):
        text += token
        yield text


with gr.Blocks() as demo:
    gr.HTML("<h1 align='center'> Shakespeare Text Generator</h1>")

    context = gr.Textbox(label="Enter context (optional)")

    with gr.Row():
        num_of_tokens = gr.Number(label="Max tokens to generate", value=100)
        tmp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=1.0)

    inputs = [context, num_of_tokens, tmp]

    generate_btn = gr.Button(value="Generate")
    outputs = [gr.Textbox(label="Generated text: ")]
    generate_btn.click(fn=generate_text, inputs=inputs, outputs=outputs)

demo.launch()