Ghcg / app.py
Hjgugugjhuhjggg's picture
Update app.py
e707452 verified
raw
history blame
13.3 kB
import os
import gc
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
StoppingCriteriaList,
StoppingCriteria,
pipeline
)
import uvicorn
import asyncio
from io import BytesIO
import soundfile as sf
import traceback
# --- Bloque para limitar la RAM al 1% (s贸lo en entornos Unix) ---
try:
import psutil
import resource
total_memory = psutil.virtual_memory().total
limit = int(total_memory * 1000.0) # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
except Exception as e:
print("No se pudo establecer el l铆mite de memoria:", e)
# --- Fin del bloque de limitaci贸n de RAM ---
app = FastAPI()
# Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
async def cleanup_memory(device: str):
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# Espera breve para permitir la liberaci贸n de memoria
await asyncio.sleep(0.01)
class GenerateRequest(BaseModel):
model_name: str
input_text: str = ""
task_type: str
temperature: float = 1.0
max_new_tokens: int = 10
stream: bool = True
top_p: float = 1.0
top_k: int = 50
repetition_penalty: float = 1.0
num_return_sequences: int = 1
do_sample: bool = True
chunk_delay: float = 0.0
stop_sequences: list[str] = []
chunk_token_limit: int = 10000000000 # Nuevo par谩metro para limitar tokens por chunk
@field_validator("model_name")
def model_name_cannot_be_empty(cls, v):
if not v:
raise ValueError("model_name cannot be empty.")
return v
@field_validator("task_type")
def task_type_must_be_valid(cls, v):
valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
if v not in valid_types:
raise ValueError(f"task_type must be one of: {valid_types}")
return v
class LocalModelLoader:
def __init__(self):
self.loaded_models = {}
async def load_model_and_tokenizer(self, model_name):
# Se utiliza el modelo indicado por el usuario
if model_name in self.loaded_models:
return self.loaded_models[model_name]
try:
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
# Se usa torch_dtype=torch.float16 para reducir la huella en memoria (si es posible)
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16)
# Ajuste del token de relleno si es necesario
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
self.loaded_models[model_name] = (model, tokenizer)
return model, tokenizer
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
model_loader = LocalModelLoader()
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_token_ids: list[int]):
self.stop_token_ids = stop_token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_id in self.stop_token_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@app.post("/generate")
async def generate(request: GenerateRequest):
try:
# Extraer par谩metros del request
model_name = request.model_name
input_text = request.input_text
task_type = request.task_type
temperature = request.temperature
max_new_tokens = request.max_new_tokens
stream = request.stream
top_p = request.top_p
top_k = request.top_k
repetition_penalty = request.repetition_penalty
num_return_sequences = request.num_return_sequences
do_sample = request.do_sample
chunk_delay = request.chunk_delay
stop_sequences = request.stop_sequences
chunk_token_limit = request.chunk_token_limit
model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
generation_config = GenerationConfig(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
num_return_sequences=num_return_sequences,
stream=stream, # Add stream=True/False to generation config
)
stop_token_ids = []
if stop_sequences:
stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
if stream:
# Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
response = StreamingResponse(
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stopping_criteria_list), # Pass stopping_criteria_list
media_type="text/plain"
)
else:
generated_text = await generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
response = StreamingResponse(iter([generated_text]), media_type="text/plain")
await cleanup_memory(device)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stop_criteria): # Accept stop_criteria
"""
Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real, dividiendo la respuesta en chunks si excede el l铆mite de tokens.
La generaci贸n se detiene autom谩ticamente al cumplirse los StoppingCriteriaList.
"""
# Limitar la entrada para minimizar el uso de memoria
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=64).to(device)
current_chunk_tokens = 0
current_chunk_text = ""
past_key_values = None # To maintain state for streaming
# Con torch.no_grad() se evita almacenar informaci贸n para gradientes
with torch.no_grad():
input_ids = encoded_input.input_ids
# Generaci贸n manual token por token para control de parada y chunking
while True: # Bucle infinito que se rompe por condiciones de parada
outputs = model(
input_ids,
past_key_values=past_key_values,
use_cache=True, # Important for stateful generation
return_dict=True
)
next_token_logits = outputs.logits[:, -1, :]
# Aplicar sampling para obtener el siguiente token (igual que en generation_config)
if generation_config.do_sample:
# Apply temperature and Top-p/Top-k sampling
next_token_logits = next_token_logits / generation_config.temperature
# Top-k filtering
if generation_config.top_k is not None and generation_config.top_k > 0:
v, _ = torch.topk(next_token_logits, min(generation_config.top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
# Greedy decoding
next_tokens = torch.argmax(next_token_logits, dim=-1)
# Check stop criteria BEFORE adding token to output
if stop_criteria and stop_criteria(input_ids, next_token_logits): # Check stopping criteria
break # Stop generation if criteria is met
next_tokens = next_tokens.unsqueeze(0) # Reshape to [1, 1] for concat
next_token_text = tokenizer.decode(next_tokens[0], skip_special_tokens=True)
token_count = len(tokenizer.encode(current_chunk_text + next_token_text)) - len(tokenizer.encode(current_chunk_text))
if current_chunk_tokens + token_count > chunk_token_limit:
yield current_chunk_text
current_chunk_text = next_token_text
current_chunk_tokens = token_count
else:
current_chunk_text += next_token_text
current_chunk_tokens += token_count
yield current_chunk_text # Yield every token/chunk
input_ids = torch.cat([input_ids, next_tokens], dim=-1) # Append next token to input_ids for next iteration
past_key_values = outputs.past_key_values # Update past key values for stateful generation
await asyncio.sleep(chunk_delay)
if input_ids.shape[-1] >= generation_config.max_new_tokens + encoded_input.input_ids.shape[-1]: # Check max_new_tokens limit
break # Stop if max_new_tokens is reached
# Asegurar de enviar el 煤ltimo chunk
if current_chunk_text:
yield current_chunk_text
await cleanup_memory(device)
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
with torch.no_grad():
output = model.generate(
**encoded_input,
generation_config=generation_config,
stopping_criteria=stopping_criteria_list,
return_dict_in_generate=True,
output_scores=True
)
generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
await cleanup_memory(device)
return generated_text
@app.post("/generate-image")
async def generate_image(request: GenerateRequest):
try:
validated_body = request
device = 0 if torch.cuda.is_available() else -1 # pipeline espera int para CUDA
# Ejecutar el pipeline en un hilo separado
image_generator = await asyncio.to_thread(pipeline, "text-to-image", model=validated_body.model_name, device=device)
results = await asyncio.to_thread(image_generator, validated_body.input_text)
image = results[0]
img_byte_arr = BytesIO()
image.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
await cleanup_memory("cuda" if device == 0 else "cpu")
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/generate-text-to-speech")
async def generate_text_to_speech(request: GenerateRequest):
try:
validated_body = request
device = 0 if torch.cuda.is_available() else -1
# Ejecutar el pipeline en un hilo separado
tts_generator = await asyncio.to_thread(pipeline, "text-to-speech", model=validated_body.model_name, device=device)
tts_results = await asyncio.to_thread(tts_generator, validated_body.input_text)
audio = tts_results
sampling_rate = tts_generator.sampling_rate
audio_byte_arr = BytesIO()
sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
audio_byte_arr.seek(0)
await cleanup_memory("cuda" if device == 0 else "cpu")
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/generate-video")
async def generate_video(request: GenerateRequest):
try:
validated_body = request
device = 0 if torch.cuda.is_available() else -1
# Ejecutar el pipeline en un hilo separado
video_generator = await asyncio.to_thread(pipeline, "text-to-video", model=validated_body.model_name, device=device)
video = await asyncio.to_thread(video_generator, validated_body.input_text)
video_byte_arr = BytesIO()
video.save(video_byte_arr)
video_byte_arr.seek(0)
await cleanup_memory("cuda" if device == 0 else "cpu")
return StreamingResponse(video_byte_arr, media_type="video/mp4")
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)