Update README.md
Browse files
README.md
CHANGED
@@ -226,17 +226,16 @@ def run_inference(prompt_text: str) -> str:
|
|
226 |
|
227 |
generated_ids = model.generate(
|
228 |
input_ids=inputs["input_ids"],
|
229 |
-
decoder_start_token_id=model.config.decoder_start_token_id,
|
230 |
-
max_new_tokens=100,
|
231 |
-
temperature=0.1,
|
232 |
-
num_beams=5,
|
233 |
-
repetition_penalty=1.2,
|
234 |
-
early_stopping=True,
|
235 |
)
|
236 |
|
237 |
generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
238 |
|
239 |
-
# ✅ Post-processing to remove repeated text
|
240 |
generated_sql = generated_sql.split(";")[0] + ";" # Keep only the first valid SQL query
|
241 |
|
242 |
return generated_sql
|
|
|
226 |
|
227 |
generated_ids = model.generate(
|
228 |
input_ids=inputs["input_ids"],
|
229 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
230 |
+
max_new_tokens=100,
|
231 |
+
temperature=0.1,
|
232 |
+
num_beams=5,
|
233 |
+
repetition_penalty=1.2,
|
234 |
+
early_stopping=True,
|
235 |
)
|
236 |
|
237 |
generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
238 |
|
|
|
239 |
generated_sql = generated_sql.split(";")[0] + ";" # Keep only the first valid SQL query
|
240 |
|
241 |
return generated_sql
|