Hjgugugjhuhjggg commited on
Commit
9214e9b
verified
1 Parent(s): ebec48b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -37
app.py CHANGED
@@ -4,11 +4,12 @@ import boto3
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
- from safetensors.torch import load_file
8
  import torch
 
9
  import asyncio
 
10
 
11
- # Configuraci贸n de logs
12
  logger = logging.getLogger(__name__)
13
  logger.setLevel(logging.INFO)
14
  console_handler = logging.StreamHandler()
@@ -16,7 +17,6 @@ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
16
  console_handler.setFormatter(formatter)
17
  logger.addHandler(console_handler)
18
 
19
- # Configuraci贸n de AWS y S3
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
  AWS_REGION = os.getenv("AWS_REGION")
@@ -32,16 +32,13 @@ s3_client = boto3.client(
32
  region_name=AWS_REGION
33
  )
34
 
35
- # Crear la aplicaci贸n FastAPI
36
  app = FastAPI()
37
 
38
- # Modelo de datos para la solicitud
39
  class GenerateRequest(BaseModel):
40
  model_name: str
41
  input_text: str
42
- task_type: str
43
 
44
- # Clase para gestionar el acceso a S3
45
  class S3DirectStream:
46
  def __init__(self, bucket_name):
47
  self.s3_client = boto3.client(
@@ -52,63 +49,102 @@ class S3DirectStream:
52
  )
53
  self.bucket_name = bucket_name
54
 
55
- # Funci贸n para obtener el archivo desde S3
56
  async def stream_from_s3(self, key):
57
  loop = asyncio.get_event_loop()
58
  return await loop.run_in_executor(None, self._stream_from_s3, key)
59
 
60
  def _stream_from_s3(self, key):
61
  try:
 
62
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
63
- file_content = response['Body'].read()
64
- if not file_content:
65
- raise HTTPException(status_code=404, detail=f"El archivo {key} est谩 vac铆o.")
66
  return file_content
67
  except self.s3_client.exceptions.NoSuchKey:
68
  raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
69
  except Exception as e:
 
70
  raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
71
 
72
- # Cargar el modelo directamente desde S3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  async def load_model_from_s3(self, model_name):
74
  try:
 
75
  model_name = model_name.replace("/", "-").lower()
 
 
 
 
 
 
 
76
  model_bytes = await self.stream_from_s3(f"{model_name}/pytorch_model.bin")
77
- if model_bytes:
78
- model = load_file(model_bytes)
79
- return model
80
- model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_name}/pytorch_model.bin")
81
  return model
 
82
  except HTTPException as e:
83
  raise e
84
  except Exception as e:
 
85
  raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde S3: {e}")
86
 
87
- # Cargar el tokenizer desde S3
88
  async def load_tokenizer_from_s3(self, model_name):
89
  try:
 
90
  model_name = model_name.replace("/", "-").lower()
91
  tokenizer_bytes = await self.stream_from_s3(f"{model_name}/tokenizer.json")
92
- if not tokenizer_bytes:
93
- raise HTTPException(status_code=404, detail="El archivo tokenizer.json est谩 vac铆o o no existe.")
94
- tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{model_name}/tokenizer.json")
95
  return tokenizer
96
  except Exception as e:
 
97
  raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {e}")
98
 
99
- # Obtener los archivos del modelo desde S3
100
- async def get_model_file_parts(self, model_name):
101
  try:
 
 
 
102
  model_name = model_name.replace("/", "-").lower()
103
- files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_name)
104
- model_files = [obj['Key'] for obj in files.get('Contents', []) if model_name in obj['Key']]
105
- if not model_files:
106
- raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados.")
107
- return model_files
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {e}")
 
 
110
 
111
- # Endpoint para la generaci贸n
112
  @app.post("/generate")
113
  async def generate(request: GenerateRequest):
114
  try:
@@ -116,41 +152,54 @@ async def generate(request: GenerateRequest):
116
  model_name = request.model_name
117
  input_text = request.input_text
118
 
 
 
119
  s3_direct_stream = S3DirectStream(S3_BUCKET_NAME)
120
 
121
- # Cargar el modelo y tokenizer desde S3
122
  model = await s3_direct_stream.load_model_from_s3(model_name)
123
  tokenizer = await s3_direct_stream.load_tokenizer_from_s3(model_name)
124
 
125
- # Generar dependiendo del tipo de tarea
 
126
  if task_type == "text-to-text":
127
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
128
  result = generator(input_text, max_length=MAX_TOKENS, num_return_sequences=1)
 
129
  return {"result": result[0]["generated_text"]}
130
 
131
  elif task_type == "text-to-image":
132
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=0)
133
  image = generator(input_text)
 
134
  return {"image": image}
135
 
136
- elif task_type == "text-to-audio" or task_type == "text-to-speech":
137
- generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=0)
138
- audio = generator(input_text)
139
- return {"audio": audio}
140
-
141
  elif task_type == "text-to-video":
142
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=0)
143
  video = generator(input_text)
 
144
  return {"video": video}
145
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  else:
147
  raise HTTPException(status_code=400, detail="Tipo de tarea no soportado.")
 
148
  except HTTPException as e:
149
  raise e
150
  except Exception as e:
151
  raise HTTPException(status_code=500, detail=f"Error en la generaci贸n: {str(e)}")
152
 
153
- # Ejecutar la aplicaci贸n
154
  if __name__ == "__main__":
155
  import uvicorn
156
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ from huggingface_hub import hf_hub_download
8
  import torch
9
+ import safetensors
10
  import asyncio
11
+ from tqdm import tqdm # Importar tqdm para la barra de progreso
12
 
 
13
  logger = logging.getLogger(__name__)
14
  logger.setLevel(logging.INFO)
15
  console_handler = logging.StreamHandler()
 
17
  console_handler.setFormatter(formatter)
18
  logger.addHandler(console_handler)
19
 
 
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
22
  AWS_REGION = os.getenv("AWS_REGION")
 
32
  region_name=AWS_REGION
33
  )
34
 
 
35
  app = FastAPI()
36
 
 
37
  class GenerateRequest(BaseModel):
38
  model_name: str
39
  input_text: str
40
+ task_type: str # Added task type to handle different tasks (e.g., text-to-image, text-to-speech)
41
 
 
42
  class S3DirectStream:
43
  def __init__(self, bucket_name):
44
  self.s3_client = boto3.client(
 
49
  )
50
  self.bucket_name = bucket_name
51
 
 
52
  async def stream_from_s3(self, key):
53
  loop = asyncio.get_event_loop()
54
  return await loop.run_in_executor(None, self._stream_from_s3, key)
55
 
56
  def _stream_from_s3(self, key):
57
  try:
58
+ logger.info(f"Descargando archivo {key} desde S3...")
59
  response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
60
+ file_content = response['Body'].read() # This returns a bytes object
 
 
61
  return file_content
62
  except self.s3_client.exceptions.NoSuchKey:
63
  raise HTTPException(status_code=404, detail=f"El archivo {key} no existe en el bucket S3.")
64
  except Exception as e:
65
+ logger.error(f"Error al descargar {key} desde S3: {str(e)}")
66
  raise HTTPException(status_code=500, detail=f"Error al descargar {key} desde S3: {str(e)}")
67
 
68
+ async def get_model_file_parts(self, model_name):
69
+ loop = asyncio.get_event_loop()
70
+ return await loop.run_in_executor(None, self._get_model_file_parts, model_name)
71
+
72
+ def _get_model_file_parts(self, model_name):
73
+ try:
74
+ model_name = model_name.replace("/", "-").lower()
75
+ logger.info(f"Obteniendo archivos del modelo {model_name} desde S3...")
76
+ files = self.s3_client.list_objects_v2(Bucket=self.bucket_name, Prefix=model_name)
77
+ model_files = [obj['Key'] for obj in files.get('Contents', []) if model_name in obj['Key']]
78
+ if not model_files:
79
+ raise HTTPException(status_code=404, detail=f"Archivos del modelo {model_name} no encontrados.")
80
+ return model_files
81
+ except Exception as e:
82
+ raise HTTPException(status_code=500, detail=f"Error al obtener archivos del modelo {model_name} desde S3: {e}")
83
+
84
  async def load_model_from_s3(self, model_name):
85
  try:
86
+ logger.info(f"Cargando modelo {model_name} desde S3...")
87
  model_name = model_name.replace("/", "-").lower()
88
+ model_files = await self.get_model_file_parts(model_name)
89
+
90
+ if 'pytorch_model.bin' not in model_files:
91
+ raise HTTPException(status_code=404, detail="Archivo 'pytorch_model.bin' no encontrado en S3")
92
+ if 'tokenizer.json' not in model_files:
93
+ raise HTTPException(status_code=404, detail="Archivo 'tokenizer.json' no encontrado en S3")
94
+
95
  model_bytes = await self.stream_from_s3(f"{model_name}/pytorch_model.bin")
96
+ logger.info(f"Modelo descargado correctamente. Cargando el modelo en memoria...")
97
+ model = AutoModelForCausalLM.from_pretrained(model_bytes, config=model_name)
 
 
98
  return model
99
+
100
  except HTTPException as e:
101
  raise e
102
  except Exception as e:
103
+ logger.error(f"Error al cargar el modelo desde S3: {e}")
104
  raise HTTPException(status_code=500, detail=f"Error al cargar el modelo desde S3: {e}")
105
 
 
106
  async def load_tokenizer_from_s3(self, model_name):
107
  try:
108
+ logger.info(f"Cargando tokenizer del modelo {model_name} desde S3...")
109
  model_name = model_name.replace("/", "-").lower()
110
  tokenizer_bytes = await self.stream_from_s3(f"{model_name}/tokenizer.json")
111
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_bytes)
 
 
112
  return tokenizer
113
  except Exception as e:
114
+ logger.error(f"Error al cargar el tokenizer desde S3: {e}")
115
  raise HTTPException(status_code=500, detail=f"Error al cargar el tokenizer desde S3: {e}")
116
 
117
+ async def download_and_upload_to_s3(self, model_name, force_download=False):
 
118
  try:
119
+ if force_download:
120
+ logger.info(f"Forzando la descarga del modelo {model_name} y la carga a S3.")
121
+
122
  model_name = model_name.replace("/", "-").lower()
123
+
124
+ if not await self.file_exists_in_s3(f"{model_name}/pytorch_model.bin") or not await self.file_exists_in_s3(f"{model_name}/tokenizer.json"):
125
+ logger.info(f"Descargando archivos del modelo {model_name} desde Hugging Face...")
126
+ model_file = hf_hub_download(repo_id=model_name, filename="pytorch_model.bin", token=HUGGINGFACE_HUB_TOKEN, force_download=force_download)
127
+ tokenizer_file = hf_hub_download(repo_id=model_name, filename="tokenizer.json", token=HUGGINGFACE_HUB_TOKEN, force_download=force_download)
128
+
129
+ await self.create_s3_folders(f"{model_name}/")
130
+
131
+ if not await self.file_exists_in_s3(f"{model_name}/pytorch_model.bin"):
132
+ with open(model_file, "rb") as file:
133
+ logger.info(f"Cargando archivo {model_name}/pytorch_model.bin a S3...")
134
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/pytorch_model.bin", Body=file)
135
+
136
+ if not await self.file_exists_in_s3(f"{model_name}/tokenizer.json"):
137
+ with open(tokenizer_file, "rb") as file:
138
+ logger.info(f"Cargando archivo {model_name}/tokenizer.json a S3...")
139
+ self.s3_client.put_object(Bucket=self.bucket_name, Key=f"{model_name}/tokenizer.json", Body=file)
140
+ else:
141
+ logger.info(f"Los archivos del modelo {model_name} ya existen en S3. No es necesario descargarlos de nuevo.")
142
+
143
  except Exception as e:
144
+ logger.error(f"Error al descargar o cargar archivos desde Hugging Face a S3: {e}")
145
+ raise HTTPException(status_code=500, detail=f"Error al descargar o cargar archivos desde Hugging Face a S3: {e}")
146
+
147
 
 
148
  @app.post("/generate")
149
  async def generate(request: GenerateRequest):
150
  try:
 
152
  model_name = request.model_name
153
  input_text = request.input_text
154
 
155
+ logger.info(f"Iniciando la generaci贸n para el modelo {model_name} con el tipo de tarea {task_type}...")
156
+
157
  s3_direct_stream = S3DirectStream(S3_BUCKET_NAME)
158
 
 
159
  model = await s3_direct_stream.load_model_from_s3(model_name)
160
  tokenizer = await s3_direct_stream.load_tokenizer_from_s3(model_name)
161
 
162
+ logger.info(f"Modelo y tokenizer cargados correctamente. Procesando tarea {task_type}...")
163
+
164
  if task_type == "text-to-text":
165
  generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
166
  result = generator(input_text, max_length=MAX_TOKENS, num_return_sequences=1)
167
+ logger.info(f"Generaci贸n completada: {result[0]['generated_text']}")
168
  return {"result": result[0]["generated_text"]}
169
 
170
  elif task_type == "text-to-image":
171
  generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=0)
172
  image = generator(input_text)
173
+ logger.info(f"Imagen generada.")
174
  return {"image": image}
175
 
 
 
 
 
 
176
  elif task_type == "text-to-video":
177
  generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=0)
178
  video = generator(input_text)
179
+ logger.info(f"Video generado.")
180
  return {"video": video}
181
 
182
+ elif task_type == "text-to-speech":
183
+ generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=0)
184
+ audio = generator(input_text)
185
+ logger.info(f"Audio generado.")
186
+ return {"audio": audio}
187
+
188
+ elif task_type == "text-to-audio":
189
+ generator = pipeline("text-to-audio", model=model, tokenizer=tokenizer, device=0)
190
+ audio = generator(input_text)
191
+ logger.info(f"Audio generado.")
192
+ return {"audio": audio}
193
+
194
  else:
195
  raise HTTPException(status_code=400, detail="Tipo de tarea no soportado.")
196
+
197
  except HTTPException as e:
198
  raise e
199
  except Exception as e:
200
  raise HTTPException(status_code=500, detail=f"Error en la generaci贸n: {str(e)}")
201
 
202
+
203
  if __name__ == "__main__":
204
  import uvicorn
205
  uvicorn.run(app, host="0.0.0.0", port=7860)