Hjgugugjhuhjggg commited on
Commit
f2dfe81
·
verified ·
1 Parent(s): 116d7b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -54,6 +54,12 @@ class GenerateRequest(BaseModel):
54
  raise ValueError(f"task_type must be one of: {valid_types}")
55
  return v
56
 
 
 
 
 
 
 
57
  class S3ModelLoader:
58
  def __init__(self, bucket_name, s3_client):
59
  self.bucket_name = bucket_name
@@ -130,9 +136,7 @@ async def generate(request: GenerateRequest):
130
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
131
 
132
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
133
- # Get the maximum model input length
134
  max_model_length = model.config.max_position_embeddings
135
-
136
  encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
137
 
138
  def stop_criteria(input_ids, scores):
@@ -164,7 +168,7 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
164
  )
165
  except IndexError as e:
166
  print(f"IndexError during generation: {e}")
167
- break # Exit the loop if there's an index error
168
 
169
  new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
170
 
 
54
  raise ValueError(f"task_type must be one of: {valid_types}")
55
  return v
56
 
57
+ @field_validator("max_new_tokens")
58
+ def max_new_tokens_must_be_within_limit(cls, v):
59
+ if v > 4:
60
+ raise ValueError("max_new_tokens cannot be greater than 4.")
61
+ return v
62
+
63
  class S3ModelLoader:
64
  def __init__(self, bucket_name, s3_client):
65
  self.bucket_name = bucket_name
 
136
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
137
 
138
  async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay):
 
139
  max_model_length = model.config.max_position_embeddings
 
140
  encoded_input = tokenizer(input_text, return_tensors="pt", max_length=max_model_length, truncation=True).to(device)
141
 
142
  def stop_criteria(input_ids, scores):
 
168
  )
169
  except IndexError as e:
170
  print(f"IndexError during generation: {e}")
171
+ break
172
 
173
  new_token_ids = outputs.sequences[0][encoded_input.input_ids.shape[-1]:]
174