Spaces:
Sleeping
Sleeping
Commit
·
ff079ce
1
Parent(s):
3f5ced9
initial commit
Browse files
gen.py
CHANGED
@@ -2,8 +2,8 @@ import torch
|
|
2 |
import sys
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
4 |
import json
|
5 |
-
import jsonschema
|
6 |
-
from jsonschema import validate
|
7 |
|
8 |
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
9 |
|
@@ -182,20 +182,31 @@ def generate(event):
|
|
182 |
|
183 |
|
184 |
output_text = tokenizer.decode(tokens[0], skip_special_tokens=False)
|
185 |
-
|
186 |
-
user_prompt_length = len(f"<bos><start_of_turn>user\n{prompt}\n{event}<end_of_turn>\n<start_of_turn>model\n") # Calculate user prompt length
|
187 |
|
188 |
json_start_index = output_text.find("<json>")
|
189 |
json_end_index = output_text.find("</json>")
|
190 |
|
191 |
if json_start_index != -1 and json_end_index != -1:
|
192 |
-
json_string = output_text[max(json_start_index + 6, user_prompt_length):json_end_index].strip()
|
193 |
|
194 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
try:
|
196 |
validate(instance=json.loads(json_string), schema=your_json_schema)
|
197 |
return json_string
|
198 |
-
except ValidationError as e:
|
199 |
return f"Error: Invalid JSON - {e}"
|
200 |
|
201 |
else:
|
|
|
2 |
import sys
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
4 |
import json
|
5 |
+
import jsonschema
|
6 |
+
from jsonschema import validate, ValidationError
|
7 |
|
8 |
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
9 |
|
|
|
182 |
|
183 |
|
184 |
output_text = tokenizer.decode(tokens[0], skip_special_tokens=False)
|
185 |
+
user_prompt_length = len(f"<bos><start_of_turn>user\n{prompt}\n{event}<end_of_turn>\n<start_of_turn>model\n")
|
|
|
186 |
|
187 |
json_start_index = output_text.find("<json>")
|
188 |
json_end_index = output_text.find("</json>")
|
189 |
|
190 |
if json_start_index != -1 and json_end_index != -1:
|
191 |
+
json_string = output_text[max(json_start_index + 6, user_prompt_length):json_end_index].strip()
|
192 |
|
193 |
+
# Remove any leading/trailing non-JSON characters (if present)
|
194 |
+
if not json_string.startswith("{"):
|
195 |
+
first_brace_index = json_string.find("{")
|
196 |
+
if first_brace_index != -1:
|
197 |
+
json_string = json_string[first_brace_index:]
|
198 |
+
|
199 |
+
if not json_string.endswith("}"):
|
200 |
+
last_brace_index = json_string.rfind("}")
|
201 |
+
if last_brace_index != -1:
|
202 |
+
json_string = json_string[:last_brace_index + 1]
|
203 |
+
|
204 |
+
|
205 |
+
# Validate JSON
|
206 |
try:
|
207 |
validate(instance=json.loads(json_string), schema=your_json_schema)
|
208 |
return json_string
|
209 |
+
except jsonschema.exceptions.ValidationError as e:
|
210 |
return f"Error: Invalid JSON - {e}"
|
211 |
|
212 |
else:
|