neuralworm commited on
Commit
ff079ce
·
1 Parent(s): 3f5ced9

initial commit

Browse files
Files changed (1) hide show
  1. gen.py +18 -7
gen.py CHANGED
@@ -2,8 +2,8 @@ import torch
2
  import sys
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import json
5
- import jsonschema # Import the jsonschema library
6
- from jsonschema import validate # Import the validate function
7
 
8
  tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
9
 
@@ -182,20 +182,31 @@ def generate(event):
182
 
183
 
184
  output_text = tokenizer.decode(tokens[0], skip_special_tokens=False)
185
- print(output_text)
186
- user_prompt_length = len(f"<bos><start_of_turn>user\n{prompt}\n{event}<end_of_turn>\n<start_of_turn>model\n") # Calculate user prompt length
187
 
188
  json_start_index = output_text.find("<json>")
189
  json_end_index = output_text.find("</json>")
190
 
191
  if json_start_index != -1 and json_end_index != -1:
192
- json_string = output_text[max(json_start_index + 6, user_prompt_length):json_end_index].strip() # Trim whitespace and remove prompt
193
 
194
- # Validate JSON (you'll need to define a schema for your JSON structure)
 
 
 
 
 
 
 
 
 
 
 
 
195
  try:
196
  validate(instance=json.loads(json_string), schema=your_json_schema)
197
  return json_string
198
- except ValidationError as e:
199
  return f"Error: Invalid JSON - {e}"
200
 
201
  else:
 
2
  import sys
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import json
5
+ import jsonschema
6
+ from jsonschema import validate, ValidationError
7
 
8
  tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
9
 
 
182
 
183
 
184
  output_text = tokenizer.decode(tokens[0], skip_special_tokens=False)
185
+ user_prompt_length = len(f"<bos><start_of_turn>user\n{prompt}\n{event}<end_of_turn>\n<start_of_turn>model\n")
 
186
 
187
  json_start_index = output_text.find("<json>")
188
  json_end_index = output_text.find("</json>")
189
 
190
  if json_start_index != -1 and json_end_index != -1:
191
+ json_string = output_text[max(json_start_index + 6, user_prompt_length):json_end_index].strip()
192
 
193
+ # Remove any leading/trailing non-JSON characters (if present)
194
+ if not json_string.startswith("{"):
195
+ first_brace_index = json_string.find("{")
196
+ if first_brace_index != -1:
197
+ json_string = json_string[first_brace_index:]
198
+
199
+ if not json_string.endswith("}"):
200
+ last_brace_index = json_string.rfind("}")
201
+ if last_brace_index != -1:
202
+ json_string = json_string[:last_brace_index + 1]
203
+
204
+
205
+ # Validate JSON
206
  try:
207
  validate(instance=json.loads(json_string), schema=your_json_schema)
208
  return json_string
209
+ except jsonschema.exceptions.ValidationError as e:
210
  return f"Error: Invalid JSON - {e}"
211
 
212
  else: