Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,146 +1,13 @@
|
|
1 |
import json
|
2 |
-
|
3 |
import requests
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
"url": "https://api-inference.huggingface.co/models/flax-community/gpt-2-spanish"
|
10 |
-
},
|
11 |
-
"Model trained on the Large Spanish Corpus": {
|
12 |
-
"url": "https://api-inference.huggingface.co/models/mrm8488/spanish-gpt2"
|
13 |
-
},
|
14 |
-
}
|
15 |
-
PROMPT_LIST = {
|
16 |
-
"Érase una vez...": ["Érase una vez "],
|
17 |
-
"¡Hola!": ["¡Hola! Me llamo "],
|
18 |
-
"¿Ser o no ser?": ["En mi opinión, 'ser' es "],
|
19 |
-
}
|
20 |
-
def query(payload, model_name):
|
21 |
data = json.dumps(payload)
|
22 |
-
|
23 |
-
response = requests.request(
|
24 |
-
"POST", MODELS[model_name]["url"], headers={}, data=data
|
25 |
-
)
|
26 |
return json.loads(response.content.decode("utf-8"))
|
27 |
-
|
28 |
-
|
29 |
-
):
|
30 |
-
payload = {
|
31 |
-
"inputs": text,
|
32 |
-
"parameters": {
|
33 |
-
"max_new_tokens": max_len,
|
34 |
-
"top_k": top_k,
|
35 |
-
"top_p": top_p,
|
36 |
-
"temperature": temp,
|
37 |
-
"repetition_penalty": 2.0,
|
38 |
-
},
|
39 |
-
"options": {
|
40 |
-
"use_cache": True,
|
41 |
-
},
|
42 |
-
}
|
43 |
-
return query(payload, model_name)
|
44 |
-
# Page
|
45 |
-
st.set_page_config(page_title="Spanish GPT-2 Demo", page_icon=LOGO)
|
46 |
-
st.title("Spanish GPT-2")
|
47 |
-
# Sidebar
|
48 |
-
st.sidebar.image(LOGO)
|
49 |
-
st.sidebar.subheader("Configurable parameters")
|
50 |
-
max_len = st.sidebar.number_input(
|
51 |
-
"Maximum length",
|
52 |
-
value=100,
|
53 |
-
help="The maximum length of the sequence to be generated.",
|
54 |
-
)
|
55 |
-
temp = st.sidebar.slider(
|
56 |
-
"Temperature",
|
57 |
-
value=1.0,
|
58 |
-
min_value=0.1,
|
59 |
-
max_value=100.0,
|
60 |
-
help="The value used to module the next token probabilities.",
|
61 |
-
)
|
62 |
-
top_k = st.sidebar.number_input(
|
63 |
-
"Top k",
|
64 |
-
value=10,
|
65 |
-
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
|
66 |
-
)
|
67 |
-
top_p = st.sidebar.number_input(
|
68 |
-
"Top p",
|
69 |
-
value=0.95,
|
70 |
-
help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
|
71 |
-
)
|
72 |
-
do_sample = st.sidebar.selectbox(
|
73 |
-
"Sampling?",
|
74 |
-
(True, False),
|
75 |
-
help="Whether or not to use sampling; use greedy decoding otherwise.",
|
76 |
-
)
|
77 |
-
# Body
|
78 |
-
st.markdown(
|
79 |
-
"""
|
80 |
-
Spanish GPT-2 models trained from scratch on two different datasets. One
|
81 |
-
model is trained on the Spanish portion of
|
82 |
-
[OSCAR](https://huggingface.co/datasets/viewer/?dataset=oscar)
|
83 |
-
and the other on the
|
84 |
-
[large_spanish_corpus](https://huggingface.co/datasets/viewer/?dataset=large_spanish_corpus)
|
85 |
-
aka BETO's corpus.
|
86 |
-
|
87 |
-
The models are trained with Flax and using TPUs sponsored by Google since this is part of the
|
88 |
-
[Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104)
|
89 |
-
organised by HuggingFace.
|
90 |
-
"""
|
91 |
-
)
|
92 |
-
model_name = st.selectbox("Model", (list(MODELS.keys())))
|
93 |
-
ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
|
94 |
-
prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
|
95 |
-
if prompt == "Custom":
|
96 |
-
prompt_box = "Enter your text here"
|
97 |
-
else:
|
98 |
-
prompt_box = random.choice(PROMPT_LIST[prompt])
|
99 |
-
text = st.text_area("Enter text", prompt_box)
|
100 |
-
if st.button("Run"):
|
101 |
-
with st.spinner(text="Getting results..."):
|
102 |
-
st.subheader("Result")
|
103 |
-
print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
|
104 |
-
result = process(
|
105 |
-
text=text,
|
106 |
-
model_name=model_name,
|
107 |
-
max_len=int(max_len),
|
108 |
-
temp=temp,
|
109 |
-
top_k=int(top_k),
|
110 |
-
top_p=float(top_p),
|
111 |
-
)
|
112 |
-
print("result:", result)
|
113 |
-
if "error" in result:
|
114 |
-
if type(result["error"]) is str:
|
115 |
-
st.write(f'{result["error"]}.', end=" ")
|
116 |
-
if "estimated_time" in result:
|
117 |
-
st.write(
|
118 |
-
f'Please try again in about {result["estimated_time"]:.0f} seconds.'
|
119 |
-
)
|
120 |
-
else:
|
121 |
-
if type(result["error"]) is list:
|
122 |
-
for error in result["error"]:
|
123 |
-
st.write(f"{error}")
|
124 |
-
else:
|
125 |
-
result = result[0]["generated_text"]
|
126 |
-
st.write(result.replace("\n", " \n"))
|
127 |
-
st.text("English translation")
|
128 |
-
st.write(translate(result, "en", "es").replace("\n", " \n"))
|
129 |
-
st.markdown(
|
130 |
-
"""
|
131 |
-
### Team members
|
132 |
-
- Manuel Romero ([mrm8488](https://huggingface.co/mrm8488))
|
133 |
-
- María Grandury ([mariagrandury](https://huggingface.co/mariagrandury))
|
134 |
-
- Pablo González de Prado ([Pablogps](https://huggingface.co/Pablogps))
|
135 |
-
- Daniel Vera ([daveni](https://huggingface.co/daveni))
|
136 |
-
- Sri Lakshmi ([srisweet](https://huggingface.co/srisweet))
|
137 |
-
- José Posada ([jdposa](https://huggingface.co/jdposa))
|
138 |
-
- Santiago Hincapie ([shpotes](https://huggingface.co/shpotes))
|
139 |
-
- Jorge ([jorgealro](https://huggingface.co/jorgealro))
|
140 |
-
|
141 |
-
### More information
|
142 |
-
You can find more information about these models in their cards:
|
143 |
-
- [Model trained on OSCAR](https://huggingface.co/models/flax-community/gpt-2-spanish)
|
144 |
-
- [Model trained on the Large Spanish Corpus](https://huggingface.co/mrm8488/spanish-gpt2)
|
145 |
-
"""
|
146 |
-
)
|
|
|
1 |
import json
|
2 |
+
|
3 |
import requests
|
4 |
+
|
5 |
+
API_URL = "https://api-inference.huggingface.co/models/gpt2"
|
6 |
+
headers = {"Authorization": f"Bearer {API_TOKEN}"}
|
7 |
+
|
8 |
+
def query(payload):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
data = json.dumps(payload)
|
10 |
+
response = requests.request("POST", API_URL, headers=headers, data=data)
|
|
|
|
|
|
|
11 |
return json.loads(response.content.decode("utf-8"))
|
12 |
+
|
13 |
+
data = query("Can you please let us know more details about your ")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|