Proba_Nos / app.py
JJFrancisco's picture
Update app.py
c85b0b5 verified
import os
from dotenv import load_dotenv
import gradio as gr
from AinaTheme import theme
from huggingface_hub import snapshot_download
import subprocess
import os
from translate import translate_nos
load_dotenv()
MODELS_PATH = "./models"
HF_CACHE_DIR = "./hf_cache"
MAX_INPUT_CHARACTERS = int(os.environ.get("MAX_INPUT_CHARACTERS", default=1000))
LANGS_WITHOUT_SUBWORDING = ["English","Spanish","Galician"]
LANGS_WITH_SUBWORDING = ["Catalan","Basque"]
# Model paths e languages avaliables -----------------------------------------------------------
def download_model(repo_id, revision="main"):
return snapshot_download(repo_id=repo_id, revision=revision, local_dir=os.path.join(MODELS_PATH, repo_id), cache_dir=HF_CACHE_DIR)
def write_text_to_file(filename, text):
with open(filename, 'w') as file:
file.write(text)
""""
print("Downloading model gl-es...")
model_dir_gl_es = download_model("proxectonos/Nos_MT-OpenNMT-gl-es", revision="main")
print("Downloading model es-gl...")
model_dir_es_gl = download_model("proxectonos/Nos_MT-OpenNMT-es-gl", revision="main")
print("Downloading model gl-en...")
model_dir_gl_en = download_model("proxectonos/Nos_MT-OpenNMT-gl-en", revision="main")
print("Downloading model en-gl...")
model_dir_en_gl = download_model("proxectonos/Nos_MT-OpenNMT-en-gl", revision="main")
model_dir_gl_ca = ""
print("Downloading model ca-gl...")
model_dir_ca_gl = download_model("proxectonos/Nos_MT-OpenNMT-ca-gl", revision="main")
"""
model_dir_gl_es = model_dir_es_gl = model_dir_gl_en = model_dir_en_gl = model_dir_gl_ca = model_dir_gl_eu= " "
print("Downloading model ca-gl...")
model_dir_ca_gl = download_model("proxectonos/Nos_MT-OpenNMT-ca-gl", revision="main")
print("Downloading model eu-gl...")
model_dir_eu_gl = download_model("proxectonos/Nos_MT-OpenNMT-eu-gl", revision="main")
print("Downloading model gl-en...")
model_dir_gl_en = download_model("proxectonos/Nos_MT-OpenNMT-gl-en", revision="main")
print("Downloading model en-gl...")
model_dir_en_gl = download_model("proxectonos/Nos_MT-OpenNMT-en-gl", revision="main")
print("Models downloaded correctly!")
print(f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}")
print(os.listdir(f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}"))
directions_reduced = {
"Catalan": {
"target": {
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ca-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ct2_detok-ca-gl_sint_10k")},
}
},
"Basque": {
"target": {
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/gl-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/eu_gl.ct2_10k")},
}
}
}
directions = {
"Galician": {
"target": {
"Spanish": {"src": "gl", "tgt":"es","model": (f"{os.path.join(MODELS_PATH, model_dir_gl_es)}/bpe/es.code", f"{os.path.join(MODELS_PATH, model_dir_gl_es)}")},
"English": {"model": (f"{os.path.join(MODELS_PATH, model_dir_gl_en)}/bpe/en.code", f"{os.path.join(MODELS_PATH, model_dir_gl_en)}")},
"Catalan": {"model": (f"{os.path.join(MODELS_PATH, model_dir_gl_ca)}/bpe/ca.code", f"{os.path.join(MODELS_PATH, model_dir_gl_ca)}")},
"Basque": {"model": (f"{os.path.join(MODELS_PATH, model_dir_gl_eu)}/bpe/eu.code", f"{os.path.join(MODELS_PATH, model_dir_gl_eu)}")},
}
},
"Spanish": {
"target": {
"Galician": {"src": "es", "tgt":"gl","model": (f"{os.path.join(MODELS_PATH, model_dir_es_gl)}/bpe/gl.code", f"{os.path.join(MODELS_PATH, model_dir_es_gl)}")},
}
},
"English": {
"target": {
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_en_gl)}/bpe/gl.code", f"{os.path.join(MODELS_PATH, model_dir_en_gl)}")},
}
},
"Catalan": {
"target": {
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ca-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_ca_gl)}/ct2_detok-ca-gl_sint_10k")},
}
},
"Basque": {
"target": {
"Galician": {"model": (f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/gl-detok10k.code", f"{os.path.join(MODELS_PATH, model_dir_eu_gl)}/eu_gl.ct2_10k")},
}
}
}
DEFAULT_SOURCE_LANGUAGE = list(directions.keys())[0]
# Translation fuctions ------------------------------------------------------------------------------
def get_target_languages(source_language):
return list(directions.get(source_language, {}).get("target", {}).keys())
def get_target_language_model(source_language, target_language):
# return directions.get(source_language, {}).get("target", {}).get(target_language, {}).get("model")
return directions.get(source_language, {}).get("target", {}).get(target_language, {})
def translate(input, source_language, target_language):
translation = ""
if source_language in LANGS_WITHOUT_SUBWORDING: #ES, GL, EN
translation = translate_without_subwording(input, source_language, target_language)
elif source_language in LANGS_WITH_SUBWORDING: #CA, EU
translation = translate_with_subwording(input, source_language, target_language)
else:
raise Exception(f"Language {source_language} not supported")
return translation
def translate_without_subwording(input, source_language, target_language):
write_text_to_file('input.txt', input)
target_language_model = get_target_language_model(source_language, target_language)
command = f"onmt_translate -src input.txt -model {target_language_model.get('model')[1]} --output ./output_file.txt --replace_unk"
print("Comando: ",command)
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
if process.returncode != 0:
raise Exception(f"Error occurred: {stderr.decode().strip()}")
with open ('./output_file.txt','r') as f:
resultado= f.read()
return resultado
def translate_with_subwording(input, source_language, target_language):
target_language_model = get_target_language_model(source_language, target_language)
translation = translate_nos(input,target_language_model.get('model'))
return translation
# Gradio UI ------------------------------------------------------------------------------
def clear():
return None, None
def change_interactive(text):
if len(text.strip()) > MAX_INPUT_CHARACTERS:
return gr.update(interactive = True), gr.update(interactive = False)
return gr.update(interactive = True), gr.update(interactive = True)
def update_target_languages_dropdown(source_language):
output_languages = get_target_languages(source_language)
return gr.update(choices=output_languages, value=output_languages[0], interactive=True)
with gr.Blocks(theme=theme) as app:
with gr.Row(variant="panel"):
with gr.Column(scale=2):
placeholder_max_token = gr.Textbox(
visible=False,
interactive=False,
value= MAX_INPUT_CHARACTERS
)
source_language = gr.Dropdown(label="Source Language", choices=list(directions.keys()), value=DEFAULT_SOURCE_LANGUAGE)
input = gr.Textbox(placeholder="Enter a text here to translate.", max_lines=100, lines=12, show_label=False, interactive=True)
with gr.Row(variant="panel", equal_height=True):
gr.HTML("""<span id="countertext" style="display: flex; justify-content: start; color:#ef4444; font-weight: bold;"></span>""")
gr.HTML(f"""<span id="counter" style="display: flex; justify-content: end;"> <span id="inputlenght">0</span>&nbsp;/&nbsp;{MAX_INPUT_CHARACTERS}</span>""")
with gr.Column(scale=2):
target_outputs = get_target_languages(DEFAULT_SOURCE_LANGUAGE)
#target_language = gr.Dropdown(choices=target_outputs, label="Target Language", value=target_outputs[0])
target_language = gr.Radio(choices=target_outputs, label="Target Language", value=target_outputs[0])
output = gr.Textbox(max_lines=100, lines=12, show_label=False, interactive=False, show_copy_button=True)
with gr.Row(variant="panel"):
clear_btn = gr.Button(
"Clear",
)
submit_btn = gr.Button(
"Submit",
variant="primary",
)
source_language.change(fn=update_target_languages_dropdown, inputs=[source_language], outputs=target_language)
input.change(
fn=change_interactive,
inputs=[input],
outputs=[clear_btn, submit_btn],
api_name=False
)
input.change(
fn=None,
inputs=[input],
js=f"""(i) => document.getElementById('countertext').textContent = i.length > {MAX_INPUT_CHARACTERS} && 'Max length {MAX_INPUT_CHARACTERS} characters. ' || '' """,
api_name=False
)
input.change(
fn=None,
inputs=[input, placeholder_max_token],
js="""(i, m) => {
document.getElementById('inputlenght').textContent = i.length + ' '
document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
}""",
api_name=False
)
clear_btn.click(
fn=clear,
inputs=[],
outputs=[input, output],
queue=False,
api_name=False
)
submit_btn.click(
fn=translate,
inputs=[input, source_language, target_language],
outputs=[output],
api_name="translate",
concurrency_limit=1,
)
app.launch(show_api=True)