Hjgugugjhuhjggg commited on
Commit
e80f1ef
·
verified ·
1 Parent(s): 7ece340

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -142
app.py CHANGED
@@ -1,14 +1,21 @@
1
  import os
2
- from fastapi import FastAPI, HTTPException, Depends
3
- from fastapi.responses import JSONResponse
 
4
  from pydantic import BaseModel, field_validator
5
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, pipeline
 
 
 
 
 
 
 
6
  import boto3
7
  import uvicorn
8
- import soundfile as sf
9
- import imageio
10
- from typing import Dict, Optional
11
- import torch # Import torch
12
 
13
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
14
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
@@ -16,34 +23,24 @@ AWS_REGION = os.getenv("AWS_REGION")
16
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
17
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
18
 
19
- if not all([AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION, S3_BUCKET_NAME]):
20
- raise ValueError("Missing one or more AWS environment variables.")
21
-
22
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
23
 
24
  app = FastAPI()
25
 
26
- SPECIAL_TOKENS = {
27
- "bos_token": "<|startoftext|>",
28
- "eos_token": "<|endoftext|>",
29
- "pad_token": "[PAD]",
30
- "unk_token": "[UNK]",
31
- }
32
-
33
  class GenerateRequest(BaseModel):
34
  model_name: str
35
  input_text: str = ""
36
  task_type: str
37
  temperature: float = 1.0
38
- max_new_tokens: int = 10
 
39
  top_p: float = 1.0
40
  top_k: int = 50
41
- repetition_penalty: float = 1.1
42
  num_return_sequences: int = 1
43
  do_sample: bool = True
 
44
  stop_sequences: list[str] = []
45
- no_repeat_ngram_size: int = 2
46
- continuation_id: Optional[str] = None
47
 
48
  @field_validator("model_name")
49
  def model_name_cannot_be_empty(cls, v):
@@ -58,12 +55,6 @@ class GenerateRequest(BaseModel):
58
  raise ValueError(f"task_type must be one of: {valid_types}")
59
  return v
60
 
61
- @field_validator("max_new_tokens")
62
- def max_new_tokens_must_be_within_limit(cls, v):
63
- if v > 500:
64
- raise ValueError("max_new_tokens cannot be greater than 500.")
65
- return v
66
-
67
  class S3ModelLoader:
68
  def __init__(self, bucket_name, s3_client):
69
  self.bucket_name = bucket_name
@@ -75,166 +66,175 @@ class S3ModelLoader:
75
  async def load_model_and_tokenizer(self, model_name):
76
  s3_uri = self._get_s3_uri(model_name)
77
  try:
78
- config = AutoConfig.from_pretrained(s3_uri, local_files_only=False)
79
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=False)
80
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=False)
81
- tokenizer.add_special_tokens(SPECIAL_TOKENS)
82
- model.resize_token_embeddings(len(tokenizer))
83
- if tokenizer.pad_token_id is None:
84
- tokenizer.pad_token_id = tokenizer.eos_token_id
85
  return model, tokenizer
86
- except Exception as e:
87
- raise HTTPException(status_code=500, detail=f"Error loading model from S3: {e}")
 
 
 
88
 
89
- model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
 
90
 
91
- active_generations: Dict[str, Dict] = {}
 
 
 
 
92
 
93
- async def get_model_and_tokenizer(model_name: str):
94
- try:
95
- return await model_loader.load_model_and_tokenizer(model_name)
96
- except Exception as e:
97
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
98
 
99
  @app.post("/generate")
100
- async def generate(request: GenerateRequest, model_resources: tuple = Depends(get_model_and_tokenizer)):
101
- model, tokenizer = model_resources
102
  try:
103
  model_name = request.model_name
104
  input_text = request.input_text
 
105
  temperature = request.temperature
106
  max_new_tokens = request.max_new_tokens
 
107
  top_p = request.top_p
108
  top_k = request.top_k
109
  repetition_penalty = request.repetition_penalty
110
  num_return_sequences = request.num_return_sequences
111
  do_sample = request.do_sample
 
112
  stop_sequences = request.stop_sequences
113
- no_repeat_ngram_size = request.no_repeat_ngram_size
114
- continuation_id = request.continuation_id
115
-
116
- if continuation_id:
117
- if continuation_id not in active_generations:
118
- raise HTTPException(status_code=404, detail="Continuation ID not found.")
119
- previous_data = active_generations[continuation_id]
120
- if previous_data["model_name"] != model_name:
121
- raise HTTPException(status_code=400, detail="Model mismatch for continuation.")
122
- input_text = previous_data["output"]
123
-
124
- generation_config = GenerationConfig.from_pretrained(model_name) # Load default config and override
125
- generation_config.temperature = temperature
126
- generation_config.max_new_tokens = max_new_tokens
127
- generation_config.top_p = top_p
128
- generation_config.top_k = top_k
129
- generation_config.repetition_penalty = repetition_penalty
130
- generation_config.do_sample = do_sample
131
- generation_config.num_return_sequences = num_return_sequences
132
- generation_config.no_repeat_ngram_size = no_repeat_ngram_size
133
- generation_config.pad_token_id = tokenizer.pad_token_id
134
-
135
- generated_text = generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences)
136
-
137
- new_continuation_id = continuation_id if continuation_id else os.urandom(16).hex()
138
- active_generations[new_continuation_id] = {"model_name": model_name, "output": generated_text}
139
-
140
- return JSONResponse({"text": generated_text, "continuation_id": new_continuation_id, "model_name": model_name})
141
-
142
- except HTTPException as http_err:
143
- raise http_err
144
  except Exception as e:
145
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
146
 
147
- def generate_text_internal(model, tokenizer, input_text, generation_config, stop_sequences):
148
- max_model_length = model.config.max_position_embeddings
149
- encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(model.device) # Ensure input is on the same device as the model
 
150
 
151
- stopping_criteria = StoppingCriteriaList()
 
152
 
153
- class CustomStoppingCriteria(StoppingCriteria): # Inherit directly from StoppingCriteria
154
- def __init__(self, stop_sequences, tokenizer):
155
- self.stop_sequences = stop_sequences
156
- self.tokenizer = tokenizer
157
 
158
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
159
- decoded_output = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
160
- for stop in self.stop_sequences:
161
- if decoded_output.endswith(stop):
162
- return True
163
- return False
164
 
165
- stopping_criteria.append(CustomStoppingCriteria(stop_sequences, tokenizer))
166
 
 
167
  outputs = model.generate(
168
- encoded_input.input_ids,
169
- generation_config=generation_config,
 
 
 
 
 
 
170
  stopping_criteria=stopping_criteria,
171
- pad_token_id=generation_config.pad_token_id
 
172
  )
173
 
174
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
175
- return generated_text
176
-
177
- async def load_pipeline_from_s3(task, model_name):
178
- s3_uri = f"s3://{S3_BUCKET_NAME}/{model_name.replace('/', '-')}"
179
- try:
180
- return pipeline(task, model=s3_uri, token=HUGGINGFACE_HUB_TOKEN) # Include token if needed
181
- except Exception as e:
182
- raise HTTPException(status_code=500, detail=f"Error loading {task} model from S3: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  @app.post("/generate-image")
185
  async def generate_image(request: GenerateRequest):
186
  try:
187
- if request.task_type != "text-to-image":
188
- raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
189
-
190
- image_generator = await load_pipeline_from_s3("text-to-image", request.model_name)
191
- image = image_generator(request.input_text)[0]
192
- image_path = f"generated_image_{os.urandom(8).hex()}.png" # Save image locally
193
- image.save(image_path)
194
- new_continuation_id = os.urandom(16).hex()
195
- active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Image saved to {image_path}"} # Return path or upload URL
196
- return JSONResponse({"url": image_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
197
-
198
- except HTTPException as http_err:
199
- raise http_err
200
  except Exception as e:
201
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
202
 
203
  @app.post("/generate-text-to-speech")
204
  async def generate_text_to_speech(request: GenerateRequest):
205
  try:
206
- if request.task_type != "text-to-speech":
207
- raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
208
-
209
- tts_pipeline = await load_pipeline_from_s3("text-to-speech", request.model_name)
210
- audio_output = tts_pipeline(request.input_text)
211
- audio_path = f"generated_audio_{os.urandom(8).hex()}.wav"
212
- sf.write(audio_path, audio_output["sampling_rate"], audio_output["audio"])
213
- new_continuation_id = os.urandom(16).hex()
214
- active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Audio saved to {audio_path}"}
215
- return JSONResponse({"url": audio_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
216
-
217
- except HTTPException as http_err:
218
- raise http_err
219
  except Exception as e:
220
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
221
 
222
  @app.post("/generate-video")
223
  async def generate_video(request: GenerateRequest):
224
  try:
225
- if request.task_type != "text-to-video":
226
- raise HTTPException(status_code=400, detail="Invalid task_type for this endpoint.")
227
-
228
- video_pipeline = await load_pipeline_from_s3("text-to-video", request.model_name)
229
- video_frames = video_pipeline(request.input_text).frames
230
- video_path = f"generated_video_{os.urandom(8).hex()}.mp4"
231
- imageio.mimsave(video_path, video_frames, fps=30) # Adjust fps as needed
232
- new_continuation_id = os.urandom(16).hex()
233
- active_generations[new_continuation_id] = {"model_name": request.model_name, "output": f"Video saved to {video_path}"}
234
- return JSONResponse({"url": video_path, "continuation_id": new_continuation_id, "model_name": request.model_name})
235
-
236
- except HTTPException as http_err:
237
- raise http_err
238
  except Exception as e:
239
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
240
 
 
1
  import os
2
+ import torch
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
+ from transformers import (
7
+ AutoConfig,
8
+ pipeline,
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ GenerationConfig,
12
+ StoppingCriteriaList
13
+ )
14
  import boto3
15
  import uvicorn
16
+ import asyncio
17
+ from io import BytesIO
18
+ from transformers import pipeline
 
19
 
20
  AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
21
  AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
 
23
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
24
  HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
25
 
 
 
 
26
  s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
27
 
28
  app = FastAPI()
29
 
 
 
 
 
 
 
 
30
  class GenerateRequest(BaseModel):
31
  model_name: str
32
  input_text: str = ""
33
  task_type: str
34
  temperature: float = 1.0
35
+ max_new_tokens: int = 200
36
+ stream: bool = True
37
  top_p: float = 1.0
38
  top_k: int = 50
39
+ repetition_penalty: float = 1.0
40
  num_return_sequences: int = 1
41
  do_sample: bool = True
42
+ chunk_delay: float = 0.0
43
  stop_sequences: list[str] = []
 
 
44
 
45
  @field_validator("model_name")
46
  def model_name_cannot_be_empty(cls, v):
 
55
  raise ValueError(f"task_type must be one of: {valid_types}")
56
  return v
57
 
 
 
 
 
 
 
58
  class S3ModelLoader:
59
  def __init__(self, bucket_name, s3_client):
60
  self.bucket_name = bucket_name
 
66
  async def load_model_and_tokenizer(self, model_name):
67
  s3_uri = self._get_s3_uri(model_name)
68
  try:
69
+ config = AutoConfig.from_pretrained(s3_uri, local_files_only=True)
70
+ model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config, local_files_only=True)
71
+ tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config, local_files_only=True)
72
+
73
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
74
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
75
+
76
  return model, tokenizer
77
+ except EnvironmentError:
78
+ try:
79
+ config = AutoConfig.from_pretrained(model_name)
80
+ tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
81
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
82
 
83
+ if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
84
+ tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
85
 
86
+ model.save_pretrained(s3_uri)
87
+ tokenizer.save_pretrained(s3_uri)
88
+ return model, tokenizer
89
+ except Exception as e:
90
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
91
 
92
+ model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
 
 
 
 
93
 
94
  @app.post("/generate")
95
+ async def generate(request: GenerateRequest):
 
96
  try:
97
  model_name = request.model_name
98
  input_text = request.input_text
99
+ task_type = request.task_type
100
  temperature = request.temperature
101
  max_new_tokens = request.max_new_tokens
102
+ stream = request.stream
103
  top_p = request.top_p
104
  top_k = request.top_k
105
  repetition_penalty = request.repetition_penalty
106
  num_return_sequences = request.num_return_sequences
107
  do_sample = request.do_sample
108
+ chunk_delay = request.chunk_delay
109
  stop_sequences = request.stop_sequences
110
+
111
+ model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ model.to(device)
114
+
115
+ generation_config = GenerationConfig(
116
+ temperature=temperature,
117
+ max_new_tokens=max_new_tokens,
118
+ top_p=top_p,
119
+ top_k=top_k,
120
+ repetition_penalty=repetition_penalty,
121
+ do_sample=do_sample,
122
+ num_return_sequences=num_return_sequences,
123
+ )
124
+
125
+ return StreamingResponse(
126
+ stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
127
+ media_type="text/plain"
128
+ )
129
+
 
 
 
 
 
 
 
 
 
 
 
130
  except Exception as e:
131
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
132
 
133
+ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
134
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
+ input_length = encoded_input["input_ids"].shape[1]
136
+ remaining_tokens = max_length - input_length
137
 
138
+ if remaining_tokens <= 0:
139
+ yield ""
140
 
141
+ generation_config.max_new_tokens = min(remaining_tokens, generation_config.max_new_tokens)
 
 
 
142
 
143
+ def stop_criteria(input_ids, scores):
144
+ decoded_output = tokenizer.decode(int(input_ids[0][-1]), skip_special_tokens=True)
145
+ return decoded_output in stop_sequences
 
 
 
146
 
147
+ stopping_criteria = StoppingCriteriaList([stop_criteria])
148
 
149
+ output_text = ""
150
  outputs = model.generate(
151
+ **encoded_input,
152
+ do_sample=generation_config.do_sample,
153
+ max_new_tokens=generation_config.max_new_tokens,
154
+ temperature=generation_config.temperature,
155
+ top_p=generation_config.top_p,
156
+ top_k=generation_config.top_k,
157
+ repetition_penalty=generation_config.repetition_penalty,
158
+ num_return_sequences=generation_config.num_return_sequences,
159
  stopping_criteria=stopping_criteria,
160
+ output_scores=True,
161
+ return_dict_in_generate=True
162
  )
163
 
164
+ for output in outputs.sequences:
165
+ for token_id in output:
166
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
167
+ yield token
168
+ await asyncio.sleep(chunk_delay) # Simula el delay entre tokens
169
+
170
+ if stop_sequences and any(stop in output_text for stop in stop_sequences):
171
+ yield output_text
172
+ return
173
+
174
+ outputs = model.generate(
175
+ **encoded_input,
176
+ do_sample=generation_config.do_sample,
177
+ max_new_tokens=generation_config.max_new_tokens,
178
+ temperature=generation_config.temperature,
179
+ top_p=generation_config.top_p,
180
+ top_k=generation_config.top_k,
181
+ repetition_penalty=generation_config.repetition_penalty,
182
+ num_return_sequences=generation_config.num_return_sequences,
183
+ stopping_criteria=stopping_criteria,
184
+ output_scores=True,
185
+ return_dict_in_generate=True
186
+ )
187
 
188
  @app.post("/generate-image")
189
  async def generate_image(request: GenerateRequest):
190
  try:
191
+ validated_body = request
192
+ device = "cuda" if torch.cuda.is_available() else "cpu"
193
+
194
+ image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
195
+ image = image_generator(validated_body.input_text)[0]
196
+
197
+ img_byte_arr = BytesIO()
198
+ image.save(img_byte_arr, format="PNG")
199
+ img_byte_arr.seek(0)
200
+
201
+ return StreamingResponse(img_byte_arr, media_type="image/png")
202
+
 
203
  except Exception as e:
204
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
205
 
206
  @app.post("/generate-text-to-speech")
207
  async def generate_text_to_speech(request: GenerateRequest):
208
  try:
209
+ validated_body = request
210
+ device = "cuda" if torch.cuda.is_available() else "cpu"
211
+
212
+ audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
213
+ audio = audio_generator(validated_body.input_text)[0]
214
+
215
+ audio_byte_arr = BytesIO()
216
+ audio.save(audio_byte_arr)
217
+ audio_byte_arr.seek(0)
218
+
219
+ return StreamingResponse(audio_byte_arr, media_type="audio/wav")
220
+
 
221
  except Exception as e:
222
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
223
 
224
  @app.post("/generate-video")
225
  async def generate_video(request: GenerateRequest):
226
  try:
227
+ validated_body = request
228
+ device = "cuda" if torch.cuda.is_available() else "cpu"
229
+ video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
230
+ video = video_generator(validated_body.input_text)[0]
231
+
232
+ video_byte_arr = BytesIO()
233
+ video.save(video_byte_arr)
234
+ video_byte_arr.seek(0)
235
+
236
+ return StreamingResponse(video_byte_arr, media_type="video/mp4")
237
+
 
 
238
  except Exception as e:
239
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
240