Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -4,14 +4,7 @@ import random
|
|
4 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
5 |
|
6 |
def load_model(model_path, dtype):
|
7 |
-
|
8 |
-
torch_dtype = torch.float32
|
9 |
-
elif dtype == "fp16":
|
10 |
-
torch_dtype = torch.float16
|
11 |
-
else:
|
12 |
-
raise ValueError("Invalid dtype. Only 'fp32' or 'fp16' are supported.")
|
13 |
-
|
14 |
-
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch_dtype)
|
15 |
return model
|
16 |
|
17 |
def generate(
|
@@ -120,7 +113,7 @@ additional_inputs = [
|
|
120 |
info="A starting point to initiate the generation process"
|
121 |
),
|
122 |
gr.Radio(
|
123 |
-
choices=["fp32", "fp16"],
|
124 |
value="fp16",
|
125 |
label="Model Precision",
|
126 |
info="fp32 is more precised, fp16 is faster and less memory consuming",
|
|
|
4 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
5 |
|
6 |
def load_model(model_path, dtype):
|
7 |
+
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
return model
|
9 |
|
10 |
def generate(
|
|
|
113 |
info="A starting point to initiate the generation process"
|
114 |
),
|
115 |
gr.Radio(
|
116 |
+
choices=[("fp32", torch.float32), ("fp16", torch.float16)],
|
117 |
value="fp16",
|
118 |
label="Model Precision",
|
119 |
info="fp32 is more precised, fp16 is faster and less memory consuming",
|