|
import os |
|
import gc |
|
import torch |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel, field_validator |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
GenerationConfig, |
|
StoppingCriteriaList, |
|
StoppingCriteria, |
|
pipeline |
|
) |
|
import uvicorn |
|
import asyncio |
|
from io import BytesIO |
|
import soundfile as sf |
|
import traceback |
|
|
|
|
|
try: |
|
import psutil |
|
import resource |
|
total_memory = psutil.virtual_memory().total |
|
limit = int(total_memory * 1000.0) |
|
resource.setrlimit(resource.RLIMIT_AS, (limit, limit)) |
|
print(f"Memory limit set to {limit} bytes (1% of total system memory).") |
|
except Exception as e: |
|
print("No se pudo establecer el l铆mite de memoria:", e) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
async def cleanup_memory(device: str): |
|
gc.collect() |
|
if device == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
await asyncio.sleep(0.01) |
|
|
|
class GenerateRequest(BaseModel): |
|
model_name: str |
|
input_text: str = "" |
|
task_type: str |
|
temperature: float = 1.0 |
|
max_new_tokens: int = 10 |
|
stream: bool = True |
|
top_p: float = 1.0 |
|
top_k: int = 50 |
|
repetition_penalty: float = 1.0 |
|
num_return_sequences: int = 1 |
|
do_sample: bool = True |
|
chunk_delay: float = 0.0 |
|
stop_sequences: list[str] = [] |
|
chunk_token_limit: int = 10000000000 |
|
|
|
@field_validator("model_name") |
|
def model_name_cannot_be_empty(cls, v): |
|
if not v: |
|
raise ValueError("model_name cannot be empty.") |
|
return v |
|
|
|
@field_validator("task_type") |
|
def task_type_must_be_valid(cls, v): |
|
valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"] |
|
if v not in valid_types: |
|
raise ValueError(f"task_type must be one of: {valid_types}") |
|
return v |
|
|
|
class LocalModelLoader: |
|
def __init__(self): |
|
self.loaded_models = {} |
|
|
|
async def load_model_and_tokenizer(self, model_name): |
|
|
|
if model_name in self.loaded_models: |
|
return self.loaded_models[model_name] |
|
try: |
|
config = AutoConfig.from_pretrained(model_name) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16) |
|
|
|
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None: |
|
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id |
|
self.loaded_models[model_name] = (model, tokenizer) |
|
return model, tokenizer |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error loading model: {e}") |
|
|
|
model_loader = LocalModelLoader() |
|
|
|
class StopOnTokens(StoppingCriteria): |
|
def __init__(self, stop_token_ids: list[int]): |
|
self.stop_token_ids = stop_token_ids |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
for stop_id in self.stop_token_ids: |
|
if input_ids[0][-1] == stop_id: |
|
return True |
|
return False |
|
|
|
@app.post("/generate") |
|
async def generate(request: GenerateRequest): |
|
try: |
|
|
|
model_name = request.model_name |
|
input_text = request.input_text |
|
task_type = request.task_type |
|
temperature = request.temperature |
|
max_new_tokens = request.max_new_tokens |
|
stream = request.stream |
|
top_p = request.top_p |
|
top_k = request.top_k |
|
repetition_penalty = request.repetition_penalty |
|
num_return_sequences = request.num_return_sequences |
|
do_sample = request.do_sample |
|
chunk_delay = request.chunk_delay |
|
stop_sequences = request.stop_sequences |
|
chunk_token_limit = request.chunk_token_limit |
|
|
|
model, tokenizer = await model_loader.load_model_and_tokenizer(model_name) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
generation_config = GenerationConfig( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=do_sample, |
|
num_return_sequences=num_return_sequences, |
|
stream=stream, |
|
) |
|
|
|
stop_token_ids = [] |
|
if stop_sequences: |
|
stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences) |
|
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None |
|
|
|
if stream: |
|
|
|
response = StreamingResponse( |
|
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stopping_criteria_list), |
|
media_type="text/plain" |
|
) |
|
else: |
|
generated_text = await generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device) |
|
response = StreamingResponse(iter([generated_text]), media_type="text/plain") |
|
|
|
await cleanup_memory(device) |
|
return response |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stop_criteria): |
|
""" |
|
Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real, dividiendo la respuesta en chunks si excede el l铆mite de tokens. |
|
La generaci贸n se detiene autom谩ticamente al cumplirse los StoppingCriteriaList. |
|
""" |
|
|
|
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=64).to(device) |
|
|
|
current_chunk_tokens = 0 |
|
current_chunk_text = "" |
|
past_key_values = None |
|
|
|
|
|
with torch.no_grad(): |
|
input_ids = encoded_input.input_ids |
|
|
|
while True: |
|
outputs = model( |
|
input_ids, |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
return_dict=True |
|
) |
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
if generation_config.do_sample: |
|
|
|
next_token_logits = next_token_logits / generation_config.temperature |
|
|
|
|
|
if generation_config.top_k is not None and generation_config.top_k > 0: |
|
v, _ = torch.topk(next_token_logits, min(generation_config.top_k, next_token_logits.size(-1))) |
|
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |
|
else: |
|
|
|
next_tokens = torch.argmax(next_token_logits, dim=-1) |
|
|
|
|
|
|
|
if stop_criteria and stop_criteria(input_ids, next_token_logits): |
|
break |
|
|
|
next_tokens = next_tokens.unsqueeze(0) |
|
next_token_text = tokenizer.decode(next_tokens[0], skip_special_tokens=True) |
|
|
|
|
|
token_count = len(tokenizer.encode(current_chunk_text + next_token_text)) - len(tokenizer.encode(current_chunk_text)) |
|
|
|
if current_chunk_tokens + token_count > chunk_token_limit: |
|
yield current_chunk_text |
|
current_chunk_text = next_token_text |
|
current_chunk_tokens = token_count |
|
else: |
|
current_chunk_text += next_token_text |
|
current_chunk_tokens += token_count |
|
|
|
yield current_chunk_text |
|
|
|
input_ids = torch.cat([input_ids, next_tokens], dim=-1) |
|
past_key_values = outputs.past_key_values |
|
|
|
await asyncio.sleep(chunk_delay) |
|
|
|
if input_ids.shape[-1] >= generation_config.max_new_tokens + encoded_input.input_ids.shape[-1]: |
|
break |
|
|
|
|
|
if current_chunk_text: |
|
yield current_chunk_text |
|
|
|
await cleanup_memory(device) |
|
|
|
|
|
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64): |
|
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device) |
|
with torch.no_grad(): |
|
output = model.generate( |
|
**encoded_input, |
|
generation_config=generation_config, |
|
stopping_criteria=stopping_criteria_list, |
|
return_dict_in_generate=True, |
|
output_scores=True |
|
) |
|
generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True) |
|
await cleanup_memory(device) |
|
return generated_text |
|
|
|
@app.post("/generate-image") |
|
async def generate_image(request: GenerateRequest): |
|
try: |
|
validated_body = request |
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
image_generator = await asyncio.to_thread(pipeline, "text-to-image", model=validated_body.model_name, device=device) |
|
results = await asyncio.to_thread(image_generator, validated_body.input_text) |
|
image = results[0] |
|
|
|
img_byte_arr = BytesIO() |
|
image.save(img_byte_arr, format="PNG") |
|
img_byte_arr.seek(0) |
|
await cleanup_memory("cuda" if device == 0 else "cpu") |
|
return StreamingResponse(img_byte_arr, media_type="image/png") |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
@app.post("/generate-text-to-speech") |
|
async def generate_text_to_speech(request: GenerateRequest): |
|
try: |
|
validated_body = request |
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
tts_generator = await asyncio.to_thread(pipeline, "text-to-speech", model=validated_body.model_name, device=device) |
|
tts_results = await asyncio.to_thread(tts_generator, validated_body.input_text) |
|
audio = tts_results |
|
sampling_rate = tts_generator.sampling_rate |
|
|
|
audio_byte_arr = BytesIO() |
|
sf.write(audio_byte_arr, audio, sampling_rate, format='WAV') |
|
audio_byte_arr.seek(0) |
|
await cleanup_memory("cuda" if device == 0 else "cpu") |
|
return StreamingResponse(audio_byte_arr, media_type="audio/wav") |
|
except Exception as e: |
|
traceback.print_exc() |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
@app.post("/generate-video") |
|
async def generate_video(request: GenerateRequest): |
|
try: |
|
validated_body = request |
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
video_generator = await asyncio.to_thread(pipeline, "text-to-video", model=validated_body.model_name, device=device) |
|
video = await asyncio.to_thread(video_generator, validated_body.input_text) |
|
|
|
video_byte_arr = BytesIO() |
|
video.save(video_byte_arr) |
|
video_byte_arr.seek(0) |
|
await cleanup_memory("cuda" if device == 0 else "cpu") |
|
return StreamingResponse(video_byte_arr, media_type="video/mp4") |
|
except Exception as e: |
|
traceback.print_exc() |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |