aarohanverma commited on
Commit
de0a177
·
verified ·
1 Parent(s): 32aace6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -7
README.md CHANGED
@@ -226,17 +226,16 @@ def run_inference(prompt_text: str) -> str:
226
 
227
  generated_ids = model.generate(
228
  input_ids=inputs["input_ids"],
229
- decoder_start_token_id=model.config.decoder_start_token_id, # ✅ Ensure decoder start token
230
- max_new_tokens=100, # ✅ Limit to prevent excessive output
231
- temperature=0.1, # ✅ Adds slight randomness to avoid repetition
232
- num_beams=5, # ✅ Increases quality
233
- repetition_penalty=1.2, # ✅ Penalizes repetition
234
- early_stopping=True, # ✅ Stops generation once complete
235
  )
236
 
237
  generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
238
 
239
- # ✅ Post-processing to remove repeated text
240
  generated_sql = generated_sql.split(";")[0] + ";" # Keep only the first valid SQL query
241
 
242
  return generated_sql
 
226
 
227
  generated_ids = model.generate(
228
  input_ids=inputs["input_ids"],
229
+ decoder_start_token_id=model.config.decoder_start_token_id,
230
+ max_new_tokens=100,
231
+ temperature=0.1,
232
+ num_beams=5,
233
+ repetition_penalty=1.2,
234
+ early_stopping=True,
235
  )
236
 
237
  generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
238
 
 
239
  generated_sql = generated_sql.split(";")[0] + ";" # Keep only the first valid SQL query
240
 
241
  return generated_sql