Spaces:
Runtime error
Runtime error
import gradio as gr | |
from TTS.api import TTS | |
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline | |
import feedparser | |
import re | |
language_map = { | |
'en': 'English', | |
'fr': 'French' | |
} | |
# Add default RSS feeds | |
rss_feed_map = { | |
"NY Times": 'https://rss.nytimes.com/services/xml/rss/nyt/HomePage.xml', | |
"Fox News": 'https://moxie.foxnews.com/google-publisher/latest.xml', | |
"Yahoo! News": 'https://www.yahoo.com/news/rss', | |
"France 24": 'https://www.france24.com/fr/rss', | |
"France Info": 'https://www.francetvinfo.fr/titres.rss' | |
} | |
def get_rss_feeds(default_choices, custom_choices): | |
custom_rss_feeds = custom_choices.split("\n") | |
if custom_rss_feeds == ['']: | |
return list(set([rss_feed_map[key] for key in default_choices])) | |
return list(set(custom_rss_feeds + [rss_feed_map[key] for key in default_choices])) | |
# RSS feeds | |
def is_url(string): | |
url_pattern = re.compile( | |
r'^(?:http|ftp)s?://' # http:// or https:// | |
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... | |
r'localhost|' # localhost... | |
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip | |
r'(?::\d+)?' # optional port | |
r'(?:/?|[/?]\S+)$', re.IGNORECASE) | |
return re.match(url_pattern, string) is not None | |
def fetch_news(rss_feed): | |
if not is_url(rss_feed): | |
raise ValueError(f"{rss_feed} is not a valid RSS feed.") | |
news = [] | |
feed = feedparser.parse(rss_feed) | |
for entry in feed.entries: | |
news.append(entry.title) | |
return news | |
def fetch_news_multiple_urls(rss_feeds): | |
return [news for rss_feed in rss_feeds for news in fetch_news(rss_feed)] | |
# Language_id | |
model_ckpt = "papluca/xlm-roberta-base-language-detection" | |
pipe = pipeline("text-classification", model=model_ckpt) | |
def language_id(strings:list[str]): | |
return [(string,language_map[pipe(string, top_k=1, truncation=True)[0]['label']]) for string in strings] | |
# Translation | |
## Initialize T5 model and tokenizer | |
model_name = "t5-small" | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
model = T5ForConditionalGeneration.from_pretrained(model_name) | |
def translate(source_text_with_id, target_language): | |
# source_text_with_id = ('text','French') for example | |
source_language = source_text_with_id[1] | |
assert source_language in language_map.values(), f"{source_language} language is not supported." | |
assert target_language in language_map.values(), f"{target_language} language is not supported." | |
source_text = f"translate {source_language} to {target_language}: " + source_text_with_id[0] | |
# Tokenize input text | |
input_ids = tokenizer.encode(source_text, return_tensors="pt") | |
# Generate translation | |
translated_ids = model.generate(input_ids=input_ids, max_length=100, num_beams=4, early_stopping=True) | |
# Decode translated text | |
return tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
def translate_multiple(source_texts_with_id, target_language): | |
return [translate(source_text_with_id, target_language) for source_text_with_id in source_texts_with_id] | |
# Speech generation | |
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2") | |
def read_news(text,input,output,language): | |
assert language in language_map.keys(), f"{language} language is not supported." | |
print("speech generation starting") | |
tts.tts_to_file(text=text, | |
file_path=output, | |
speaker_wav=input, | |
language=language) | |
print("speech generation done") | |
return output | |
# Gradio interface | |
def process(radio_value, textbox_value, audio_value, checkbox_value): | |
inputs = { | |
"language": radio_value, | |
"rss_feed_urls": textbox_value, | |
"audio": audio_value, | |
"selected_feeds": checkbox_value | |
} | |
print("Inputs to Gradio Blocks:") | |
print(inputs) | |
rss_feeds = get_rss_feeds(checkbox_value,textbox_value) | |
print("rss_feeds=",rss_feeds) | |
news = fetch_news_multiple_urls(rss_feeds) | |
print("news=",news[:2]) | |
news_with_language_id = language_id(news) | |
print("news_with_language_id=",news_with_language_id[:2]) | |
translated_news = translate_multiple(news_with_language_id, radio_value) | |
print("translated_news=",translated_news[:2]) | |
language = next((key for key, val in language_map.items() if val == radio_value), None) | |
print("language=",language) | |
all_news = ' '.join(translated_news) | |
print("all_news=",all_news[:80]) | |
output_path = "output.wav" | |
return read_news(all_news,audio_value,output_path,language) | |
with gr.Blocks() as demo: | |
gr.Markdown("Customize your newsletter and then click **Fetch News** to download the audio output.") | |
with gr.Row(): | |
radio = gr.Radio( | |
label='Choose the language of the output', | |
info="If the output language doesn't match the language of an RSS feed, an AI model will take care of translation", | |
choices=["English", "French"] | |
) | |
with gr.Row(): | |
textbox = gr.Textbox( | |
placeholder='https://www.francetvinfo.fr/titres.rss', | |
label='Add custom RSS feeds to your newsletter', | |
info='The provided urls needed to be written each in a separate line' | |
) | |
with gr.Row(): | |
audio = gr.Audio( | |
label="Upload a sample audio of someone speaking. The voice of the output will match the voice of the input.", | |
type='filepath' | |
) | |
with gr.Row(): | |
checkboxgroup = gr.CheckboxGroup( | |
["NY Times", "Fox News", "Yahoo! News", "France 24", "France Info"], | |
label="RSS feeds", | |
info="Default RSS feeds" | |
) | |
with gr.Row(): | |
btn = gr.Button(value='Fetch News') | |
with gr.Row(): | |
out = gr.DownloadButton("📂 Click to download file") | |
btn.click( | |
fn=process, | |
inputs=[radio, textbox, audio, checkboxgroup], | |
outputs=out | |
) | |
demo.launch(debug=True) |