NewsIA / app.py
CamiloVega's picture
Update app.py
0d42823 verified
import spaces
import gradio as gr
import logging
import os
import tempfile
import pandas as pd
import requests
from bs4 import BeautifulSoup
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import whisper
from moviepy.editor import VideoFileClip
from pydub import AudioSegment
import fitz
import docx
import yt_dlp
from functools import lru_cache
import gc
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelManager:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.tokenizer = None
self.model = None
self.news_generator = None
self.whisper_model = None
self._initialized = True
@spaces.GPU(duration=120)
def initialize_models(self):
"""Initialize models with ZeroGPU compatible settings"""
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN')
if not HUGGINGFACE_TOKEN:
raise ValueError("HUGGINGFACE_TOKEN environment variable not set")
logger.info("Starting model initialization...")
model_name = "meta-llama/Llama-2-7b-chat-hf"
# Load tokenizer
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=HUGGINGFACE_TOKEN,
use_fast=True,
model_max_length=512
)
self.tokenizer.pad_token = self.tokenizer.eos_token
# Initialize model with ZeroGPU compatible settings
logger.info("Loading model...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
token=HUGGINGFACE_TOKEN,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
use_safetensors=True,
# ZeroGPU specific settings
max_memory={0: "6GB"},
offload_folder="offload",
offload_state_dict=True
)
# Create pipeline with minimal settings
logger.info("Creating pipeline...")
from transformers import pipeline
self.news_generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device_map="auto",
torch_dtype=torch.float16,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.2,
num_return_sequences=1,
early_stopping=True
)
# Load Whisper model with minimal settings
logger.info("Loading Whisper model...")
self.whisper_model = whisper.load_model(
"tiny",
device="cuda" if torch.cuda.is_available() else "cpu",
download_root="/tmp/whisper"
)
logger.info("All models initialized successfully")
return True
except Exception as e:
logger.error(f"Error during model initialization: {str(e)}")
self.reset_models()
raise
def reset_models(self):
"""Reset all models and clear memory"""
try:
if hasattr(self, 'model') and self.model is not None:
self.model.cpu()
del self.model
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
del self.tokenizer
if hasattr(self, 'news_generator') and self.news_generator is not None:
del self.news_generator
if hasattr(self, 'whisper_model') and self.whisper_model is not None:
if hasattr(self.whisper_model, 'cpu'):
self.whisper_model.cpu()
del self.whisper_model
self.tokenizer = None
self.model = None
self.news_generator = None
self.whisper_model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
import gc
gc.collect()
except Exception as e:
logger.error(f"Error during model reset: {str(e)}")
def check_models_initialized(self):
"""Check if all models are properly initialized"""
if None in (self.tokenizer, self.model, self.news_generator, self.whisper_model):
logger.warning("Models not initialized, attempting to initialize...")
self.initialize_models()
def get_models(self):
"""Get initialized models, initializing if necessary"""
self.check_models_initialized()
return self.tokenizer, self.model, self.news_generator, self.whisper_model
# Create global model manager instance
model_manager = ModelManager()
@lru_cache(maxsize=32)
def download_social_media_video(url):
"""Download a video from social media."""
ydl_opts = {
'format': 'bestaudio/best',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
'outtmpl': '%(id)s.%(ext)s',
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(url, download=True)
audio_file = f"{info_dict['id']}.mp3"
logger.info(f"Video downloaded successfully: {audio_file}")
return audio_file
except Exception as e:
logger.error(f"Error downloading video: {str(e)}")
raise
def convert_video_to_audio(video_file):
"""Convert a video file to audio."""
try:
video = VideoFileClip(video_file)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
video.audio.write_audiofile(temp_file.name)
logger.info(f"Video converted to audio: {temp_file.name}")
return temp_file.name
except Exception as e:
logger.error(f"Error converting video: {str(e)}")
raise
def preprocess_audio(audio_file):
"""Preprocess the audio file to improve quality."""
try:
audio = AudioSegment.from_file(audio_file)
audio = audio.apply_gain(-audio.dBFS + (-20))
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
audio.export(temp_file.name, format="mp3")
logger.info(f"Audio preprocessed: {temp_file.name}")
return temp_file.name
except Exception as e:
logger.error(f"Error preprocessing audio: {str(e)}")
raise
@spaces.GPU(duration=120)
def transcribe_audio(file):
"""Transcribe an audio or video file."""
try:
_, _, _, whisper_model = model_manager.get_models()
if isinstance(file, str) and file.startswith('http'):
file_path = download_social_media_video(file)
elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
file_path = convert_video_to_audio(file)
else:
file_path = preprocess_audio(file)
logger.info(f"Transcribing audio: {file_path}")
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
with torch.inference_mode():
result = whisper_model.transcribe(file_path)
if not result:
raise RuntimeError("Transcription failed to produce results")
transcription = result.get("text", "Error in transcription")
logger.info(f"Transcription completed: {transcription[:50]}...")
return transcription
except Exception as e:
logger.error(f"Error transcribing: {str(e)}")
return f"Error processing the file: {str(e)}"
@lru_cache(maxsize=32)
def read_document(document_path):
"""Read the content of a document."""
try:
if document_path.endswith(".pdf"):
doc = fitz.open(document_path)
return "\n".join([page.get_text() for page in doc])
elif document_path.endswith(".docx"):
doc = docx.Document(document_path)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
elif document_path.endswith(".xlsx"):
return pd.read_excel(document_path).to_string()
elif document_path.endswith(".csv"):
return pd.read_csv(document_path).to_string()
else:
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document."
except Exception as e:
logger.error(f"Error reading document: {str(e)}")
return f"Error reading document: {str(e)}"
@lru_cache(maxsize=32)
def read_url(url):
"""Read the content of a URL."""
try:
response = requests.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.content, 'html.parser')
return soup.get_text()
except Exception as e:
logger.error(f"Error reading URL: {str(e)}")
return f"Error reading URL: {str(e)}"
def process_social_content(url):
"""Process social media content."""
try:
text_content = read_url(url)
try:
video_content = transcribe_audio(url)
except Exception as e:
logger.error(f"Error processing video content: {str(e)}")
video_content = None
return {
"text": text_content,
"video": video_content
}
except Exception as e:
logger.error(f"Error processing social content: {str(e)}")
return None
@spaces.GPU(duration=120)
def generate_news(instructions, facts, size, tone, *args):
try:
tokenizer, _, news_generator, _ = model_manager.get_models()
knowledge_base = {
"instructions": instructions,
"facts": facts,
"document_content": [],
"audio_data": [],
"url_content": [],
"social_content": []
}
num_audios = 5 * 3
num_social_urls = 3 * 3
num_urls = 5
audios = args[:num_audios]
social_urls = args[num_audios:num_audios+num_social_urls]
urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls]
documents = args[num_audios+num_social_urls+num_urls:]
for url in urls:
if url:
content = read_url(url)
if content and not content.startswith("Error"):
knowledge_base["url_content"].append(content)
for document in documents:
if document is not None:
content = read_document(document.name)
if content and not content.startswith("Error"):
knowledge_base["document_content"].append(content)
for i in range(0, len(audios), 3):
audio_file, name, position = audios[i:i+3]
if audio_file is not None:
knowledge_base["audio_data"].append({
"audio": audio_file,
"name": name,
"position": position
})
for i in range(0, len(social_urls), 3):
social_url, social_name, social_context = social_urls[i:i+3]
if social_url:
social_content = process_social_content(social_url)
if social_content:
knowledge_base["social_content"].append({
"url": social_url,
"name": social_name,
"context": social_context,
"text": social_content["text"],
"video": social_content["video"]
})
transcriptions_text = ""
raw_transcriptions = ""
for idx, data in enumerate(knowledge_base["audio_data"]):
if data["audio"] is not None:
transcription = transcribe_audio(data["audio"])
if not transcription.startswith("Error"):
transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n'
raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n'
for data in knowledge_base["social_content"]:
if data["text"] and not str(data["text"]).startswith("Error"):
transcriptions_text += f'[Social media text]: "{data["text"][:200]}..." - {data["name"]}, {data["context"]}\n'
raw_transcriptions += transcriptions_text + "\n\n"
if data["video"] and not str(data["video"]).startswith("Error"):
video_transcription = f'[Social media video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n'
transcriptions_text += video_transcription
raw_transcriptions += video_transcription + "\n\n"
document_content = "\n\n".join(knowledge_base["document_content"])
url_content = "\n\n".join(knowledge_base["url_content"])
prompt = f"""[INST] You are a professional news writer. Write a news article based on the following information:
Instructions: {knowledge_base["instructions"]}
Facts: {knowledge_base["facts"]}
Additional content from documents: {document_content}
Additional content from URLs: {url_content}
Use these transcriptions as direct and indirect quotes:
{transcriptions_text}
Follow these requirements:
- Write a title
- Write a 15-word hook that complements the title
- Write the body with {size} words
- Use a {tone} tone
- Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph
- Use at least 80% direct quotes (in quotation marks)
- Use proper journalistic style
- Do not invent information
- Be rigorous with the provided facts [/INST]"""
# Optimize size and max tokens
max_tokens = min(int(size * 1.5), 512)
# Generate article with optimized settings
with torch.inference_mode():
try:
news_article = news_generator(
prompt,
max_new_tokens=max_tokens,
num_return_sequences=1,
do_sample=True,
temperature=0.7,
top_p=0.95,
repetition_penalty=1.2,
early_stopping=True
)
# Process the generated text
if isinstance(news_article, list):
news_article = news_article[0]['generated_text']
news_article = news_article.replace('[INST]', '').replace('[/INST]', '').strip()
except Exception as gen_error:
logger.error(f"Error in text generation: {str(gen_error)}")
raise
return news_article, raw_transcriptions
except Exception as e:
logger.error(f"Error generating news: {str(e)}")
try:
# Attempt to recover by resetting and reinitializing models
model_manager.reset_models()
model_manager.initialize_models()
logger.info("Models reinitialized successfully after error")
except Exception as reinit_error:
logger.error(f"Failed to reinitialize models: {str(reinit_error)}")
return f"Error generating the news article: {str(e)}", ""
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("## Generador de noticias todo en uno")
with gr.Row():
with gr.Column(scale=2):
instrucciones = gr.Textbox(
label="Instrucciones para la noticia",
lines=2
)
hechos = gr.Textbox(
label="Describe los hechos de la noticia",
lines=4
)
tamaño = gr.Number(
label="Tamaño del cuerpo de la noticia (en palabras)",
value=100
)
tono = gr.Dropdown(
label="Tono de la noticia",
choices=["serio", "neutral", "divertido"],
value="neutral"
)
with gr.Column(scale=3):
inputs_list = [instrucciones, hechos, tamaño, tono]
with gr.Tabs():
for i in range(1, 6):
with gr.TabItem(f"Audio/Video {i}"):
file = gr.File(
label=f"Audio/Video {i}",
file_types=["audio", "video"]
)
nombre = gr.Textbox(
label="Nombre",
placeholder="Nombre del entrevistado"
)
cargo = gr.Textbox(
label="Cargo",
placeholder="Cargo o rol"
)
inputs_list.extend([file, nombre, cargo])
for i in range(1, 4):
with gr.TabItem(f"Red Social {i}"):
social_url = gr.Textbox(
label=f"URL de red social {i}",
placeholder="https://..."
)
social_nombre = gr.Textbox(
label=f"Nombre de persona/cuenta {i}"
)
social_contexto = gr.Textbox(
label=f"Contexto del contenido {i}",
lines=2
)
inputs_list.extend([social_url, social_nombre, social_contexto])
for i in range(1, 6):
with gr.TabItem(f"URL {i}"):
url = gr.Textbox(
label=f"URL {i}",
placeholder="https://..."
)
inputs_list.append(url)
for i in range(1, 6):
with gr.TabItem(f"Documento {i}"):
documento = gr.File(
label=f"Documento {i}",
file_types=["pdf", "docx", "xlsx", "csv"],
file_count="single"
)
inputs_list.append(documento)
gr.Markdown("---")
with gr.Row():
transcripciones_output = gr.Textbox(
label="Transcripciones",
lines=10,
show_copy_button=True
)
gr.Markdown("---")
with gr.Row():
generar = gr.Button("Generar borrador")
with gr.Row():
noticia_output = gr.Textbox(
label="Borrador generado",
lines=20,
show_copy_button=True
)
generar.click(
fn=generate_news,
inputs=inputs_list,
outputs=[noticia_output, transcripciones_output]
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860
)