Nick088 commited on
Commit
f2e8fa9
·
verified ·
1 Parent(s): 7500dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -9
app.py CHANGED
@@ -4,14 +4,7 @@ import random
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
  def load_model(model_path, dtype):
7
- if dtype == "fp32":
8
- torch_dtype = torch.float32
9
- elif dtype == "fp16":
10
- torch_dtype = torch.float16
11
- else:
12
- raise ValueError("Invalid dtype. Only 'fp32' or 'fp16' are supported.")
13
-
14
- model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch_dtype)
15
  return model
16
 
17
  def generate(
@@ -120,7 +113,7 @@ additional_inputs = [
120
  info="A starting point to initiate the generation process"
121
  ),
122
  gr.Radio(
123
- choices=["fp32", "fp16"],
124
  value="fp16",
125
  label="Model Precision",
126
  info="fp32 is more precised, fp16 is faster and less memory consuming",
 
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
  def load_model(model_path, dtype):
7
+ model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=dtype)
 
 
 
 
 
 
 
8
  return model
9
 
10
  def generate(
 
113
  info="A starting point to initiate the generation process"
114
  ),
115
  gr.Radio(
116
+ choices=[("fp32", torch.float32), ("fp16", torch.float16)],
117
  value="fp16",
118
  label="Model Precision",
119
  info="fp32 is more precised, fp16 is faster and less memory consuming",