Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
# Standard library imports | |
import logging | |
import os | |
import tempfile | |
from typing import List, Dict, Any | |
from pathlib import Path | |
# Third-party imports | |
import gradio as gr | |
import torch | |
import pandas as pd | |
import numpy as np | |
import requests | |
from bs4 import BeautifulSoup | |
import whisper | |
import yt_dlp | |
# Document processing imports | |
import fitz # PyMuPDF | |
from docx import Document | |
from pydub import AudioSegment | |
from moviepy.editor import VideoFileClip | |
# Hugging Face imports | |
from transformers import ( | |
pipeline, | |
AutoModelForCausalLM, | |
AutoTokenizer | |
) | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Environment variables | |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN') | |
if not HUGGINGFACE_TOKEN: | |
logger.error("HUGGINGFACE_TOKEN environment variable not set") | |
raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable") | |
# Global variables for models | |
tokenizer = None | |
model = None | |
news_generator = None | |
whisper_model = None | |
def custom_css(): | |
return """ | |
#main-container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.main-title { | |
text-align: center; | |
padding: 20px 0; | |
margin-bottom: 30px; | |
border-bottom: 2px solid #eee; | |
} | |
.section-title { | |
font-size: 1.2em; | |
margin-bottom: 15px; | |
color: #2c3e50; | |
} | |
.input-container { | |
background: #f8f9fa; | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
} | |
.source-tab { | |
padding: 15px; | |
background: white; | |
border-radius: 8px; | |
margin: 10px 0; | |
} | |
.generate-btn { | |
background: #2c3e50 !important; | |
color: white !important; | |
padding: 12px 24px !important; | |
} | |
.output-container { | |
background: #f8f9fa; | |
padding: 20px; | |
border-radius: 10px; | |
margin-top: 20px; | |
} | |
""" | |
def initialize_models(): | |
"""Initialize models with Zero GPU optimizations""" | |
global tokenizer, model, news_generator, whisper_model | |
try: | |
logger.info("Starting model initialization...") | |
model_name = "meta-llama/Llama-2-7b-chat-hf" | |
# Load tokenizer | |
logger.info("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
token=HUGGINGFACE_TOKEN | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model | |
logger.info("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=HUGGINGFACE_TOKEN, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
low_cpu_mem_usage=True | |
) | |
# Create pipeline | |
logger.info("Creating pipeline...") | |
news_generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
max_length=2048, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.2 | |
) | |
# Load Whisper model | |
logger.info("Loading Whisper model...") | |
whisper_model = whisper.load_model("base") | |
logger.info("All models initialized successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error during model initialization: {str(e)}") | |
raise | |
def download_social_media_video(url): | |
"""Download a video from social media.""" | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'outtmpl': '%(id)s.%(ext)s', | |
} | |
try: | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info_dict = ydl.extract_info(url, download=True) | |
audio_file = f"{info_dict['id']}.mp3" | |
logger.info(f"Video downloaded successfully: {audio_file}") | |
return audio_file | |
except Exception as e: | |
logger.error(f"Error downloading video: {str(e)}") | |
raise | |
def convert_video_to_audio(video_file): | |
"""Convert a video file to audio.""" | |
try: | |
video = VideoFileClip(video_file) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
video.audio.write_audiofile(temp_file.name) | |
logger.info(f"Video converted to audio: {temp_file.name}") | |
return temp_file.name | |
except Exception as e: | |
logger.error(f"Error converting video: {str(e)}") | |
raise | |
def preprocess_audio(audio_file): | |
"""Preprocess the audio file to improve quality.""" | |
try: | |
audio = AudioSegment.from_file(audio_file) | |
audio = audio.apply_gain(-audio.dBFS + (-20)) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
audio.export(temp_file.name, format="mp3") | |
logger.info(f"Audio preprocessed: {temp_file.name}") | |
return temp_file.name | |
except Exception as e: | |
logger.error(f"Error preprocessing audio: {str(e)}") | |
raise | |
def transcribe_audio(file): | |
"""Transcribe an audio or video file.""" | |
try: | |
if isinstance(file, str) and file.startswith('http'): | |
file_path = download_social_media_video(file) | |
elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): | |
file_path = convert_video_to_audio(file) | |
else: | |
file_path = preprocess_audio(file) | |
logger.info(f"Transcribing audio: {file_path}") | |
with torch.inference_mode(): | |
result = whisper_model.transcribe(file_path) | |
transcription = result.get("text", "Error in transcription") | |
logger.info(f"Transcription completed: {transcription[:50]}...") | |
return transcription | |
except Exception as e: | |
logger.error(f"Error transcribing: {str(e)}") | |
return f"Error processing the file: {str(e)}" | |
def read_document(document_path): | |
"""Read the content of a document.""" | |
try: | |
if document_path.endswith(".pdf"): | |
doc = fitz.open(document_path) | |
return "\n".join([page.get_text() for page in doc]) | |
elif document_path.endswith(".docx"): | |
doc = Document(document_path) | |
return "\n".join([paragraph.text for paragraph in doc.paragraphs]) | |
elif document_path.endswith(".xlsx"): | |
return pd.read_excel(document_path).to_string() | |
elif document_path.endswith(".csv"): | |
return pd.read_csv(document_path).to_string() | |
else: | |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document." | |
except Exception as e: | |
return f"Error reading document: {str(e)}" | |
def read_url(url): | |
"""Read the content of a URL.""" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
return soup.get_text() | |
except Exception as e: | |
return f"Error reading URL: {str(e)}" | |
def process_social_content(url): | |
"""Process social media content.""" | |
try: | |
text_content = read_url(url) | |
try: | |
video_content = transcribe_audio(url) | |
except Exception: | |
video_content = None | |
return { | |
"text": text_content, | |
"video": video_content | |
} | |
except Exception as e: | |
logger.error(f"Error processing social content: {str(e)}") | |
return None | |
def generate_news(instructions, facts, size, tone, *args): | |
try: | |
# Initialize knowledge base | |
knowledge_base = { | |
"instructions": instructions, | |
"facts": facts, | |
"document_content": [], | |
"audio_data": [], | |
"url_content": [], | |
"social_content": [] | |
} | |
# Parse arguments | |
num_audios = 5 * 3 | |
num_social_urls = 3 * 3 | |
num_urls = 5 | |
audios = args[:num_audios] | |
social_urls = args[num_audios:num_audios+num_social_urls] | |
urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls] | |
documents = args[num_audios+num_social_urls+num_urls:] | |
# Process URLs | |
for url in urls: | |
if url: | |
knowledge_base["url_content"].append(read_url(url)) | |
# Process documents | |
for document in documents: | |
if document is not None: | |
knowledge_base["document_content"].append(read_document(document.name)) | |
# Process audio files | |
for i in range(0, len(audios), 3): | |
audio_file, name, position = audios[i:i+3] | |
if audio_file is not None: | |
knowledge_base["audio_data"].append({ | |
"audio": audio_file, | |
"name": name, | |
"position": position | |
}) | |
# Process social media content | |
for i in range(0, len(social_urls), 3): | |
social_url, social_name, social_context = social_urls[i:i+3] | |
if social_url: | |
social_content = process_social_content(social_url) | |
if social_content: | |
knowledge_base["social_content"].append({ | |
"url": social_url, | |
"name": social_name, | |
"context": social_context, | |
"text": social_content["text"], | |
"video": social_content["video"] | |
}) | |
# Build transcriptions | |
transcriptions_text = "" | |
raw_transcriptions = "" | |
for idx, data in enumerate(knowledge_base["audio_data"]): | |
if data["audio"] is not None: | |
transcription = transcribe_audio(data["audio"]) | |
transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n' | |
raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n' | |
for data in knowledge_base["social_content"]: | |
if data["text"]: | |
transcriptions_text += f'[Social media text]: "{data["text"][:200]}..." - {data["name"]}, {data["context"]}\n' | |
raw_transcriptions += transcriptions_text + "\n\n" | |
if data["video"]: | |
video_transcription = f'[Social media video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n' | |
transcriptions_text += video_transcription | |
raw_transcriptions += video_transcription + "\n\n" | |
document_content = "\n\n".join(knowledge_base["document_content"]) | |
url_content = "\n\n".join(knowledge_base["url_content"]) | |
# Create prompt | |
prompt = f"""[INST] You are a professional news writer. Write a news article based on the following information: | |
Instructions: {knowledge_base["instructions"]} | |
Facts: {knowledge_base["facts"]} | |
Additional content from documents: {document_content} | |
Additional content from URLs: {url_content} | |
Use these transcriptions as direct and indirect quotes: | |
{transcriptions_text} | |
Follow these requirements: | |
- Write a title | |
- Write a 15-word hook that complements the title | |
- Write the body with {size} words | |
- Use a {tone} tone | |
- Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph | |
- Use at least 80% direct quotes (in quotation marks) | |
- Use proper journalistic style | |
- Do not invent information | |
- Be rigorous with the provided facts [/INST]""" | |
# Generate article | |
with torch.inference_mode(): | |
outputs = news_generator( | |
prompt, | |
max_new_tokens=min(int(size * 2), 1024), | |
return_full_text=False, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.2 | |
) | |
news_article = outputs[0]['generated_text'] | |
news_article = news_article.replace('[INST]', '').replace('[/INST]', '').strip() | |
return news_article, raw_transcriptions | |
except Exception as e: | |
logger.error(f"Error generating news: {str(e)}") | |
return f"Error generating the news article: {str(e)}", "" | |
# Create Gradio interface | |
def create_demo(): | |
with gr.Blocks(css=""" | |
/* Container styles */ | |
.gradio-container { | |
max-width: 1200px !important; | |
margin: auto !important; | |
} | |
/* Header styles */ | |
.header { | |
margin-bottom: 1rem; | |
} | |
.header h1 { | |
font-size: 1.5rem !important; | |
margin-bottom: 0.5rem !important; | |
} | |
/* Two column layout */ | |
.two-columns { | |
display: grid !important; | |
grid-template-columns: 300px 1fr !important; | |
gap: 2rem !important; | |
margin-top: 1rem !important; | |
} | |
/* Input fields */ | |
.input-field { | |
margin-bottom: 1rem !important; | |
} | |
/* Tab navigation */ | |
.tabs > .tab-nav { | |
display: flex !important; | |
flex-wrap: wrap !important; | |
gap: 4px !important; | |
border-bottom: 1px solid #e5e7eb !important; | |
padding-bottom: 0.5rem !important; | |
margin-bottom: 1rem !important; | |
} | |
.tab-nav * { | |
font-size: 0.8rem !important; | |
padding: 0.2rem 0.5rem !important; | |
border-radius: 4px !important; | |
background: transparent !important; | |
border: 1px solid #e5e7eb !important; | |
color: #374151 !important; | |
} | |
/* File upload area */ | |
.file-upload { | |
max-height: 120px !important; | |
min-height: 120px !important; | |
border: 1px dashed #e5e7eb !important; | |
border-radius: 4px !important; | |
display: flex !important; | |
align-items: center !important; | |
justify-content: center !important; | |
margin-bottom: 0.5rem !important; | |
padding: 1rem !important; | |
} | |
.file-upload svg { | |
width: 24px !important; | |
height: 24px !important; | |
opacity: 0.5 !important; | |
} | |
/* Button styles */ | |
.generate-btn { | |
margin-top: 1rem !important; | |
background: #4b5563 !important; | |
color: white !important; | |
padding: 0.5rem 1rem !important; | |
border-radius: 4px !important; | |
width: auto !important; | |
} | |
/* Output areas */ | |
.output-box { | |
margin-top: 1rem !important; | |
border: 1px solid #e5e7eb !important; | |
border-radius: 4px !important; | |
padding: 0.5rem !important; | |
} | |
""") as demo: | |
# Header | |
with gr.Group(elem_classes=["header"]): | |
gr.Markdown("# All-in-One News Generator") | |
gr.Markdown(""" | |
**About this tool** | |
This AI-powered news generator helps journalists and content creators produce news articles by processing multiple types of input: | |
- Audio and video files with automatic transcription | |
- Social media content | |
- Documents (PDF, DOCX, XLSX, CSV) | |
- Web URLs | |
The tool uses advanced AI to generate well-structured news articles following journalistic principles and maintaining the integrity of source quotes. | |
""") | |
gr.Markdown("*Created by Camilo Vega, AI Consultant*") | |
with gr.Row(elem_classes=["two-columns"]): | |
# Left column - Main inputs | |
with gr.Column(): | |
instructions = gr.Textbox( | |
label="News article instructions", | |
lines=3, | |
elem_classes=["input-field"] | |
) | |
facts = gr.Textbox( | |
label="Describe the news facts", | |
lines=4, | |
elem_classes=["input-field"] | |
) | |
size = gr.Number( | |
label="Content body size (in words)", | |
value=100, | |
elem_classes=["input-field"] | |
) | |
tone = gr.Dropdown( | |
label="News tone", | |
choices=["serious", "neutral", "lighthearted"], | |
value="neutral", | |
elem_classes=["input-field"] | |
) | |
# Right column - Source inputs | |
with gr.Column(): | |
inputs_list = [instructions, facts, size, tone] | |
with gr.Tabs() as tabs: | |
# Audio/Video Sources | |
for i in range(1, 6): | |
with gr.Tab(f"Audio/Video {i}"): | |
with gr.Group(): | |
file = gr.File( | |
label="Upload Audio/Video", | |
file_types=["audio", "video"], | |
elem_classes=["file-upload"] | |
) | |
name = gr.Textbox( | |
label="Name", | |
elem_classes=["input-field"] | |
) | |
position = gr.Textbox( | |
label="Position", | |
elem_classes=["input-field"] | |
) | |
inputs_list.extend([file, name, position]) | |
# Social Media Sources | |
for i in range(1, 4): | |
with gr.Tab(f"Social Media {i}"): | |
social_url = gr.Textbox( | |
label="URL", | |
elem_classes=["input-field"] | |
) | |
social_name = gr.Textbox( | |
label="Person/account name", | |
elem_classes=["input-field"] | |
) | |
social_context = gr.Textbox( | |
label="Content context", | |
elem_classes=["input-field"] | |
) | |
inputs_list.extend([social_url, social_name, social_context]) | |
# URLs | |
for i in range(1, 6): | |
with gr.Tab(f"URL {i}"): | |
url = gr.Textbox( | |
label=f"URL {i}", | |
elem_classes=["input-field"] | |
) | |
inputs_list.append(url) | |
# Documents | |
for i in range(1, 6): | |
with gr.Tab(f"Document {i}"): | |
document = gr.File( | |
label=f"Document {i}", | |
file_types=["pdf", "docx", "xlsx", "csv"], | |
elem_classes=["file-upload"] | |
) | |
inputs_list.append(document) | |
# Output areas | |
transcriptions_output = gr.Textbox( | |
label="Transcriptions", | |
lines=6, | |
elem_classes=["output-box"] | |
) | |
generate = gr.Button( | |
"Generate Draft", | |
elem_classes=["generate-btn"] | |
) | |
news_output = gr.Textbox( | |
label="Generated Draft", | |
lines=10, | |
elem_classes=["output-box"] | |
) | |
# Connect the generate button | |
generate.click( | |
fn=generate_news, | |
inputs=inputs_list, | |
outputs=[news_output, transcriptions_output] | |
) | |
return demo | |
# Initialize and launch | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |