File size: 6,427 Bytes
d3d8124 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# modal_app.py
import modal
import sys
from pathlib import Path
import os # Para HF_TOKEN
# --- Configuraci贸n ---
PYTHON_VERSION = "3.10"
APP_NAME = "bioprocess-custom-eq-agent-modal"
# Directorios (asume que modal_app.py est谩 en la ra铆z del proyecto junto a los otros .py)
LOCAL_APP_DIR = Path(__file__).parent
REMOTE_APP_DIR = "/app" # Directorio dentro del contenedor Modal
stub = modal.Stub(APP_NAME)
# Definici贸n de la imagen del contenedor Modal
app_image = (
modal.Image.debian_slim(python_version=PYTHON_VERSION)
.pip_install_from_requirements(LOCAL_APP_DIR / "requirements.txt")
.copy_mount(
modal.Mount.from_local_dir(LOCAL_APP_DIR, remote_path=REMOTE_APP_DIR)
)
.env({
"PYTHONPATH": REMOTE_APP_DIR,
"HF_HOME": "/cache/huggingface", # Directorio de cach茅 de Hugging Face
"HF_HUB_CACHE": "/cache/huggingface/hub",
"TRANSFORMERS_CACHE": "/cache/huggingface/hub", # Alias com煤n
"MPLCONFIGDIR": "/tmp/matplotlib_cache" # Para evitar warnings de matplotlib
})
.run_commands( # Comandos para ejecutar durante la construcci贸n de la imagen
"apt-get update && apt-get install -y git git-lfs && rm -rf /var/lib/apt/lists/*", # git-lfs para algunos modelos
"mkdir -p /cache/huggingface/hub /tmp/matplotlib_cache" # Crear directorios de cach茅
)
)
# --- Funci贸n Modal para Generaci贸n de An谩lisis con LLM ---
@stub.function(
image=app_image,
gpu="any", # Solicitar GPU (ej. "T4", "A10G", o "any")
secrets=[
modal.Secret.from_name("huggingface-read-token", optional=True) # Para modelos privados/gated
],
timeout=600, # 10 minutos de timeout
# Montar un volumen para cachear modelos de Hugging Face
volumes={"/cache/huggingface": modal.Volume.persisted(f"{APP_NAME}-hf-cache-vol")}
)
def generate_analysis_llm_modal_remote(prompt: str, model_path_config: str, max_new_tokens_config: int) -> str:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# El token de HF se inyecta como variable de entorno si el secreto est谩 configurado
hf_token = os.environ.get("HUGGING_FACE_TOKEN")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"LLM Modal Func: Usando dispositivo: {device}")
print(f"LLM Modal Func: Cargando modelo: {model_path_config} con token: {'S铆' if hf_token else 'No'}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_path_config, cache_dir="/cache/huggingface/hub", token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_path_config,
torch_dtype="auto", # bfloat16 en A100/H100, float16 en otras
device_map="auto", # Distribuye autom谩ticamente en GPUs disponibles
cache_dir="/cache/huggingface/hub",
token=hf_token,
# low_cpu_mem_usage=True # Puede ayudar con modelos muy grandes
)
# No es necesario .to(device) expl铆citamente con device_map="auto"
# model.eval() no es necesario si solo se hace inferencia y no se entrena
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096-max_new_tokens_config).to(model.device) # Truncar prompt si es muy largo
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens_config,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
do_sample=True,
temperature=0.6, # Ajustar para creatividad vs factualidad
top_p=0.9,
# num_beams=1 # Usar num_beams > 1 para beam search si se desea, pero m谩s lento
)
# Decodificar solo los tokens generados nuevos, no el prompt
input_length = inputs.input_ids.shape[1]
generated_ids = outputs[0][input_length:]
analysis = tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"LLM Modal Func: Longitud del an谩lisis generado: {len(analysis)} caracteres.")
return analysis.strip()
except Exception as e:
error_traceback = traceback.format_exc()
print(f"Error en generate_analysis_llm_modal_remote: {e}\n{error_traceback}")
return f"Error al generar an谩lisis con el modelo LLM: {str(e)}"
# --- Servidor Gradio ---
@stub.asgi_app() # image se hereda del stub si no se especifica aqu铆
def serve_gradio_app_asgi():
# Estas importaciones ocurren DENTRO del contenedor Modal
import gradio as gr
sys.path.insert(0, REMOTE_APP_DIR) # Asegurar que los m贸dulos de la app son importables
# Importar los m贸dulos de la aplicaci贸n AHORA que sys.path est谩 configurado
from UI import create_interface
import interface as app_interface_module # Renombrar para claridad
from config import MODEL_PATH as cfg_MODEL_PATH, MAX_LENGTH as cfg_MAX_LENGTH
# Wrapper para llamar a la funci贸n Modal remota desde tu interface.py
def analysis_func_wrapper_for_interface(prompt: str) -> str:
print("Gradio Backend: Llamando a generate_analysis_llm_modal_remote.remote...")
return generate_analysis_llm_modal_remote.remote(prompt, cfg_MODEL_PATH, cfg_MAX_LENGTH)
# Inyectar esta funci贸n wrapper en el m贸dulo `interface`
app_interface_module.generate_analysis_from_modal = analysis_func_wrapper_for_interface
app_interface_module.USE_MODAL_FOR_LLM_ANALYSIS = True
# Crear la app Gradio y conectar el bot贸n
gradio_ui, all_ui_inputs, ui_outputs, ui_submit_button = create_interface()
ui_submit_button.click(
fn=app_interface_module.process_and_plot,
inputs=all_ui_inputs,
outputs=ui_outputs
)
return gr.routes.App.create_app(gradio_ui) # Para montar Gradio en FastAPI/ASGI
# (Opcional) Un entrypoint local para probar r谩pidamente la generaci贸n LLM
@stub.local_entrypoint()
def test_llm():
print("Probando la generaci贸n de LLM con Modal (localmente)...")
from config import MODEL_PATH, MAX_LENGTH
sample_prompt = "Explica brevemente el concepto de R cuadrado (R虏) en el ajuste de modelos."
analysis = generate_analysis_llm_modal_remote.remote(sample_prompt, MODEL_PATH, MAX_LENGTH)
print("\nRespuesta del LLM:")
print(analysis) |