import os import gc import torch from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel, field_validator from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, StoppingCriteria, pipeline ) import uvicorn import asyncio from io import BytesIO import soundfile as sf import traceback # --- Bloque para limitar la RAM al 1% (sólo en entornos Unix) --- try: import psutil import resource total_memory = psutil.virtual_memory().total limit = int(total_memory * 1000.0) # 1% del total en bytes # Corrección: Usar 0.01 para 1% resource.setrlimit(resource.RLIMIT_AS, (limit, limit)) print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el límite aplicado except Exception as e: print("No se pudo establecer el límite de memoria:", e) # --- Fin del bloque de limitación de RAM --- app = FastAPI() # Función asíncrona para limpiar la memoria (RAM y caché CUDA) async def cleanup_memory(device: str): gc.collect() if device == "cuda": torch.cuda.empty_cache() # Espera breve para permitir la liberación de memoria await asyncio.sleep(0.01) class GenerateRequest(BaseModel): model_name: str input_text: str = "" task_type: str temperature: float = 1.0 max_new_tokens: int = 10 stream: bool = True top_p: float = 1.0 top_k: int = 50 repetition_penalty: float = 1.0 num_return_sequences: int = 1 do_sample: bool = True chunk_delay: float = 0.0 stop_sequences: list[str] = [] chunk_token_limit: int = 10000000000 # Nuevo parámetro para limitar tokens por chunk @field_validator("model_name") def model_name_cannot_be_empty(cls, v): if not v: raise ValueError("model_name cannot be empty.") return v @field_validator("task_type") def task_type_must_be_valid(cls, v): valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"] if v not in valid_types: raise ValueError(f"task_type must be one of: {valid_types}") return v class LocalModelLoader: def __init__(self): self.loaded_models = {} async def load_model_and_tokenizer(self, model_name): # Se utiliza el modelo indicado por el usuario if model_name in self.loaded_models: return self.loaded_models[model_name] try: config = AutoConfig.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) # Se usa torch_dtype=torch.float16 para reducir la huella en memoria (si es posible) model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16) # Ajuste del token de relleno si es necesario if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None: tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id self.loaded_models[model_name] = (model, tokenizer) return model, tokenizer except Exception as e: raise HTTPException(status_code=500, detail=f"Error loading model: {e}") model_loader = LocalModelLoader() class StopOnTokens(StoppingCriteria): def __init__(self, stop_token_ids: list[int]): self.stop_token_ids = stop_token_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in self.stop_token_ids: if input_ids[0][-1] == stop_id: return True return False @app.post("/generate") async def generate(request: GenerateRequest): try: # Extraer parámetros del request model_name = request.model_name input_text = request.input_text task_type = request.task_type temperature = request.temperature max_new_tokens = request.max_new_tokens stream = request.stream top_p = request.top_p top_k = request.top_k repetition_penalty = request.repetition_penalty num_return_sequences = request.num_return_sequences do_sample = request.do_sample chunk_delay = request.chunk_delay stop_sequences = request.stop_sequences chunk_token_limit = request.chunk_token_limit model, tokenizer = await model_loader.load_model_and_tokenizer(model_name) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) generation_config = GenerationConfig( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=do_sample, num_return_sequences=num_return_sequences, stream=stream, # Add stream=True/False to generation config ) stop_token_ids = [] if stop_sequences: stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences) stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None if stream: # Se utiliza StreamingResponse con la función asíncrona que envía cada token en tiempo real. response = StreamingResponse( stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stopping_criteria_list), # Pass stopping_criteria_list media_type="text/plain" ) else: generated_text = await generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device) response = StreamingResponse(iter([generated_text]), media_type="text/plain") await cleanup_memory(device) return response except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stop_criteria): # Accept stop_criteria """ Genera tokens de forma asíncrona y los envía al cliente en tiempo real, dividiendo la respuesta en chunks si excede el límite de tokens. La generación se detiene automáticamente al cumplirse los StoppingCriteriaList. """ # Limitar la entrada para minimizar el uso de memoria encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=64).to(device) current_chunk_tokens = 0 current_chunk_text = "" past_key_values = None # To maintain state for streaming # Con torch.no_grad() se evita almacenar información para gradientes with torch.no_grad(): input_ids = encoded_input.input_ids # Generación manual token por token para control de parada y chunking while True: # Bucle infinito que se rompe por condiciones de parada outputs = model( input_ids, past_key_values=past_key_values, use_cache=True, # Important for stateful generation return_dict=True ) next_token_logits = outputs.logits[:, -1, :] # Aplicar sampling para obtener el siguiente token (igual que en generation_config) if generation_config.do_sample: # Apply temperature and Top-p/Top-k sampling next_token_logits = next_token_logits / generation_config.temperature # Top-k filtering if generation_config.top_k is not None and generation_config.top_k > 0: v, _ = torch.topk(next_token_logits, min(generation_config.top_k, next_token_logits.size(-1))) next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') probs = torch.nn.functional.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: # Greedy decoding next_tokens = torch.argmax(next_token_logits, dim=-1) # Check stop criteria BEFORE adding token to output if stop_criteria and stop_criteria(input_ids, next_token_logits): # Check stopping criteria break # Stop generation if criteria is met next_tokens = next_tokens.unsqueeze(0) # Reshape to [1, 1] for concat next_token_text = tokenizer.decode(next_tokens[0], skip_special_tokens=True) token_count = len(tokenizer.encode(current_chunk_text + next_token_text)) - len(tokenizer.encode(current_chunk_text)) if current_chunk_tokens + token_count > chunk_token_limit: yield current_chunk_text current_chunk_text = next_token_text current_chunk_tokens = token_count else: current_chunk_text += next_token_text current_chunk_tokens += token_count yield current_chunk_text # Yield every token/chunk input_ids = torch.cat([input_ids, next_tokens], dim=-1) # Append next token to input_ids for next iteration past_key_values = outputs.past_key_values # Update past key values for stateful generation await asyncio.sleep(chunk_delay) if input_ids.shape[-1] >= generation_config.max_new_tokens + encoded_input.input_ids.shape[-1]: # Check max_new_tokens limit break # Stop if max_new_tokens is reached # Asegurar de enviar el último chunk if current_chunk_text: yield current_chunk_text await cleanup_memory(device) async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64): encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device) with torch.no_grad(): output = model.generate( **encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria_list, return_dict_in_generate=True, output_scores=True ) generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True) await cleanup_memory(device) return generated_text @app.post("/generate-image") async def generate_image(request: GenerateRequest): try: validated_body = request device = 0 if torch.cuda.is_available() else -1 # pipeline espera int para CUDA # Ejecutar el pipeline en un hilo separado image_generator = await asyncio.to_thread(pipeline, "text-to-image", model=validated_body.model_name, device=device) results = await asyncio.to_thread(image_generator, validated_body.input_text) image = results[0] img_byte_arr = BytesIO() image.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) await cleanup_memory("cuda" if device == 0 else "cpu") return StreamingResponse(img_byte_arr, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/generate-text-to-speech") async def generate_text_to_speech(request: GenerateRequest): try: validated_body = request device = 0 if torch.cuda.is_available() else -1 # Ejecutar el pipeline en un hilo separado tts_generator = await asyncio.to_thread(pipeline, "text-to-speech", model=validated_body.model_name, device=device) tts_results = await asyncio.to_thread(tts_generator, validated_body.input_text) audio = tts_results sampling_rate = tts_generator.sampling_rate audio_byte_arr = BytesIO() sf.write(audio_byte_arr, audio, sampling_rate, format='WAV') audio_byte_arr.seek(0) await cleanup_memory("cuda" if device == 0 else "cpu") return StreamingResponse(audio_byte_arr, media_type="audio/wav") except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/generate-video") async def generate_video(request: GenerateRequest): try: validated_body = request device = 0 if torch.cuda.is_available() else -1 # Ejecutar el pipeline en un hilo separado video_generator = await asyncio.to_thread(pipeline, "text-to-video", model=validated_body.model_name, device=device) video = await asyncio.to_thread(video_generator, validated_body.input_text) video_byte_arr = BytesIO() video.save(video_byte_arr) video_byte_arr.seek(0) await cleanup_memory("cuda" if device == 0 else "cpu") return StreamingResponse(video_byte_arr, media_type="video/mp4") except Exception as e: traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)