|
import os |
|
import sys |
|
import glob |
|
import requests |
|
from urllib.parse import urlencode |
|
from dotenv import dotenv_values |
|
import traceback |
|
import time |
|
|
|
flores_to_iso = { |
|
"asm_Beng": "as", |
|
"ben_Beng": "bn", |
|
"brx_Deva": "brx", |
|
"doi_Deva": "doi", |
|
"eng_Latn": "en", |
|
"gom_Deva": "gom", |
|
"guj_Gujr": "gu", |
|
"hin_Deva": "hi", |
|
"kan_Knda": "kn", |
|
"kas_Arab": "ks", |
|
"kas_Deva": "ks_Deva", |
|
"mai_Deva": "mai", |
|
"mal_Mlym": "ml", |
|
"mar_Deva": "mr", |
|
"mni_Beng": "mni_Beng", |
|
"mni_Mtei": "mni", |
|
"npi_Deva": "ne", |
|
"ory_Orya": "or", |
|
"pan_Guru": "pa", |
|
"san_Deva": "sa", |
|
"sat_Olck": "sat", |
|
"snd_Arab": "sd", |
|
"snd_Deva": "sd_Deva", |
|
"tam_Taml": "ta", |
|
"tel_Telu": "te", |
|
"urd_Arab": "ur", |
|
} |
|
|
|
|
|
class AzureTranslator: |
|
def __init__( |
|
self, |
|
subscription_key: str, |
|
region: str, |
|
endpoint: str = "https://api.cognitive.microsofttranslator.com", |
|
) -> None: |
|
self.http_headers = { |
|
"Ocp-Apim-Subscription-Key": subscription_key, |
|
"Ocp-Apim-Subscription-Region": region, |
|
} |
|
self.translate_endpoint = endpoint + "/translate?api-version=3.0&" |
|
self.languages_endpoint = endpoint + "/languages?api-version=3.0" |
|
|
|
self.supported_languages = self.get_supported_languages() |
|
|
|
def get_supported_languages(self) -> dict: |
|
return requests.get(self.languages_endpoint).json()["translation"] |
|
|
|
def batch_translate(self, texts: list, src_lang: str, tgt_lang: str) -> list: |
|
if not texts: |
|
return texts |
|
|
|
src_lang = flores_to_iso[src_lang] |
|
tgt_lang = flores_to_iso[tgt_lang] |
|
|
|
if src_lang not in self.supported_languages: |
|
raise NotImplementedError( |
|
f"Source language code: `{src_lang}` not supported!" |
|
) |
|
|
|
if tgt_lang not in self.supported_languages: |
|
raise NotImplementedError( |
|
f"Target language code: `{tgt_lang}` not supported!" |
|
) |
|
|
|
body = [{"text": text} for text in texts] |
|
query_string = urlencode( |
|
{ |
|
"from": src_lang, |
|
"to": tgt_lang, |
|
} |
|
) |
|
|
|
try: |
|
response = requests.post( |
|
self.translate_endpoint + query_string, |
|
headers=self.http_headers, |
|
json=body, |
|
) |
|
except: |
|
traceback.print_exc() |
|
return None |
|
|
|
try: |
|
response = response.json() |
|
except: |
|
traceback.print_exc() |
|
print("Response:", response.text) |
|
return None |
|
|
|
return [payload["translations"][0]["text"] for payload in response] |
|
|
|
def text_translate(self, text: str, src_lang: str, tgt_lang: str) -> str: |
|
return self.batch_translate([text], src_lang, tgt_lang)[0] |
|
|
|
|
|
if __name__ == "__main__": |
|
root_dir = sys.argv[1] |
|
|
|
|
|
config = dotenv_values(os.path.join(os.path.dirname(__file__), ".env")) |
|
|
|
t = AzureTranslator( |
|
config["AZURE_TRANSLATOR_TEXT_SUBSCRIPTION_KEY"], |
|
config["AZURE_TRANSLATOR_TEXT_REGION"], |
|
config["AZURE_TRANSLATOR_TEXT_ENDPOINT"], |
|
) |
|
|
|
pairs = sorted(glob.glob(os.path.join(root_dir, "*"))) |
|
|
|
for i, pair in enumerate(pairs): |
|
basename = os.path.basename(pair) |
|
|
|
print(pair) |
|
|
|
src_lang, tgt_lang = basename.split("-") |
|
|
|
print(f"{src_lang} - {tgt_lang}") |
|
|
|
|
|
src_infname = os.path.join(pair, f"test.{src_lang}") |
|
tgt_outfname = os.path.join(pair, f"test.{tgt_lang}.pred.azure") |
|
if not os.path.exists(src_infname): |
|
continue |
|
|
|
src_sents = [ |
|
sent.replace("\n", "").strip() |
|
for sent in open(src_infname, "r").read().split("\n") |
|
if sent |
|
] |
|
|
|
if not os.path.exists(tgt_outfname): |
|
try: |
|
translations = [] |
|
for i in range(0, len(src_sents), 128): |
|
start, end = i, int(min(i + 128, len(src_sents))) |
|
translations.extend( |
|
t.batch_translate(src_sents[start:end], src_lang, tgt_lang) |
|
) |
|
with open(tgt_outfname, "w") as f: |
|
f.write("\n".join(translations)) |
|
|
|
time.sleep(10) |
|
except Exception as e: |
|
print(e) |
|
continue |
|
|
|
|
|
tgt_infname = os.path.join(pair, f"test.{tgt_lang}") |
|
src_outfname = os.path.join(pair, f"test.{src_lang}.pred.azure") |
|
if not os.path.exists(tgt_infname): |
|
continue |
|
|
|
tgt_sents = [ |
|
sent.replace("\n", "").strip() |
|
for sent in open(tgt_infname, "r").read().split("\n") |
|
if sent |
|
] |
|
|
|
if not os.path.exists(src_outfname): |
|
try: |
|
translations = [] |
|
for i in range(0, len(tgt_sents), 128): |
|
start, end = i, int(min(i + 128, len(tgt_sents))) |
|
translations.extend( |
|
t.batch_translate(tgt_sents[start:end], tgt_lang, src_lang) |
|
) |
|
with open(src_outfname, "w") as f: |
|
f.write("\n".join(translations)) |
|
except Exception as e: |
|
continue |
|
|
|
time.sleep(10) |
|
|