File size: 10,676 Bytes
edb59e5
3ef7ee3
edb59e5
 
 
 
 
 
 
 
 
5ded4bc
6538615
395fc4a
edb59e5
 
 
 
395fc4a
 
edb59e5
3ef7ee3
 
 
 
 
9aa1233
3ef7ee3
9aa1233
3ef7ee3
 
 
 
edb59e5
 
3ef7ee3
 
 
 
 
 
 
 
edb59e5
 
 
 
 
c938099
395fc4a
edb59e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ded4bc
edb59e5
 
3ef7ee3
5ded4bc
 
edb59e5
be44f14
 
3ef7ee3
 
 
be44f14
 
5ded4bc
be44f14
 
 
edb59e5
 
 
5ded4bc
 
 
 
 
 
 
 
 
 
edb59e5
 
 
3ef7ee3
edb59e5
 
 
 
 
395fc4a
edb59e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ded4bc
 
 
 
 
395fc4a
3ef7ee3
 
6538615
 
 
395fc4a
3ef7ee3
 
edb59e5
3ef7ee3
 
edb59e5
 
 
3ef7ee3
 
 
 
 
edb59e5
395fc4a
 
3ef7ee3
 
 
 
 
 
 
 
 
9aa1233
3ef7ee3
 
 
 
 
 
 
 
 
 
 
 
6538615
3ef7ee3
 
 
 
 
 
 
 
395fc4a
3ef7ee3
395fc4a
5ded4bc
edb59e5
 
 
 
3ef7ee3
edb59e5
3ef7ee3
 
 
 
edb59e5
 
 
 
3ef7ee3
edb59e5
 
 
 
 
 
 
 
3ef7ee3
edb59e5
3ef7ee3
 
 
 
 
edb59e5
 
395fc4a
edb59e5
3ef7ee3
edb59e5
 
395fc4a
edb59e5
 
 
 
 
 
3ef7ee3
 
 
 
 
edb59e5
 
395fc4a
edb59e5
3ef7ee3
edb59e5
 
395fc4a
edb59e5
 
 
e8ade9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import os
import gc
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    StoppingCriteriaList,
    StoppingCriteria,
    pipeline
)
import uvicorn
import asyncio
from io import BytesIO
import soundfile as sf
import traceback

# --- Bloque para limitar la RAM al 1% (s贸lo en entornos Unix) ---
try:
    import psutil
    import resource
    total_memory = psutil.virtual_memory().total
    limit = int(total_memory * 0.01)  # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
    resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
    print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
except Exception as e:
    print("No se pudo establecer el l铆mite de memoria:", e)
# --- Fin del bloque de limitaci贸n de RAM ---

app = FastAPI()

# Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
async def cleanup_memory(device: str):
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    # Espera breve para permitir la liberaci贸n de memoria
    await asyncio.sleep(0.01)

class GenerateRequest(BaseModel):
    model_name: str
    input_text: str = ""
    task_type: str
    temperature: float = 1.0
    max_new_tokens: int = 2
    stream: bool = True
    top_p: float = 1.0
    top_k: int = 50
    repetition_penalty: float = 1.0
    num_return_sequences: int = 1
    do_sample: bool = True
    chunk_delay: float = 0.0
    stop_sequences: list[str] = []

    @field_validator("model_name")
    def model_name_cannot_be_empty(cls, v):
        if not v:
            raise ValueError("model_name cannot be empty.")
        return v

    @field_validator("task_type")
    def task_type_must_be_valid(cls, v):
        valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
        if v not in valid_types:
            raise ValueError(f"task_type must be one of: {valid_types}")
        return v

class LocalModelLoader:
    def __init__(self):
        self.loaded_models = {}

    async def load_model_and_tokenizer(self, model_name):
        # Se utiliza el modelo indicado por el usuario
        if model_name in self.loaded_models:
            return self.loaded_models[model_name]
        try:
            config = AutoConfig.from_pretrained(model_name)
            tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
            # Se usa torch_dtype=torch.float16 para reducir la huella en memoria (si es posible)
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16)
            # Ajuste del token de relleno si es necesario
            if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
                tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
            self.loaded_models[model_name] = (model, tokenizer)
            return model, tokenizer
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Error loading model: {e}")

model_loader = LocalModelLoader()

class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_token_ids: list[int]):
        self.stop_token_ids = stop_token_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in self.stop_token_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

@app.post("/generate")
async def generate(request: GenerateRequest):
    try:
        # Extraer par谩metros del request
        model_name = request.model_name
        input_text = request.input_text
        task_type = request.task_type
        temperature = request.temperature
        max_new_tokens = request.max_new_tokens
        stream = request.stream
        top_p = request.top_p
        top_k = request.top_k
        repetition_penalty = request.repetition_penalty
        num_return_sequences = request.num_return_sequences
        do_sample = request.do_sample
        chunk_delay = request.chunk_delay
        stop_sequences = request.stop_sequences

        model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)

        generation_config = GenerationConfig(
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            do_sample=do_sample,
            num_return_sequences=num_return_sequences,
        )

        stop_token_ids = []
        if stop_sequences:
            stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
        stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None

        if stream:
            # Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
            response = StreamingResponse(
                stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
                media_type="text/plain"
            )
        else:
            generated_text = await generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
            response = StreamingResponse(iter([generated_text]), media_type="text/plain")

        await cleanup_memory(device)
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=64):
    """
    Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real.
    """
    # Limitar la entrada para minimizar el uso de memoria
    encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    encoded_input_len = encoded_input["input_ids"].shape[-1]

    # Con torch.no_grad() se evita almacenar informaci贸n para gradientes
    with torch.no_grad():
        # Se genera el texto de forma iterativa (streaming)
        for output in model.generate(
            **encoded_input,
            generation_config=generation_config,
            stopping_criteria=stopping_criteria_list,
            return_dict_in_generate=True,
            output_scores=True,
            # stream=True, # Eliminar 'stream=True' aqu铆, ya que GenerationConfig lo maneja
        ):
            # Se extraen solo los tokens generados (excluyendo la entrada)
            new_tokens = output.sequences[:, encoded_input_len:]
            for token_batch in new_tokens:
                token = tokenizer.decode(token_batch, skip_special_tokens=True)
                if token:
                    # Se env铆a cada token inmediatamente
                    yield token
                    await asyncio.sleep(chunk_delay)
    await cleanup_memory(device)

async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
    encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    with torch.no_grad():
        output = model.generate(
            **encoded_input,
            generation_config=generation_config,
            stopping_criteria=stopping_criteria_list,
            return_dict_in_generate=True,
            output_scores=True
        )
    generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
    await cleanup_memory(device)
    return generated_text

@app.post("/generate-image")
async def generate_image(request: GenerateRequest):
    try:
        validated_body = request
        device = 0 if torch.cuda.is_available() else -1  # pipeline espera int para CUDA

        # Ejecutar el pipeline en un hilo separado
        image_generator = await asyncio.to_thread(pipeline, "text-to-image", model=validated_body.model_name, device=device)
        results = await asyncio.to_thread(image_generator, validated_body.input_text)
        image = results[0]

        img_byte_arr = BytesIO()
        image.save(img_byte_arr, format="PNG")
        img_byte_arr.seek(0)
        await cleanup_memory("cuda" if device == 0 else "cpu")
        return StreamingResponse(img_byte_arr, media_type="image/png")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.post("/generate-text-to-speech")
async def generate_text_to_speech(request: GenerateRequest):
    try:
        validated_body = request
        device = 0 if torch.cuda.is_available() else -1

        # Ejecutar el pipeline en un hilo separado
        tts_generator = await asyncio.to_thread(pipeline, "text-to-speech", model=validated_body.model_name, device=device)
        tts_results = await asyncio.to_thread(tts_generator, validated_body.input_text)
        audio = tts_results
        sampling_rate = tts_generator.sampling_rate

        audio_byte_arr = BytesIO()
        sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
        audio_byte_arr.seek(0)
        await cleanup_memory("cuda" if device == 0 else "cpu")
        return StreamingResponse(audio_byte_arr, media_type="audio/wav")
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.post("/generate-video")
async def generate_video(request: GenerateRequest):
    try:
        validated_body = request
        device = 0 if torch.cuda.is_available() else -1

        # Ejecutar el pipeline en un hilo separado
        video_generator = await asyncio.to_thread(pipeline, "text-to-video", model=validated_body.model_name, device=device)
        video = await asyncio.to_thread(video_generator, validated_body.input_text)

        video_byte_arr = BytesIO()
        video.save(video_byte_arr)
        video_byte_arr.seek(0)
        await cleanup_memory("cuda" if device == 0 else "cpu")
        return StreamingResponse(video_byte_arr, media_type="video/mp4")
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)