Nick088 commited on
Commit
07bb887
·
verified ·
1 Parent(s): f2e8fa9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -3,9 +3,6 @@ import torch
3
  import random
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
- def load_model(model_path, dtype):
7
- model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=dtype)
8
- return model
9
 
10
  def generate(
11
  prompt,
@@ -21,7 +18,7 @@ def generate(
21
  dtype="fp16",
22
  ):
23
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
24
- model = load_model(model_path, dtype)
25
 
26
  if torch.cuda.is_available():
27
  device = "cuda"
@@ -114,7 +111,7 @@ additional_inputs = [
114
  ),
115
  gr.Radio(
116
  choices=[("fp32", torch.float32), ("fp16", torch.float16)],
117
- value="fp16",
118
  label="Model Precision",
119
  info="fp32 is more precised, fp16 is faster and less memory consuming",
120
  ),
 
3
  import random
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
 
 
 
6
 
7
  def generate(
8
  prompt,
 
18
  dtype="fp16",
19
  ):
20
  tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
21
+ model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=dtype)
22
 
23
  if torch.cuda.is_available():
24
  device = "cuda"
 
111
  ),
112
  gr.Radio(
113
  choices=[("fp32", torch.float32), ("fp16", torch.float16)],
114
+ value=torch.float16,
115
  label="Model Precision",
116
  info="fp32 is more precised, fp16 is faster and less memory consuming",
117
  ),