Spaces:
Runtime error
Runtime error
debug...
Browse files- .gitignore +1 -0
- main.py +15 -13
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
# environment
|
2 |
bloom_demo
|
|
|
|
1 |
# environment
|
2 |
bloom_demo
|
3 |
+
tutorial-env
|
main.py
CHANGED
@@ -42,17 +42,19 @@ def model_init():
|
|
42 |
|
43 |
tokenizer, model = model_init()
|
44 |
|
45 |
-
|
46 |
-
#
|
47 |
-
prompt =
|
|
|
48 |
|
49 |
-
# =================== INFERENCE ==================== #
|
50 |
-
if prompt:
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
42 |
|
43 |
tokenizer, model = model_init()
|
44 |
|
45 |
+
try:
|
46 |
+
# ===================== INPUT ====================== #
|
47 |
+
# prompt = "\u554F\uFF1A\u53F0\u7063\u6700\u9AD8\u7684\u5EFA\u7BC9\u7269\u662F\uFF1F\u7B54\uFF1A" #@param {type:"string"}
|
48 |
+
prompt = st.text_input("Prompt: ")
|
49 |
|
50 |
+
# =================== INFERENCE ==================== #
|
51 |
+
if prompt:
|
52 |
+
with torch.no_grad():
|
53 |
+
[texts_out] = model.generate(
|
54 |
+
**tokenizer(
|
55 |
+
prompt, return_tensors="pt"
|
56 |
+
).to(device))
|
57 |
+
output_text = tokenizer.decode(texts_out)
|
58 |
+
st.markdown(output_text)
|
59 |
+
except Exception as err:
|
60 |
+
st.write(str(err))
|