miwojc commited on
Commit
4bfba61
·
1 Parent(s): 28db877

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -142
app.py CHANGED
@@ -1,146 +1,13 @@
1
  import json
2
- import random
3
  import requests
4
- from mtranslate import translate
5
- import streamlit as st
6
- LOGO = "https://raw.githubusercontent.com/nlp-en-es/assets/main/logo.png"
7
- MODELS = {
8
- "Model trained on OSCAR": {
9
- "url": "https://api-inference.huggingface.co/models/flax-community/gpt-2-spanish"
10
- },
11
- "Model trained on the Large Spanish Corpus": {
12
- "url": "https://api-inference.huggingface.co/models/mrm8488/spanish-gpt2"
13
- },
14
- }
15
- PROMPT_LIST = {
16
- "Érase una vez...": ["Érase una vez "],
17
- "¡Hola!": ["¡Hola! Me llamo "],
18
- "¿Ser o no ser?": ["En mi opinión, 'ser' es "],
19
- }
20
- def query(payload, model_name):
21
  data = json.dumps(payload)
22
- print("model url:", MODELS[model_name]["url"])
23
- response = requests.request(
24
- "POST", MODELS[model_name]["url"], headers={}, data=data
25
- )
26
  return json.loads(response.content.decode("utf-8"))
27
- def process(
28
- text: str, model_name: str, max_len: int, temp: float, top_k: int, top_p: float
29
- ):
30
- payload = {
31
- "inputs": text,
32
- "parameters": {
33
- "max_new_tokens": max_len,
34
- "top_k": top_k,
35
- "top_p": top_p,
36
- "temperature": temp,
37
- "repetition_penalty": 2.0,
38
- },
39
- "options": {
40
- "use_cache": True,
41
- },
42
- }
43
- return query(payload, model_name)
44
- # Page
45
- st.set_page_config(page_title="Spanish GPT-2 Demo", page_icon=LOGO)
46
- st.title("Spanish GPT-2")
47
- # Sidebar
48
- st.sidebar.image(LOGO)
49
- st.sidebar.subheader("Configurable parameters")
50
- max_len = st.sidebar.number_input(
51
- "Maximum length",
52
- value=100,
53
- help="The maximum length of the sequence to be generated.",
54
- )
55
- temp = st.sidebar.slider(
56
- "Temperature",
57
- value=1.0,
58
- min_value=0.1,
59
- max_value=100.0,
60
- help="The value used to module the next token probabilities.",
61
- )
62
- top_k = st.sidebar.number_input(
63
- "Top k",
64
- value=10,
65
- help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
66
- )
67
- top_p = st.sidebar.number_input(
68
- "Top p",
69
- value=0.95,
70
- help=" If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
71
- )
72
- do_sample = st.sidebar.selectbox(
73
- "Sampling?",
74
- (True, False),
75
- help="Whether or not to use sampling; use greedy decoding otherwise.",
76
- )
77
- # Body
78
- st.markdown(
79
- """
80
- Spanish GPT-2 models trained from scratch on two different datasets. One
81
- model is trained on the Spanish portion of
82
- [OSCAR](https://huggingface.co/datasets/viewer/?dataset=oscar)
83
- and the other on the
84
- [large_spanish_corpus](https://huggingface.co/datasets/viewer/?dataset=large_spanish_corpus)
85
- aka BETO's corpus.
86
-
87
- The models are trained with Flax and using TPUs sponsored by Google since this is part of the
88
- [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104)
89
- organised by HuggingFace.
90
- """
91
- )
92
- model_name = st.selectbox("Model", (list(MODELS.keys())))
93
- ALL_PROMPTS = list(PROMPT_LIST.keys()) + ["Custom"]
94
- prompt = st.selectbox("Prompt", ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
95
- if prompt == "Custom":
96
- prompt_box = "Enter your text here"
97
- else:
98
- prompt_box = random.choice(PROMPT_LIST[prompt])
99
- text = st.text_area("Enter text", prompt_box)
100
- if st.button("Run"):
101
- with st.spinner(text="Getting results..."):
102
- st.subheader("Result")
103
- print(f"maxlen:{max_len}, temp:{temp}, top_k:{top_k}, top_p:{top_p}")
104
- result = process(
105
- text=text,
106
- model_name=model_name,
107
- max_len=int(max_len),
108
- temp=temp,
109
- top_k=int(top_k),
110
- top_p=float(top_p),
111
- )
112
- print("result:", result)
113
- if "error" in result:
114
- if type(result["error"]) is str:
115
- st.write(f'{result["error"]}.', end=" ")
116
- if "estimated_time" in result:
117
- st.write(
118
- f'Please try again in about {result["estimated_time"]:.0f} seconds.'
119
- )
120
- else:
121
- if type(result["error"]) is list:
122
- for error in result["error"]:
123
- st.write(f"{error}")
124
- else:
125
- result = result[0]["generated_text"]
126
- st.write(result.replace("\n", " \n"))
127
- st.text("English translation")
128
- st.write(translate(result, "en", "es").replace("\n", " \n"))
129
- st.markdown(
130
- """
131
- ### Team members
132
- - Manuel Romero ([mrm8488](https://huggingface.co/mrm8488))
133
- - María Grandury ([mariagrandury](https://huggingface.co/mariagrandury))
134
- - Pablo González de Prado ([Pablogps](https://huggingface.co/Pablogps))
135
- - Daniel Vera ([daveni](https://huggingface.co/daveni))
136
- - Sri Lakshmi ([srisweet](https://huggingface.co/srisweet))
137
- - José Posada ([jdposa](https://huggingface.co/jdposa))
138
- - Santiago Hincapie ([shpotes](https://huggingface.co/shpotes))
139
- - Jorge ([jorgealro](https://huggingface.co/jorgealro))
140
-
141
- ### More information
142
- You can find more information about these models in their cards:
143
- - [Model trained on OSCAR](https://huggingface.co/models/flax-community/gpt-2-spanish)
144
- - [Model trained on the Large Spanish Corpus](https://huggingface.co/mrm8488/spanish-gpt2)
145
- """
146
- )
 
1
  import json
2
+
3
  import requests
4
+
5
+ API_URL = "https://api-inference.huggingface.co/models/gpt2"
6
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
7
+
8
+ def query(payload):
 
 
 
 
 
 
 
 
 
 
 
 
9
  data = json.dumps(payload)
10
+ response = requests.request("POST", API_URL, headers=headers, data=data)
 
 
 
11
  return json.loads(response.content.decode("utf-8"))
12
+
13
+ data = query("Can you please let us know more details about your ")