Spaces:
Sleeping
Sleeping
import spaces | |
import gradio as gr | |
import logging | |
import os | |
import tempfile | |
import pandas as pd | |
import requests | |
from bs4 import BeautifulSoup | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import whisper | |
from moviepy.editor import VideoFileClip | |
from pydub import AudioSegment | |
import fitz | |
import docx | |
import yt_dlp | |
from functools import lru_cache | |
import gc | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
class ModelManager: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(ModelManager, cls).__new__(cls) | |
cls._instance._initialized = False | |
return cls._instance | |
def __init__(self): | |
if not self._initialized: | |
self.tokenizer = None | |
self.model = None | |
self.news_generator = None | |
self.whisper_model = None | |
self._initialized = True | |
def initialize_models(self): | |
"""Initialize models with ZeroGPU compatible settings""" | |
try: | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
HUGGINGFACE_TOKEN = os.environ.get('HUGGINGFACE_TOKEN') | |
if not HUGGINGFACE_TOKEN: | |
raise ValueError("HUGGINGFACE_TOKEN environment variable not set") | |
logger.info("Starting model initialization...") | |
model_name = "meta-llama/Llama-2-7b-chat-hf" | |
# Load tokenizer | |
logger.info("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
token=HUGGINGFACE_TOKEN, | |
use_fast=True, | |
model_max_length=512 | |
) | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Initialize model with ZeroGPU compatible settings | |
logger.info("Loading model...") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
token=HUGGINGFACE_TOKEN, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
# ZeroGPU specific settings | |
max_memory={0: "6GB"}, | |
offload_folder="offload", | |
offload_state_dict=True | |
) | |
# Create pipeline with minimal settings | |
logger.info("Creating pipeline...") | |
from transformers import pipeline | |
self.news_generator = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.2, | |
num_return_sequences=1, | |
early_stopping=True | |
) | |
# Load Whisper model with minimal settings | |
logger.info("Loading Whisper model...") | |
self.whisper_model = whisper.load_model( | |
"tiny", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
download_root="/tmp/whisper" | |
) | |
logger.info("All models initialized successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error during model initialization: {str(e)}") | |
self.reset_models() | |
raise | |
def reset_models(self): | |
"""Reset all models and clear memory""" | |
try: | |
if hasattr(self, 'model') and self.model is not None: | |
self.model.cpu() | |
del self.model | |
if hasattr(self, 'tokenizer') and self.tokenizer is not None: | |
del self.tokenizer | |
if hasattr(self, 'news_generator') and self.news_generator is not None: | |
del self.news_generator | |
if hasattr(self, 'whisper_model') and self.whisper_model is not None: | |
if hasattr(self.whisper_model, 'cpu'): | |
self.whisper_model.cpu() | |
del self.whisper_model | |
self.tokenizer = None | |
self.model = None | |
self.news_generator = None | |
self.whisper_model = None | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
import gc | |
gc.collect() | |
except Exception as e: | |
logger.error(f"Error during model reset: {str(e)}") | |
def check_models_initialized(self): | |
"""Check if all models are properly initialized""" | |
if None in (self.tokenizer, self.model, self.news_generator, self.whisper_model): | |
logger.warning("Models not initialized, attempting to initialize...") | |
self.initialize_models() | |
def get_models(self): | |
"""Get initialized models, initializing if necessary""" | |
self.check_models_initialized() | |
return self.tokenizer, self.model, self.news_generator, self.whisper_model | |
# Create global model manager instance | |
model_manager = ModelManager() | |
def download_social_media_video(url): | |
"""Download a video from social media.""" | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'mp3', | |
'preferredquality': '192', | |
}], | |
'outtmpl': '%(id)s.%(ext)s', | |
} | |
try: | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info_dict = ydl.extract_info(url, download=True) | |
audio_file = f"{info_dict['id']}.mp3" | |
logger.info(f"Video downloaded successfully: {audio_file}") | |
return audio_file | |
except Exception as e: | |
logger.error(f"Error downloading video: {str(e)}") | |
raise | |
def convert_video_to_audio(video_file): | |
"""Convert a video file to audio.""" | |
try: | |
video = VideoFileClip(video_file) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
video.audio.write_audiofile(temp_file.name) | |
logger.info(f"Video converted to audio: {temp_file.name}") | |
return temp_file.name | |
except Exception as e: | |
logger.error(f"Error converting video: {str(e)}") | |
raise | |
def preprocess_audio(audio_file): | |
"""Preprocess the audio file to improve quality.""" | |
try: | |
audio = AudioSegment.from_file(audio_file) | |
audio = audio.apply_gain(-audio.dBFS + (-20)) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file: | |
audio.export(temp_file.name, format="mp3") | |
logger.info(f"Audio preprocessed: {temp_file.name}") | |
return temp_file.name | |
except Exception as e: | |
logger.error(f"Error preprocessing audio: {str(e)}") | |
raise | |
def transcribe_audio(file): | |
"""Transcribe an audio or video file.""" | |
try: | |
_, _, _, whisper_model = model_manager.get_models() | |
if isinstance(file, str) and file.startswith('http'): | |
file_path = download_social_media_video(file) | |
elif isinstance(file, str) and file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): | |
file_path = convert_video_to_audio(file) | |
else: | |
file_path = preprocess_audio(file) | |
logger.info(f"Transcribing audio: {file_path}") | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Audio file not found: {file_path}") | |
with torch.inference_mode(): | |
result = whisper_model.transcribe(file_path) | |
if not result: | |
raise RuntimeError("Transcription failed to produce results") | |
transcription = result.get("text", "Error in transcription") | |
logger.info(f"Transcription completed: {transcription[:50]}...") | |
return transcription | |
except Exception as e: | |
logger.error(f"Error transcribing: {str(e)}") | |
return f"Error processing the file: {str(e)}" | |
def read_document(document_path): | |
"""Read the content of a document.""" | |
try: | |
if document_path.endswith(".pdf"): | |
doc = fitz.open(document_path) | |
return "\n".join([page.get_text() for page in doc]) | |
elif document_path.endswith(".docx"): | |
doc = docx.Document(document_path) | |
return "\n".join([paragraph.text for paragraph in doc.paragraphs]) | |
elif document_path.endswith(".xlsx"): | |
return pd.read_excel(document_path).to_string() | |
elif document_path.endswith(".csv"): | |
return pd.read_csv(document_path).to_string() | |
else: | |
return "Unsupported file type. Please upload a PDF, DOCX, XLSX or CSV document." | |
except Exception as e: | |
logger.error(f"Error reading document: {str(e)}") | |
return f"Error reading document: {str(e)}" | |
def read_url(url): | |
"""Read the content of a URL.""" | |
try: | |
response = requests.get(url) | |
response.raise_for_status() | |
soup = BeautifulSoup(response.content, 'html.parser') | |
return soup.get_text() | |
except Exception as e: | |
logger.error(f"Error reading URL: {str(e)}") | |
return f"Error reading URL: {str(e)}" | |
def process_social_content(url): | |
"""Process social media content.""" | |
try: | |
text_content = read_url(url) | |
try: | |
video_content = transcribe_audio(url) | |
except Exception as e: | |
logger.error(f"Error processing video content: {str(e)}") | |
video_content = None | |
return { | |
"text": text_content, | |
"video": video_content | |
} | |
except Exception as e: | |
logger.error(f"Error processing social content: {str(e)}") | |
return None | |
def generate_news(instructions, facts, size, tone, *args): | |
try: | |
tokenizer, _, news_generator, _ = model_manager.get_models() | |
knowledge_base = { | |
"instructions": instructions, | |
"facts": facts, | |
"document_content": [], | |
"audio_data": [], | |
"url_content": [], | |
"social_content": [] | |
} | |
num_audios = 5 * 3 | |
num_social_urls = 3 * 3 | |
num_urls = 5 | |
audios = args[:num_audios] | |
social_urls = args[num_audios:num_audios+num_social_urls] | |
urls = args[num_audios+num_social_urls:num_audios+num_social_urls+num_urls] | |
documents = args[num_audios+num_social_urls+num_urls:] | |
for url in urls: | |
if url: | |
content = read_url(url) | |
if content and not content.startswith("Error"): | |
knowledge_base["url_content"].append(content) | |
for document in documents: | |
if document is not None: | |
content = read_document(document.name) | |
if content and not content.startswith("Error"): | |
knowledge_base["document_content"].append(content) | |
for i in range(0, len(audios), 3): | |
audio_file, name, position = audios[i:i+3] | |
if audio_file is not None: | |
knowledge_base["audio_data"].append({ | |
"audio": audio_file, | |
"name": name, | |
"position": position | |
}) | |
for i in range(0, len(social_urls), 3): | |
social_url, social_name, social_context = social_urls[i:i+3] | |
if social_url: | |
social_content = process_social_content(social_url) | |
if social_content: | |
knowledge_base["social_content"].append({ | |
"url": social_url, | |
"name": social_name, | |
"context": social_context, | |
"text": social_content["text"], | |
"video": social_content["video"] | |
}) | |
transcriptions_text = "" | |
raw_transcriptions = "" | |
for idx, data in enumerate(knowledge_base["audio_data"]): | |
if data["audio"] is not None: | |
transcription = transcribe_audio(data["audio"]) | |
if not transcription.startswith("Error"): | |
transcriptions_text += f'"{transcription}" - {data["name"]}, {data["position"]}\n' | |
raw_transcriptions += f'[Audio/Video {idx + 1}]: "{transcription}" - {data["name"]}, {data["position"]}\n\n' | |
for data in knowledge_base["social_content"]: | |
if data["text"] and not str(data["text"]).startswith("Error"): | |
transcriptions_text += f'[Social media text]: "{data["text"][:200]}..." - {data["name"]}, {data["context"]}\n' | |
raw_transcriptions += transcriptions_text + "\n\n" | |
if data["video"] and not str(data["video"]).startswith("Error"): | |
video_transcription = f'[Social media video]: "{data["video"]}" - {data["name"]}, {data["context"]}\n' | |
transcriptions_text += video_transcription | |
raw_transcriptions += video_transcription + "\n\n" | |
document_content = "\n\n".join(knowledge_base["document_content"]) | |
url_content = "\n\n".join(knowledge_base["url_content"]) | |
prompt = f"""[INST] You are a professional news writer. Write a news article based on the following information: | |
Instructions: {knowledge_base["instructions"]} | |
Facts: {knowledge_base["facts"]} | |
Additional content from documents: {document_content} | |
Additional content from URLs: {url_content} | |
Use these transcriptions as direct and indirect quotes: | |
{transcriptions_text} | |
Follow these requirements: | |
- Write a title | |
- Write a 15-word hook that complements the title | |
- Write the body with {size} words | |
- Use a {tone} tone | |
- Answer the 5 Ws (Who, What, When, Where, Why) in the first paragraph | |
- Use at least 80% direct quotes (in quotation marks) | |
- Use proper journalistic style | |
- Do not invent information | |
- Be rigorous with the provided facts [/INST]""" | |
# Optimize size and max tokens | |
max_tokens = min(int(size * 1.5), 512) | |
# Generate article with optimized settings | |
with torch.inference_mode(): | |
try: | |
news_article = news_generator( | |
prompt, | |
max_new_tokens=max_tokens, | |
num_return_sequences=1, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.2, | |
early_stopping=True | |
) | |
# Process the generated text | |
if isinstance(news_article, list): | |
news_article = news_article[0]['generated_text'] | |
news_article = news_article.replace('[INST]', '').replace('[/INST]', '').strip() | |
except Exception as gen_error: | |
logger.error(f"Error in text generation: {str(gen_error)}") | |
raise | |
return news_article, raw_transcriptions | |
except Exception as e: | |
logger.error(f"Error generating news: {str(e)}") | |
try: | |
# Attempt to recover by resetting and reinitializing models | |
model_manager.reset_models() | |
model_manager.initialize_models() | |
logger.info("Models reinitialized successfully after error") | |
except Exception as reinit_error: | |
logger.error(f"Failed to reinitialize models: {str(reinit_error)}") | |
return f"Error generating the news article: {str(e)}", "" | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("## Generador de noticias todo en uno") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
instrucciones = gr.Textbox( | |
label="Instrucciones para la noticia", | |
lines=2 | |
) | |
hechos = gr.Textbox( | |
label="Describe los hechos de la noticia", | |
lines=4 | |
) | |
tamaño = gr.Number( | |
label="Tamaño del cuerpo de la noticia (en palabras)", | |
value=100 | |
) | |
tono = gr.Dropdown( | |
label="Tono de la noticia", | |
choices=["serio", "neutral", "divertido"], | |
value="neutral" | |
) | |
with gr.Column(scale=3): | |
inputs_list = [instrucciones, hechos, tamaño, tono] | |
with gr.Tabs(): | |
for i in range(1, 6): | |
with gr.TabItem(f"Audio/Video {i}"): | |
file = gr.File( | |
label=f"Audio/Video {i}", | |
file_types=["audio", "video"] | |
) | |
nombre = gr.Textbox( | |
label="Nombre", | |
placeholder="Nombre del entrevistado" | |
) | |
cargo = gr.Textbox( | |
label="Cargo", | |
placeholder="Cargo o rol" | |
) | |
inputs_list.extend([file, nombre, cargo]) | |
for i in range(1, 4): | |
with gr.TabItem(f"Red Social {i}"): | |
social_url = gr.Textbox( | |
label=f"URL de red social {i}", | |
placeholder="https://..." | |
) | |
social_nombre = gr.Textbox( | |
label=f"Nombre de persona/cuenta {i}" | |
) | |
social_contexto = gr.Textbox( | |
label=f"Contexto del contenido {i}", | |
lines=2 | |
) | |
inputs_list.extend([social_url, social_nombre, social_contexto]) | |
for i in range(1, 6): | |
with gr.TabItem(f"URL {i}"): | |
url = gr.Textbox( | |
label=f"URL {i}", | |
placeholder="https://..." | |
) | |
inputs_list.append(url) | |
for i in range(1, 6): | |
with gr.TabItem(f"Documento {i}"): | |
documento = gr.File( | |
label=f"Documento {i}", | |
file_types=["pdf", "docx", "xlsx", "csv"], | |
file_count="single" | |
) | |
inputs_list.append(documento) | |
gr.Markdown("---") | |
with gr.Row(): | |
transcripciones_output = gr.Textbox( | |
label="Transcripciones", | |
lines=10, | |
show_copy_button=True | |
) | |
gr.Markdown("---") | |
with gr.Row(): | |
generar = gr.Button("Generar borrador") | |
with gr.Row(): | |
noticia_output = gr.Textbox( | |
label="Borrador generado", | |
lines=20, | |
show_copy_button=True | |
) | |
generar.click( | |
fn=generate_news, | |
inputs=inputs_list, | |
outputs=[noticia_output, transcripciones_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |