Spaces:
Runtime error
Runtime error
Stefan Dumitrescu
commited on
Commit
·
c44f938
1
Parent(s):
0957f7e
Update
Browse files
app.py
CHANGED
@@ -33,9 +33,9 @@ top_p = st.sidebar.slider("Top-p", min_value=0.0, max_value=1.0, step=0.05, valu
|
|
33 |
|
34 |
|
35 |
@st.cache(allow_output_mutation=True)
|
36 |
-
def setModel(
|
37 |
-
model = AutoModelWithLMHead.from_pretrained(
|
38 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
39 |
return model, tokenizer
|
40 |
|
41 |
def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top_p):
|
@@ -52,7 +52,7 @@ def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top
|
|
52 |
|
53 |
return output_sequences
|
54 |
|
55 |
-
|
56 |
output_sequences = infer(model, tokenizer, text_element, input_ids, max_length, temperature, top_k, top_p)
|
57 |
|
58 |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
|
|
33 |
|
34 |
|
35 |
@st.cache(allow_output_mutation=True)
|
36 |
+
def setModel(model_checkpoint):
|
37 |
+
model = AutoModelWithLMHead.from_pretrained(model_checkpoint)
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
39 |
return model, tokenizer
|
40 |
|
41 |
def infer(model, tokenizer, text, input_ids, max_length, temperature, top_k, top_p):
|
|
|
52 |
|
53 |
return output_sequences
|
54 |
|
55 |
+
model, tokenizer = setModel(model_checkpoint)
|
56 |
output_sequences = infer(model, tokenizer, text_element, input_ids, max_length, temperature, top_k, top_p)
|
57 |
|
58 |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|