radio / app.py
demomodels's picture
Create app.py
a37aa0d verified
import gradio as gr
from TTS.api import TTS
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import feedparser
import re
language_map = {
'en': 'English',
'fr': 'French'
}
# Add default RSS feeds
rss_feed_map = {
"NY Times": 'https://rss.nytimes.com/services/xml/rss/nyt/HomePage.xml',
"Fox News": 'https://moxie.foxnews.com/google-publisher/latest.xml',
"Yahoo! News": 'https://www.yahoo.com/news/rss',
"France 24": 'https://www.france24.com/fr/rss',
"France Info": 'https://www.francetvinfo.fr/titres.rss'
}
def get_rss_feeds(default_choices, custom_choices):
custom_rss_feeds = custom_choices.split("\n")
if custom_rss_feeds == ['']:
return list(set([rss_feed_map[key] for key in default_choices]))
return list(set(custom_rss_feeds + [rss_feed_map[key] for key in default_choices]))
# RSS feeds
def is_url(string):
url_pattern = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
return re.match(url_pattern, string) is not None
def fetch_news(rss_feed):
if not is_url(rss_feed):
raise ValueError(f"{rss_feed} is not a valid RSS feed.")
news = []
feed = feedparser.parse(rss_feed)
for entry in feed.entries:
news.append(entry.title)
return news
def fetch_news_multiple_urls(rss_feeds):
return [news for rss_feed in rss_feeds for news in fetch_news(rss_feed)]
# Language_id
model_ckpt = "papluca/xlm-roberta-base-language-detection"
pipe = pipeline("text-classification", model=model_ckpt)
def language_id(strings:list[str]):
return [(string,language_map[pipe(string, top_k=1, truncation=True)[0]['label']]) for string in strings]
# Translation
## Initialize T5 model and tokenizer
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
def translate(source_text_with_id, target_language):
# source_text_with_id = ('text','French') for example
source_language = source_text_with_id[1]
assert source_language in language_map.values(), f"{source_language} language is not supported."
assert target_language in language_map.values(), f"{target_language} language is not supported."
source_text = f"translate {source_language} to {target_language}: " + source_text_with_id[0]
# Tokenize input text
input_ids = tokenizer.encode(source_text, return_tensors="pt")
# Generate translation
translated_ids = model.generate(input_ids=input_ids, max_length=100, num_beams=4, early_stopping=True)
# Decode translated text
return tokenizer.decode(translated_ids[0], skip_special_tokens=True)
def translate_multiple(source_texts_with_id, target_language):
return [translate(source_text_with_id, target_language) for source_text_with_id in source_texts_with_id]
# Speech generation
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
def read_news(text,input,output,language):
assert language in language_map.keys(), f"{language} language is not supported."
print("speech generation starting")
tts.tts_to_file(text=text,
file_path=output,
speaker_wav=input,
language=language)
print("speech generation done")
return output
# Gradio interface
def process(radio_value, textbox_value, audio_value, checkbox_value):
inputs = {
"language": radio_value,
"rss_feed_urls": textbox_value,
"audio": audio_value,
"selected_feeds": checkbox_value
}
print("Inputs to Gradio Blocks:")
print(inputs)
rss_feeds = get_rss_feeds(checkbox_value,textbox_value)
print("rss_feeds=",rss_feeds)
news = fetch_news_multiple_urls(rss_feeds)
print("news=",news[:2])
news_with_language_id = language_id(news)
print("news_with_language_id=",news_with_language_id[:2])
translated_news = translate_multiple(news_with_language_id, radio_value)
print("translated_news=",translated_news[:2])
language = next((key for key, val in language_map.items() if val == radio_value), None)
print("language=",language)
all_news = ' '.join(translated_news)
print("all_news=",all_news[:80])
output_path = "output.wav"
return read_news(all_news,audio_value,output_path,language)
with gr.Blocks() as demo:
gr.Markdown("Customize your newsletter and then click **Fetch News** to download the audio output.")
with gr.Row():
radio = gr.Radio(
label='Choose the language of the output',
info="If the output language doesn't match the language of an RSS feed, an AI model will take care of translation",
choices=["English", "French"]
)
with gr.Row():
textbox = gr.Textbox(
placeholder='https://www.francetvinfo.fr/titres.rss',
label='Add custom RSS feeds to your newsletter',
info='The provided urls needed to be written each in a separate line'
)
with gr.Row():
audio = gr.Audio(
label="Upload a sample audio of someone speaking. The voice of the output will match the voice of the input.",
type='filepath'
)
with gr.Row():
checkboxgroup = gr.CheckboxGroup(
["NY Times", "Fox News", "Yahoo! News", "France 24", "France Info"],
label="RSS feeds",
info="Default RSS feeds"
)
with gr.Row():
btn = gr.Button(value='Fetch News')
with gr.Row():
out = gr.DownloadButton("📂 Click to download file")
btn.click(
fn=process,
inputs=[radio, textbox, audio, checkboxgroup],
outputs=out
)
demo.launch(debug=True)