Saif Rehman Nasir commited on
Commit
544fa79
·
1 Parent(s): d2872cd

Move model to object store

Browse files
Files changed (2) hide show
  1. app.py +38 -17
  2. saved_model.pth +0 -3
app.py CHANGED
@@ -1,38 +1,59 @@
1
  import gradio as gr
2
  import torch
3
  from model import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
6
 
7
- model = torch.load('saved_model.pth', map_location= torch.device(device), weights_only=False)
8
 
9
  def generate_text(context, num_of_tokens, temperature=1.0):
10
- if context == None or context == '':
11
- idx = torch.zeros((1,1), dtype=torch.long)
12
  else:
13
  idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
14
  text = ""
15
- for token in model.generate(idx, max_new_tokens=num_of_tokens,temperature=temperature):
16
- text+= token
 
 
17
  yield text
18
 
19
 
20
  with gr.Blocks() as demo:
21
  gr.HTML("<h1 align='center'> Shakespeare Text Generator</h1>")
22
 
23
- context = gr.Textbox(label = "Enter context (optional)")
24
 
25
  with gr.Row():
26
- num_of_tokens = gr.Number( label = "Max tokens to generate", value = 100)
27
- tmp = gr.Slider(label= "Temperature", minimum = 0.0, maximum = 1.0, value = 1.0 )
28
-
29
- inputs = [
30
- context,
31
- num_of_tokens,tmp
32
- ]
33
 
34
  generate_btn = gr.Button(value="Generate")
35
- outputs = [gr.Textbox(label= "Generated text: ")]
36
- generate_btn.click(fn = generate_text, inputs= inputs, outputs= outputs)
37
-
38
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from model import *
4
+ import requests
5
+ import os
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ object_store_url = os.getenv("OBJECT_STORE")
9
+ username = os.getenv("USERNAME")
10
+ password = os.getenv("PASSWORD")
11
+
12
+
13
+ def download(filename, directory):
14
+ download_url = f"{object_store_url}{directory}/{filename}"
15
+ response = requests.get(download_url, auth=(username, password))
16
+ if response.status_code == 200:
17
+ with open(filename, "wb") as file:
18
+ file.write(response.content)
19
+ print("File downloaded successfully")
20
+ else:
21
+ print(f"Failed to download file. Status code: {response.status_code}")
22
+ print(response.text)
23
+
24
 
25
+ download("saved_model.pth", "ShakespeareGPT")
26
+ model = torch.load(
27
+ "saved_model.pth", map_location=torch.device(device), weights_only=False
28
+ )
29
 
 
30
 
31
  def generate_text(context, num_of_tokens, temperature=1.0):
32
+ if context == None or context == "":
33
+ idx = torch.zeros((1, 1), dtype=torch.long)
34
  else:
35
  idx = torch.tensor(encode(context), dtype=torch.long).unsqueeze(0)
36
  text = ""
37
+ for token in model.generate(
38
+ idx, max_new_tokens=num_of_tokens, temperature=temperature
39
+ ):
40
+ text += token
41
  yield text
42
 
43
 
44
  with gr.Blocks() as demo:
45
  gr.HTML("<h1 align='center'> Shakespeare Text Generator</h1>")
46
 
47
+ context = gr.Textbox(label="Enter context (optional)")
48
 
49
  with gr.Row():
50
+ num_of_tokens = gr.Number(label="Max tokens to generate", value=100)
51
+ tmp = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=1.0)
52
+
53
+ inputs = [context, num_of_tokens, tmp]
 
 
 
54
 
55
  generate_btn = gr.Button(value="Generate")
56
+ outputs = [gr.Textbox(label="Generated text: ")]
57
+ generate_btn.click(fn=generate_text, inputs=inputs, outputs=outputs)
58
+
59
  demo.launch()
saved_model.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f267ea6518686be472643110dfc37d669bbcb5138ea05040d1262d0d89f60d78
3
- size 52717798