Commit
·
6b65ad0
1
Parent(s):
1157ef0
Update main_dl.py
Browse files- main_dl.py +21 -11
main_dl.py
CHANGED
@@ -17,7 +17,6 @@ import csv
|
|
17 |
import tiktoken
|
18 |
from sklearn.preprocessing import LabelEncoder
|
19 |
from tensorflow import keras
|
20 |
-
# import keras
|
21 |
from keras_nlp.layers import TransformerEncoder
|
22 |
from tensorflow.keras import layers
|
23 |
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
@@ -46,6 +45,8 @@ def load_vocab(file_path):
|
|
46 |
def decode_sequence_rnn(input_sentence, src, tgt):
|
47 |
global translation_model
|
48 |
|
|
|
|
|
49 |
vocab_size = 15000
|
50 |
sequence_length = 50
|
51 |
|
@@ -180,9 +181,11 @@ class PositionalEmbedding(layers.Layer):
|
|
180 |
})
|
181 |
return config
|
182 |
|
183 |
-
def
|
184 |
global translation_model
|
185 |
|
|
|
|
|
186 |
vocab_size = 15000
|
187 |
sequence_length = 30
|
188 |
|
@@ -221,7 +224,7 @@ def decode_sequence_tranf(input_sentence, src, tgt):
|
|
221 |
|
222 |
# ==== End Transforformer section ====
|
223 |
|
224 |
-
def
|
225 |
|
226 |
merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
|
227 |
merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
|
@@ -229,7 +232,9 @@ def load_all_data():
|
|
229 |
rnn_fr_en = keras.models.load_model(dataPath+"/seq2seq_rnn-model-fr-en.h5") # , compile=False)
|
230 |
rnn_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
231 |
rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
232 |
-
|
|
|
|
|
233 |
custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
|
234 |
with keras.saving.custom_object_scope(custom_objects):
|
235 |
transformer_en_fr = keras.models.load_model( "data/transformer-model-en-fr.h5")
|
@@ -239,9 +244,10 @@ def load_all_data():
|
|
239 |
transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
240 |
transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
241 |
|
242 |
-
return
|
243 |
|
244 |
-
rnn_en_fr, rnn_fr_en
|
|
|
245 |
|
246 |
# ==== Language identifier ====
|
247 |
|
@@ -277,10 +283,13 @@ def init_dl_identifier():
|
|
277 |
else: print("dl_model vide")
|
278 |
return
|
279 |
|
|
|
|
|
280 |
def lang_id_dl(sentences):
|
281 |
global dl_model, label_encoder, lan_to_language
|
282 |
|
283 |
-
|
|
|
284 |
if "str" in str(type(sentences)): predictions = dl_model.predict(encode_text([sentences]))
|
285 |
else: predictions = dl_model.predict(encode_text(sentences))
|
286 |
# Décodage des prédictions en langues
|
@@ -293,7 +302,8 @@ def lang_id_dl(sentences):
|
|
293 |
|
294 |
@api.get('/', name="Vérification que l'API fonctionne")
|
295 |
def check_api():
|
296 |
-
|
|
|
297 |
init_dl_identifier()
|
298 |
return {'message': "L'API fonctionne"}
|
299 |
|
@@ -316,10 +326,10 @@ async def trad_transformer(lang_tgt:str,
|
|
316 |
|
317 |
if (lang_tgt=='en'):
|
318 |
translation_model = transformer_fr_en
|
319 |
-
return
|
320 |
else:
|
321 |
translation_model = transformer_en_fr
|
322 |
-
return
|
323 |
|
324 |
@api.get('/small_vocab/plot_model', name="Affiche le modèle")
|
325 |
def affiche_modele(lang_tgt:str,
|
@@ -345,5 +355,5 @@ def affiche_modele(lang_tgt:str,
|
|
345 |
return Response(content=image_data, media_type="image/png")
|
346 |
|
347 |
@api.get('/lang_id_dl', name="Id de langue par DL")
|
348 |
-
def language_id_dl(sentence:List[str] = Query(..., min_length=1)):
|
349 |
return lang_id_dl(sentence)
|
|
|
17 |
import tiktoken
|
18 |
from sklearn.preprocessing import LabelEncoder
|
19 |
from tensorflow import keras
|
|
|
20 |
from keras_nlp.layers import TransformerEncoder
|
21 |
from tensorflow.keras import layers
|
22 |
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
|
|
45 |
def decode_sequence_rnn(input_sentence, src, tgt):
|
46 |
global translation_model
|
47 |
|
48 |
+
if translation_model not in globals():
|
49 |
+
load_rnn()
|
50 |
vocab_size = 15000
|
51 |
sequence_length = 50
|
52 |
|
|
|
181 |
})
|
182 |
return config
|
183 |
|
184 |
+
def decode_sequence_transf(input_sentence, src, tgt):
|
185 |
global translation_model
|
186 |
|
187 |
+
if translation_model not in globals():
|
188 |
+
load_transformer()
|
189 |
vocab_size = 15000
|
190 |
sequence_length = 30
|
191 |
|
|
|
224 |
|
225 |
# ==== End Transforformer section ====
|
226 |
|
227 |
+
def load_rnn():
|
228 |
|
229 |
merge = Merge( dataPath+"/rnn_en-fr_split", dataPath, "seq2seq_rnn-model-en-fr.h5").merge(cleanup=False)
|
230 |
merge = Merge( dataPath+"/rnn_fr-en_split", dataPath, "seq2seq_rnn-model-fr-en.h5").merge(cleanup=False)
|
|
|
232 |
rnn_fr_en = keras.models.load_model(dataPath+"/seq2seq_rnn-model-fr-en.h5") # , compile=False)
|
233 |
rnn_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
234 |
rnn_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
235 |
+
return rnn_en_fr, rnn_fr_en
|
236 |
+
|
237 |
+
def load_transformer():
|
238 |
custom_objects = {"TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
|
239 |
with keras.saving.custom_object_scope(custom_objects):
|
240 |
transformer_en_fr = keras.models.load_model( "data/transformer-model-en-fr.h5")
|
|
|
244 |
transformer_en_fr.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
245 |
transformer_fr_en.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
|
246 |
|
247 |
+
return transformer_en_fr, transformer_fr_en
|
248 |
|
249 |
+
rnn_en_fr, rnn_fr_en = load_rnn()
|
250 |
+
transformer_en_fr, transformer_fr_en = load_transformer()
|
251 |
|
252 |
# ==== Language identifier ====
|
253 |
|
|
|
283 |
else: print("dl_model vide")
|
284 |
return
|
285 |
|
286 |
+
init_dl_identifier()
|
287 |
+
|
288 |
def lang_id_dl(sentences):
|
289 |
global dl_model, label_encoder, lan_to_language
|
290 |
|
291 |
+
if dl_model not in globals():
|
292 |
+
init_dl_identifier()
|
293 |
if "str" in str(type(sentences)): predictions = dl_model.predict(encode_text([sentences]))
|
294 |
else: predictions = dl_model.predict(encode_text(sentences))
|
295 |
# Décodage des prédictions en langues
|
|
|
302 |
|
303 |
@api.get('/', name="Vérification que l'API fonctionne")
|
304 |
def check_api():
|
305 |
+
load_rnn()
|
306 |
+
load_transformer()
|
307 |
init_dl_identifier()
|
308 |
return {'message': "L'API fonctionne"}
|
309 |
|
|
|
326 |
|
327 |
if (lang_tgt=='en'):
|
328 |
translation_model = transformer_fr_en
|
329 |
+
return decode_sequence_transf(texte, "fr", "en")
|
330 |
else:
|
331 |
translation_model = transformer_en_fr
|
332 |
+
return decode_sequence_transf(texte, "en", "fr")
|
333 |
|
334 |
@api.get('/small_vocab/plot_model', name="Affiche le modèle")
|
335 |
def affiche_modele(lang_tgt:str,
|
|
|
355 |
return Response(content=image_data, media_type="image/png")
|
356 |
|
357 |
@api.get('/lang_id_dl', name="Id de langue par DL")
|
358 |
+
async def language_id_dl(sentence:List[str] = Query(..., min_length=1)):
|
359 |
return lang_id_dl(sentence)
|