|
from transformers import MBart50Tokenizer, AutoModelForSeq2SeqLM, pipeline |
|
from langdetect import detect |
|
import re |
|
|
|
def load_models(): |
|
tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/mbart-large-50") |
|
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer) |
|
|
|
return tokenizer, summarizer |
|
|
|
tokenizer, summarizer = load_models() |
|
|
|
import streamlit as st |
|
LANGUAGE_CODES = { |
|
"en": "en_XX", |
|
"fr": "fr_XX", |
|
"de": "de_DE", |
|
"ru": "ru_RU", |
|
"hi": "hi_IN", |
|
"mr": "mr_IN", |
|
"ja": "ja_XX", |
|
} |
|
|
|
def detect_language(text): |
|
lang_code = detect(text) |
|
return lang_code |
|
|
|
|
|
|
|
|
|
def summarize_text(text, lang_code): |
|
mbart_lang_code = LANGUAGE_CODES.get(lang_code, "en_XX") |
|
inputs = tokenizer( |
|
f"<{mbart_lang_code}>{text}", |
|
return_tensors="pt", |
|
max_length=1024, |
|
truncation=True |
|
) |
|
summary_ids = summarizer.model.generate( |
|
inputs["input_ids"], |
|
max_length=100, |
|
min_length=30, |
|
length_penalty=2.0, |
|
num_beams=4 |
|
) |
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
summary = re.sub(r"<[^>]+>", "", summary).strip() |
|
return summary |
|
|
|
|
|
st.title("Multilingual Summarization App") |
|
|
|
|
|
user_input = st.text_area("Enter text in any language:", "") |
|
|
|
if st.button("Process Text"): |
|
if user_input.strip(): |
|
|
|
lang_code = detect_language(user_input) |
|
st.write(f"**Detected Language Code:** {lang_code}") |
|
|
|
if lang_code not in LANGUAGE_CODES: |
|
st.warning(f"The detected language ({lang_code}) is not supported by the model.") |
|
else: |
|
try: |
|
|
|
summary = summarize_text(user_input, lang_code) |
|
st.write(f"### Summarized Text ({lang_code}):") |
|
st.write(summary) |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during processing: {e}") |
|
else: |
|
st.warning("Please enter some text to process.") |
|
|