= commited on
Commit
eb136bc
·
1 Parent(s): c9c7b63

adding app

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5ForConditionalGeneration, T5TokenizerFast
2
+ from torch.utils.data import DataLoader
3
+ import streamlit as st
4
+ import torch
5
+ import os
6
+
7
+
8
+ # Let us define the main page
9
+ st.markdown("Translation page 🔠")
10
+
11
+ # Dropdown for the translation type
12
+ translation_type = st.sidebar.selectbox("Translation Type", options=["French ➡️ Wolof", "Wolof ➡️ French"])
13
+
14
+ # define a dictionary of versions
15
+ models = {
16
+ "Version ✌️": {
17
+ "French ➡️ Wolof": {
18
+ "checkpoints": "wolof-translate/wolof_translate/checkpoints/t5_small_custom_train_results_fw_v4",
19
+ "tokenizer": "wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v4.json",
20
+ "max_len": None
21
+ }
22
+ },
23
+ "Version ☝️": {
24
+ "French ➡️ Wolof": {
25
+ "checkpoints": "wolof-translate/wolof_translate/checkpoints/t5_small_custom_train_results_fw_v3",
26
+ "tokenizer": "wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v3.json",
27
+ "max_len": 51
28
+ }
29
+ }
30
+ }
31
+
32
+ # Dropdown for the model version
33
+ version = st.sidebar.selectbox("Model version", options=["Version ☝️", "Version ✌️"])
34
+
35
+ # Recuperate the number of sentences to provide
36
+ number = st.sidebar.number_input("Give the number of sentences that you want to provide", min_value = 1,
37
+ max_value = 100)
38
+
39
+ # Recuperate the number of sentences to provide
40
+ temperature = st.sidebar.slider("How randomly need you the translated sentences to be from 0% to 100%", min_value = 0,
41
+ max_value = 100)
42
+
43
+
44
+ # make the process
45
+ try:
46
+ # recuperate checkpoints
47
+ checkpoints = torch.load(os.path.join(models[version][translation_type]['checkpoints'], "best_checkpoints.pth"))
48
+
49
+ # recuperate the tokenizer
50
+ tokenizer_file = models[version][translation_type]['tokenizer']
51
+
52
+ # recuperate the max length
53
+ max_len = models[version][translation_type]['max_len']
54
+
55
+ # let us get the best model
56
+ @st.cache_resource
57
+ def get_model():
58
+
59
+ # initialize the tokenizer
60
+ tokenizer = T5TokenizerFast(tokenizer_file=tokenizer_file)
61
+
62
+ # initialize the model
63
+ model_name = 't5-small'
64
+
65
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
66
+
67
+ # resize the token embeddings
68
+ model.resize_token_embeddings(len(tokenizer))
69
+
70
+ model.load_state_dict(checkpoints['model_state_dict'])
71
+
72
+
73
+ return model, tokenizer
74
+
75
+ model, tokenizer = get_model()
76
+
77
+ # set the model to eval mode
78
+ _ = model.eval()
79
+
80
+ # Add a title
81
+ st.header("Translate French sentences onto Wolof 👌")
82
+
83
+
84
+ # Recuperate two columns
85
+ left, right = st.columns(2)
86
+
87
+ # recuperate sentences
88
+ left.subheader('Give me some sentences in French: ')
89
+
90
+ for i in range(number):
91
+
92
+ left.text_input(f"- Sentence number {i + 1}", key = f"sentence{i}")
93
+
94
+ # run model inference on all test data
95
+ original_translations, predicted_translations, original_texts, scores = [], [], [], {}
96
+
97
+ # print a sentence recuperated from the session
98
+ right.subheader("Translation to Wolof:")
99
+
100
+ for i in range(number):
101
+
102
+ sentence = st.session_state[f"sentence{i}"] + tokenizer.eos_token
103
+
104
+ if not sentence == "":
105
+
106
+ # Let us encode the sentences
107
+ encoding = tokenizer([sentence], return_tensors='pt', max_length=max_len, padding='max_length', truncation=True)
108
+
109
+ # Let us recuperate the input ids
110
+ input_ids = encoding.input_ids
111
+
112
+ # Let us recuperate the mask
113
+ mask = encoding.attention_mask
114
+
115
+ # Let us recuperate the pad token id
116
+ pad_token_id = tokenizer.pad_token_id
117
+
118
+ # perform prediction
119
+ predictions = model.generate(input_ids, do_sample = False, top_k = 50, max_length = max_len, top_p = 0.90,
120
+ temperature = temperature/100, num_return_sequences = 0, attention_mask = mask, pad_token_id = pad_token_id)
121
+
122
+ # decode the predictions
123
+ predicted_sentence = tokenizer.batch_decode(predictions, skip_special_tokens = True)
124
+
125
+ # provide the prediction
126
+ right.write(f"{i+1}. {predicted_sentence[0]}")
127
+
128
+ else:
129
+
130
+ # provide the prediction
131
+ right.write(f"{i+1}. ")
132
+
133
+ except Exception as e:
134
+
135
+ st.warning("The chosen model is not available yet !", icon = "⚠️")
136
+
137
+ # st.write(e)
138
+
139
+
140
+