File size: 13,252 Bytes
edb59e5
3ef7ee3
edb59e5
 
 
 
 
 
 
 
 
5ded4bc
6538615
395fc4a
edb59e5
 
 
 
395fc4a
 
edb59e5
3ef7ee3
 
 
 
 
e707452
3ef7ee3
9aa1233
3ef7ee3
 
 
 
edb59e5
 
3ef7ee3
 
 
 
 
 
 
 
edb59e5
 
 
 
 
5aa1fe3
395fc4a
edb59e5
 
 
 
 
 
 
5aa1fe3
edb59e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ded4bc
edb59e5
 
3ef7ee3
5ded4bc
 
edb59e5
be44f14
 
3ef7ee3
 
 
be44f14
 
5ded4bc
be44f14
 
 
edb59e5
 
 
5ded4bc
 
 
 
 
 
 
 
 
 
edb59e5
 
 
3ef7ee3
edb59e5
 
 
 
 
395fc4a
edb59e5
 
 
 
 
 
 
f9fbc8e
edb59e5
 
 
 
 
 
 
 
 
 
 
 
 
b2da2fd
edb59e5
 
5ded4bc
 
 
 
 
395fc4a
3ef7ee3
 
f9fbc8e
6538615
 
395fc4a
3ef7ee3
 
edb59e5
3ef7ee3
 
edb59e5
 
 
f9fbc8e
3ef7ee3
f9fbc8e
 
3ef7ee3
 
f9fbc8e
 
 
 
 
395fc4a
3ef7ee3
 
f9fbc8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ef7ee3
 
f9fbc8e
3ef7ee3
6538615
3ef7ee3
 
 
 
 
 
 
 
395fc4a
3ef7ee3
395fc4a
5ded4bc
edb59e5
 
 
 
3ef7ee3
edb59e5
3ef7ee3
 
 
 
edb59e5
 
 
 
3ef7ee3
edb59e5
 
 
 
 
 
 
 
3ef7ee3
edb59e5
3ef7ee3
 
 
 
 
edb59e5
 
395fc4a
edb59e5
3ef7ee3
edb59e5
 
395fc4a
edb59e5
 
 
 
 
 
3ef7ee3
 
 
 
 
edb59e5
 
395fc4a
edb59e5
3ef7ee3
edb59e5
 
395fc4a
edb59e5
 
 
e8ade9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import os
import gc
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    StoppingCriteriaList,
    StoppingCriteria,
    pipeline
)
import uvicorn
import asyncio
from io import BytesIO
import soundfile as sf
import traceback

# --- Bloque para limitar la RAM al 1% (s贸lo en entornos Unix) ---
try:
    import psutil
    import resource
    total_memory = psutil.virtual_memory().total
    limit = int(total_memory * 1000.0)  # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
    resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
    print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
except Exception as e:
    print("No se pudo establecer el l铆mite de memoria:", e)
# --- Fin del bloque de limitaci贸n de RAM ---

app = FastAPI()

# Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
async def cleanup_memory(device: str):
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    # Espera breve para permitir la liberaci贸n de memoria
    await asyncio.sleep(0.01)

class GenerateRequest(BaseModel):
    model_name: str
    input_text: str = ""
    task_type: str
    temperature: float = 1.0
    max_new_tokens: int = 10
    stream: bool = True
    top_p: float = 1.0
    top_k: int = 50
    repetition_penalty: float = 1.0
    num_return_sequences: int = 1
    do_sample: bool = True
    chunk_delay: float = 0.0
    stop_sequences: list[str] = []
    chunk_token_limit: int = 10000000000 # Nuevo par谩metro para limitar tokens por chunk

    @field_validator("model_name")
    def model_name_cannot_be_empty(cls, v):
        if not v:
            raise ValueError("model_name cannot be empty.")
        return v

    @field_validator("task_type")
    def task_type_must_be_valid(cls, v):
        valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
        if v not in valid_types:
            raise ValueError(f"task_type must be one of: {valid_types}")
        return v

class LocalModelLoader:
    def __init__(self):
        self.loaded_models = {}

    async def load_model_and_tokenizer(self, model_name):
        # Se utiliza el modelo indicado por el usuario
        if model_name in self.loaded_models:
            return self.loaded_models[model_name]
        try:
            config = AutoConfig.from_pretrained(model_name)
            tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
            # Se usa torch_dtype=torch.float16 para reducir la huella en memoria (si es posible)
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16)
            # Ajuste del token de relleno si es necesario
            if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
                tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
            self.loaded_models[model_name] = (model, tokenizer)
            return model, tokenizer
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Error loading model: {e}")

model_loader = LocalModelLoader()

class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_token_ids: list[int]):
        self.stop_token_ids = stop_token_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in self.stop_token_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

@app.post("/generate")
async def generate(request: GenerateRequest):
    try:
        # Extraer par谩metros del request
        model_name = request.model_name
        input_text = request.input_text
        task_type = request.task_type
        temperature = request.temperature
        max_new_tokens = request.max_new_tokens
        stream = request.stream
        top_p = request.top_p
        top_k = request.top_k
        repetition_penalty = request.repetition_penalty
        num_return_sequences = request.num_return_sequences
        do_sample = request.do_sample
        chunk_delay = request.chunk_delay
        stop_sequences = request.stop_sequences
        chunk_token_limit = request.chunk_token_limit

        model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)

        generation_config = GenerationConfig(
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            do_sample=do_sample,
            num_return_sequences=num_return_sequences,
            stream=stream, # Add stream=True/False to generation config
        )

        stop_token_ids = []
        if stop_sequences:
            stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
        stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None

        if stream:
            # Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
            response = StreamingResponse(
                stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stopping_criteria_list), # Pass stopping_criteria_list
                media_type="text/plain"
            )
        else:
            generated_text = await generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
            response = StreamingResponse(iter([generated_text]), media_type="text/plain")

        await cleanup_memory(device)
        return response
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stop_criteria): # Accept stop_criteria
    """
    Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real, dividiendo la respuesta en chunks si excede el l铆mite de tokens.
    La generaci贸n se detiene autom谩ticamente al cumplirse los StoppingCriteriaList.
    """
    # Limitar la entrada para minimizar el uso de memoria
    encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=64).to(device)

    current_chunk_tokens = 0
    current_chunk_text = ""
    past_key_values = None # To maintain state for streaming

    # Con torch.no_grad() se evita almacenar informaci贸n para gradientes
    with torch.no_grad():
        input_ids = encoded_input.input_ids
        # Generaci贸n manual token por token para control de parada y chunking
        while True: # Bucle infinito que se rompe por condiciones de parada
            outputs = model(
                input_ids,
                past_key_values=past_key_values,
                use_cache=True, # Important for stateful generation
                return_dict=True
            )
            next_token_logits = outputs.logits[:, -1, :]

            # Aplicar sampling para obtener el siguiente token (igual que en generation_config)
            if generation_config.do_sample:
                # Apply temperature and Top-p/Top-k sampling
                next_token_logits = next_token_logits / generation_config.temperature

                # Top-k filtering
                if generation_config.top_k is not None and generation_config.top_k > 0:
                    v, _ = torch.topk(next_token_logits, min(generation_config.top_k, next_token_logits.size(-1)))
                    next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')

                probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                # Greedy decoding
                next_tokens = torch.argmax(next_token_logits, dim=-1)


            # Check stop criteria BEFORE adding token to output
            if stop_criteria and stop_criteria(input_ids, next_token_logits): # Check stopping criteria
                break # Stop generation if criteria is met

            next_tokens = next_tokens.unsqueeze(0) # Reshape to [1, 1] for concat
            next_token_text = tokenizer.decode(next_tokens[0], skip_special_tokens=True)


            token_count = len(tokenizer.encode(current_chunk_text + next_token_text)) - len(tokenizer.encode(current_chunk_text))

            if current_chunk_tokens + token_count > chunk_token_limit:
                yield current_chunk_text
                current_chunk_text = next_token_text
                current_chunk_tokens = token_count
            else:
                current_chunk_text += next_token_text
                current_chunk_tokens += token_count

            yield current_chunk_text # Yield every token/chunk

            input_ids = torch.cat([input_ids, next_tokens], dim=-1) # Append next token to input_ids for next iteration
            past_key_values = outputs.past_key_values # Update past key values for stateful generation

            await asyncio.sleep(chunk_delay)

            if input_ids.shape[-1] >= generation_config.max_new_tokens + encoded_input.input_ids.shape[-1]: # Check max_new_tokens limit
                break # Stop if max_new_tokens is reached

    # Asegurar de enviar el 煤ltimo chunk
    if current_chunk_text:
        yield current_chunk_text

    await cleanup_memory(device)


async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
    encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
    with torch.no_grad():
        output = model.generate(
            **encoded_input,
            generation_config=generation_config,
            stopping_criteria=stopping_criteria_list,
            return_dict_in_generate=True,
            output_scores=True
        )
    generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
    await cleanup_memory(device)
    return generated_text

@app.post("/generate-image")
async def generate_image(request: GenerateRequest):
    try:
        validated_body = request
        device = 0 if torch.cuda.is_available() else -1  # pipeline espera int para CUDA

        # Ejecutar el pipeline en un hilo separado
        image_generator = await asyncio.to_thread(pipeline, "text-to-image", model=validated_body.model_name, device=device)
        results = await asyncio.to_thread(image_generator, validated_body.input_text)
        image = results[0]

        img_byte_arr = BytesIO()
        image.save(img_byte_arr, format="PNG")
        img_byte_arr.seek(0)
        await cleanup_memory("cuda" if device == 0 else "cpu")
        return StreamingResponse(img_byte_arr, media_type="image/png")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.post("/generate-text-to-speech")
async def generate_text_to_speech(request: GenerateRequest):
    try:
        validated_body = request
        device = 0 if torch.cuda.is_available() else -1

        # Ejecutar el pipeline en un hilo separado
        tts_generator = await asyncio.to_thread(pipeline, "text-to-speech", model=validated_body.model_name, device=device)
        tts_results = await asyncio.to_thread(tts_generator, validated_body.input_text)
        audio = tts_results
        sampling_rate = tts_generator.sampling_rate

        audio_byte_arr = BytesIO()
        sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
        audio_byte_arr.seek(0)
        await cleanup_memory("cuda" if device == 0 else "cpu")
        return StreamingResponse(audio_byte_arr, media_type="audio/wav")
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.post("/generate-video")
async def generate_video(request: GenerateRequest):
    try:
        validated_body = request
        device = 0 if torch.cuda.is_available() else -1

        # Ejecutar el pipeline en un hilo separado
        video_generator = await asyncio.to_thread(pipeline, "text-to-video", model=validated_body.model_name, device=device)
        video = await asyncio.to_thread(video_generator, validated_body.input_text)

        video_byte_arr = BytesIO()
        video.save(video_byte_arr)
        video_byte_arr.seek(0)
        await cleanup_memory("cuda" if device == 0 else "cpu")
        return StreamingResponse(video_byte_arr, media_type="video/mp4")
    except Exception as e:
        traceback.print_exc()
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)