Spaces:
Sleeping
Sleeping
Model selector
Browse files
app.py
CHANGED
@@ -11,7 +11,17 @@ st.title("Streamlit + Transformers")
|
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
from transformers import MarianMTModel, MarianTokenizer
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
@st.experimental_singleton
|
17 |
def get_tokenizer(model_name):
|
@@ -26,6 +36,9 @@ def get_model(model_name):
|
|
26 |
tokenizer = get_tokenizer(model_name)
|
27 |
model = get_model(model_name)
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
|
31 |
input_text = input_text.strip()
|
@@ -52,7 +65,7 @@ with torch.no_grad():
|
|
52 |
|
53 |
last_token_logits = model_output.logits[0, -1].cpu()
|
54 |
assert len(last_token_logits.shape) == 1
|
55 |
-
most_likely_tokens = last_token_logits.topk(k=
|
56 |
|
57 |
probs = last_token_logits.softmax(dim=-1)
|
58 |
probs_for_likely_tokens = probs[most_likely_tokens.indices]
|
|
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
from transformers import MarianMTModel, MarianTokenizer
|
14 |
+
|
15 |
+
model_name = st.radio("Select a model", [
|
16 |
+
'Helsinki-NLP/opus-mt-roa-en',
|
17 |
+
'Helsinki-NLP/opus-mt-en-roa',
|
18 |
+
'other'
|
19 |
+
])
|
20 |
+
|
21 |
+
if model_name == 'other':
|
22 |
+
model_name = st.text_input("Enter model name", 'Helsinki-NLP/opus-mt-ROMANCE-en')
|
23 |
+
|
24 |
+
|
25 |
|
26 |
@st.experimental_singleton
|
27 |
def get_tokenizer(model_name):
|
|
|
36 |
tokenizer = get_tokenizer(model_name)
|
37 |
model = get_model(model_name)
|
38 |
|
39 |
+
if tokenizer.supported_language_codes is not None:
|
40 |
+
st.write(f"Supported languages: {tokenizer.supported_language_codes}")
|
41 |
+
|
42 |
|
43 |
input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
|
44 |
input_text = input_text.strip()
|
|
|
65 |
|
66 |
last_token_logits = model_output.logits[0, -1].cpu()
|
67 |
assert len(last_token_logits.shape) == 1
|
68 |
+
most_likely_tokens = last_token_logits.topk(k=20)
|
69 |
|
70 |
probs = last_token_logits.softmax(dim=-1)
|
71 |
probs_for_likely_tokens = probs[most_likely_tokens.indices]
|