Spaces:
Sleeping
Sleeping
Make it fun
Browse files
app.py
CHANGED
@@ -6,8 +6,6 @@ import torch.nn.functional as F
|
|
6 |
import transformers
|
7 |
import pandas as pd
|
8 |
|
9 |
-
st.title("Streamlit + Transformers")
|
10 |
-
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
from transformers import MarianMTModel, MarianTokenizer
|
@@ -21,13 +19,15 @@ model_name = st.radio("Select a model", [
|
|
21 |
if model_name == 'other':
|
22 |
model_name = st.text_input("Enter model name", 'Helsinki-NLP/opus-mt-ROMANCE-en')
|
23 |
|
|
|
|
|
24 |
|
25 |
|
26 |
-
@st.
|
27 |
def get_tokenizer(model_name):
|
28 |
return MarianTokenizer.from_pretrained(model_name)
|
29 |
|
30 |
-
@st.
|
31 |
def get_model(model_name):
|
32 |
model = MarianMTModel.from_pretrained(model_name).to(device)
|
33 |
print(f"Loaded model, {model.num_parameters():,d} parameters.")
|
@@ -36,8 +36,10 @@ def get_model(model_name):
|
|
36 |
tokenizer = get_tokenizer(model_name)
|
37 |
model = get_model(model_name)
|
38 |
|
39 |
-
if tokenizer.supported_language_codes
|
40 |
-
st.
|
|
|
|
|
41 |
|
42 |
|
43 |
input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
|
@@ -45,10 +47,22 @@ input_text = input_text.strip()
|
|
45 |
if not input_text:
|
46 |
st.stop()
|
47 |
|
|
|
|
|
|
|
|
|
48 |
output_so_far = st.text_input("Enter text translated so far", "Hello, my")
|
49 |
|
50 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# tokenize the output so far
|
53 |
with tokenizer.as_target_tokenizer():
|
54 |
output_tokens = tokenizer.tokenize(output_so_far)
|
@@ -62,7 +76,6 @@ with torch.no_grad():
|
|
62 |
input_ids = input_ids,
|
63 |
decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))
|
64 |
|
65 |
-
|
66 |
last_token_logits = model_output.logits[0, -1].cpu()
|
67 |
assert len(last_token_logits.shape) == 1
|
68 |
most_likely_tokens = last_token_logits.topk(k=20)
|
@@ -79,5 +92,11 @@ with tokenizer.as_target_tokenizer():
|
|
79 |
'cumulative probability': probs_for_likely_tokens.cumsum(0)
|
80 |
})
|
81 |
|
82 |
-
|
83 |
st.write(probs_table)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import transformers
|
7 |
import pandas as pd
|
8 |
|
|
|
|
|
9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
|
11 |
from transformers import MarianMTModel, MarianTokenizer
|
|
|
19 |
if model_name == 'other':
|
20 |
model_name = st.text_input("Enter model name", 'Helsinki-NLP/opus-mt-ROMANCE-en')
|
21 |
|
22 |
+
if not hasattr(st, "cache_resource"):
|
23 |
+
st.cache_resource = st.experimental_singleton
|
24 |
|
25 |
|
26 |
+
@st.cache_resource
|
27 |
def get_tokenizer(model_name):
|
28 |
return MarianTokenizer.from_pretrained(model_name)
|
29 |
|
30 |
+
@st.cache_resource
|
31 |
def get_model(model_name):
|
32 |
model = MarianMTModel.from_pretrained(model_name).to(device)
|
33 |
print(f"Loaded model, {model.num_parameters():,d} parameters.")
|
|
|
36 |
tokenizer = get_tokenizer(model_name)
|
37 |
model = get_model(model_name)
|
38 |
|
39 |
+
if tokenizer.supported_language_codes:
|
40 |
+
lang_code = st.selectbox("Select a language", tokenizer.supported_language_codes)
|
41 |
+
else:
|
42 |
+
lang_code = None
|
43 |
|
44 |
|
45 |
input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
|
|
|
47 |
if not input_text:
|
48 |
st.stop()
|
49 |
|
50 |
+
# prepend the language code if necessary
|
51 |
+
if lang_code:
|
52 |
+
input_text = f"{lang_code} {input_text}"
|
53 |
+
|
54 |
output_so_far = st.text_input("Enter text translated so far", "Hello, my")
|
55 |
|
56 |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
57 |
|
58 |
+
example_generations = model.generate(
|
59 |
+
input_ids,
|
60 |
+
num_beams=4,
|
61 |
+
num_return_sequences=4,
|
62 |
+
)
|
63 |
+
st.write("Example generations:")
|
64 |
+
st.write(tokenizer.batch_decode(example_generations, skip_special_tokens=True))
|
65 |
+
|
66 |
# tokenize the output so far
|
67 |
with tokenizer.as_target_tokenizer():
|
68 |
output_tokens = tokenizer.tokenize(output_so_far)
|
|
|
76 |
input_ids = input_ids,
|
77 |
decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))
|
78 |
|
|
|
79 |
last_token_logits = model_output.logits[0, -1].cpu()
|
80 |
assert len(last_token_logits.shape) == 1
|
81 |
most_likely_tokens = last_token_logits.topk(k=20)
|
|
|
92 |
'cumulative probability': probs_for_likely_tokens.cumsum(0)
|
93 |
})
|
94 |
|
|
|
95 |
st.write(probs_table)
|
96 |
+
|
97 |
+
loss_table = pd.DataFrame({
|
98 |
+
'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
|
99 |
+
'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()
|
100 |
+
})
|
101 |
+
st.write(loss_table)
|
102 |
+
st.write("Total loss so far:", loss_table.loss.sum())
|