|
import gradio as gr |
|
from transformers import MarianMTModel, MarianTokenizer |
|
import speech_recognition as sr |
|
from gtts import gTTS |
|
import pyttsx3 |
|
import tempfile |
|
import re |
|
|
|
|
|
|
|
|
|
def clean_text(text): |
|
"""Clean text to handle punctuations and special characters.""" |
|
return re.sub(r'[^\w\s,.!?;:]', '', text) |
|
|
|
|
|
tokenizer_en2zh = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh") |
|
model_en2zh = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-zh") |
|
|
|
tokenizer_zh2en = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en") |
|
model_zh2en = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-zh-en") |
|
|
|
def translate_text(input_text, target_language): |
|
cleaned_text = clean_text(input_text) |
|
if target_language == "en": |
|
tokenizer, model = tokenizer_zh2en, model_zh2en |
|
else: |
|
tokenizer, model = tokenizer_en2zh, model_en2zh |
|
|
|
inputs = tokenizer(cleaned_text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
translated_tokens = model.generate(**inputs) |
|
translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) |
|
return translated_text |
|
|
|
def speech_to_text(audio_file): |
|
recognizer = sr.Recognizer() |
|
with sr.AudioFile(audio_file) as source: |
|
audio = recognizer.record(source) |
|
return recognizer.recognize_google(audio) |
|
|
|
def synthesize_speech(text, language, method='gtts'): |
|
if method == 'gtts': |
|
tts = gTTS(text=text, lang=language) |
|
temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") |
|
tts.save(temp_audio.name) |
|
elif method == 'pyttsx3': |
|
engine = pyttsx3.init() |
|
engine.save_to_file(text, temp_audio.name) |
|
engine.runAndWait() |
|
return temp_audio.name |
|
|
|
def process_input(input_text=None, audio_file=None, method='gtts'): |
|
recognized_text = "" |
|
if audio_file: |
|
recognized_text = speech_to_text(audio_file) |
|
elif input_text: |
|
recognized_text = input_text |
|
else: |
|
return "No input provided", "", None |
|
|
|
detected_lang = "zh" if re.search(r'[\u4e00-\u9fff]', recognized_text) else "en" |
|
target_language = "zh" if detected_lang == "en" else "en" |
|
|
|
translated_text = translate_text(recognized_text, target_language) |
|
audio_path = synthesize_speech(translated_text, target_language, method=method) |
|
|
|
return recognized_text, translated_text, audio_path |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 中英文互译及语音播报") |
|
|
|
inputs = [ |
|
gr.Textbox(label="Enter Text"), |
|
gr.Audio(sources=["upload","microphone"], type="filepath", label="Or speak into the microphone") |
|
] |
|
|
|
translate_button = gr.Button("翻译并播报") |
|
|
|
outputs = [ |
|
gr.Textbox(label="Recognized Text"), |
|
gr.Textbox(label="翻译结果"), |
|
gr.Audio(label="语音播报") |
|
] |
|
|
|
translate_button.click(fn=process_input, inputs=inputs, outputs=outputs) |
|
|
|
demo.launch(share=True) |
|
|