Update app.py
Browse files
app.py
CHANGED
@@ -54,6 +54,12 @@ class GenerateRequest(BaseModel):
|
|
54 |
raise ValueError(f"task_type must be one of: {valid_types}")
|
55 |
return v
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
class S3ModelLoader:
|
58 |
def __init__(self, bucket_name, s3_client):
|
59 |
self.bucket_name = bucket_name
|
@@ -130,9 +136,7 @@ async def generate(request: GenerateRequest):
|
|
130 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
131 |
|
132 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
|
133 |
-
# Get the maximum model input length
|
134 |
max_model_length = model.config.max_position_embeddings
|
135 |
-
|
136 |
encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
|
137 |
|
138 |
def stop_criteria(input_ids, scores):
|
@@ -164,7 +168,7 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
|
|
164 |
)
|
165 |
except IndexError as e:
|
166 |
print(f"IndexError during generation: {e}")
|
167 |
-
break
|
168 |
|
169 |
new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
|
170 |
|
|
|
54 |
raise ValueError(f"task_type must be one of: {valid_types}")
|
55 |
return v
|
56 |
|
57 |
+
@field_validator("max_new_tokens")
|
58 |
+
def max_new_tokens_must_be_within_limit(cls, v):
|
59 |
+
if v > 4:
|
60 |
+
raise ValueError("max_new_tokens cannot be greater than 4.")
|
61 |
+
return v
|
62 |
+
|
63 |
class S3ModelLoader:
|
64 |
def __init__(self, bucket_name, s3_client):
|
65 |
self.bucket_name = bucket_name
|
|
|
136 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
137 |
|
138 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
|
|
|
139 |
max_model_length = model.config.max_position_embeddings
|
|
|
140 |
encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
|
141 |
|
142 |
def stop_criteria(input_ids, scores):
|
|
|
168 |
)
|
169 |
except IndexError as e:
|
170 |
print(f"IndexError during generation: {e}")
|
171 |
+
break
|
172 |
|
173 |
new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
|
174 |
|