trans-en-indic / app.py
darshankr's picture
Update app.py
a0c8166 verified
raw
history blame
3.83 kB
# app.py
import streamlit as st
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import torch
import asyncio
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import requests
import json
# Initialize models and processors
model = AutoModelForSeq2SeqLM.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B",
trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
def translate_text(sentences: List[str], target_lang: str):
try:
src_lang = "eng_Latn"
batch = ip.preprocess_batch(
sentences,
src_lang=src_lang,
tgt_lang=target_lang
)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
translations = ip.postprocess_batch(generated_tokens, lang=target_lang)
return {
"translations": translations,
"source_language": src_lang,
"target_language": target_lang
}
except Exception as e:
raise Exception(f"Translation failed: {str(e)}")
# Streamlit interface
def main():
st.title("Indic Language Translator")
# Input text
text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
# Language selection
target_languages = {
"Hindi": "hin_Deva",
"Bengali": "ben_Beng",
"Tamil": "tam_Taml",
"Telugu": "tel_Telu",
"Marathi": "mar_Deva",
"Gujarati": "guj_Gujr",
"Kannada": "kan_Knda",
"Malayalam": "mal_Mlym",
"Punjabi": "pan_Guru",
"Odia": "ori_Orya"
}
target_lang = st.selectbox(
"Select target language:",
options=list(target_languages.keys())
)
if st.button("Translate"):
try:
result = translate_text(
sentences=[text_input],
target_lang=target_languages[target_lang]
)
# Display result
st.success("Translation:")
st.write(result["translations"][0])
except Exception as e:
st.error(f"Translation failed: {str(e)}")
# Add API documentation
st.markdown("---")
st.header("API Documentation")
st.markdown("""
To use the translation API, send POST requests to:
```
https://USERNAME-SPACE_NAME.hf.space/translate
```
Request body format:
```json
{
"sentences": ["Your text here"],
"target_lang": "hin_Deva"
}
```
Available target languages:
- Hindi: `hin_Deva`
- Bengali: `ben_Beng`
- Tamil: `tam_Taml`
- Telugu: `tel_Telu`
- Marathi: `mar_Deva`
- Gujarati: `guj_Gujr`
- Kannada: `kan_Knda`
- Malayalam: `mal_Mlym`
- Punjabi: `pan_Guru`
- Odia: `ori_Orya`
""")
if __name__ == "__main__":
main()