mgoNeo4j commited on
Commit
a1aa7e9
1 Parent(s): 7bd3ab7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +75 -1
README.md CHANGED
@@ -253,4 +253,78 @@ Used RunPod with following setup:
253
  [More Information Needed] -->
254
  ### Framework versions
255
 
256
- - PEFT 0.12.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  [More Information Needed] -->
254
  ### Framework versions
255
 
256
+ - PEFT 0.12.0
257
+
258
+ ### Example Cypher generation
259
+ ```
260
+ from peft import PeftModel, PeftConfig
261
+ import torch
262
+ from transformers import (
263
+ AutoModelForCausalLM,
264
+ AutoTokenizer,
265
+ )
266
+
267
+ instruction = (
268
+ "Generate Cypher statement to query a graph database. "
269
+ "Use only the provided relationship types and properties in the schema. \n"
270
+ "Schema: {schema} \n Question: {question} \n Cypher output: "
271
+ )
272
+
273
+ def prepare_chat_prompt(question, schema) -> list[dict]:
274
+ chat = [
275
+ {
276
+ "role": "user",
277
+ "content": instruction.format(
278
+ schema=schema, question=question
279
+ ),
280
+ }
281
+ ]
282
+ return chat
283
+
284
+ def _postprocess_output_cypher(output_cypher: str) -> str:
285
+ # Remove any explanation. E.g. MATCH...\n\n**Explanation:**\n\n -> MATCH...
286
+ # Remove cypher indicator. E.g.```cypher\nMATCH...```` --> MATCH...
287
+ # Note: Possible to have both:
288
+ # E.g. ```cypher\nMATCH...````\n\n**Explanation:**\n\n --> MATCH...
289
+ partition_by = "**Explanation:**"
290
+ output_cypher, _, _ = output_cypher.partition(partition_by)
291
+ output_cypher = output_cypher.strip("`\n")
292
+ output_cypher = output_cypher.lstrip("cypher\n")
293
+ output_cypher = output_cypher.strip("`\n ")
294
+ return output_cypher
295
+
296
+ # Model
297
+ base_model_name = "google/gemma-2-9b-it"
298
+ model_name = "neo4j/text2cypher-gemma-2-9b-it-finetuned-2024v1"
299
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
300
+ config = PeftConfig.from_pretrained(model_name)
301
+ model = PeftModel.from_pretrained(base_model, model_name)
302
+
303
+ # Question
304
+ question = "What are the movies of Tom Hanks?"
305
+ schema = "(:Actor)-[:ActedIn]->(:Movie)"
306
+ new_message = prepare_chat_prompt(question=question, schema=schema)
307
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
308
+ prompt = tokenizer.apply_chat_template(new_message, add_generation_prompt=True, tokenize=False)
309
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
310
+
311
+ # Any other parameters
312
+ model_generate_parameters = {
313
+ "top_p": 0.9,
314
+ "temperature": 0.2,
315
+ "max_new_tokens": 512,
316
+ "do_sample": True,
317
+ "pad_token_id": tokenizer.eos_token_id,
318
+ }
319
+
320
+ inputs.to(model.device)
321
+ model.eval()
322
+ with torch.no_grad():
323
+ tokens = model.generate(**inputs, **model_generate_parameters)
324
+ tokens = tokens[:, inputs.input_ids.shape[1] :]
325
+ raw_outputs = tokenizer.batch_decode(tokens, skip_special_tokens=True)
326
+ outputs = [_postprocess_output_cypher(output) for output in raw_outputs]
327
+
328
+ print(outputs)
329
+ > ["MATCH (hanks:Actor {name: 'Tom Hanks'})-[:ActedIn]->(m:Movie) RETURN m"]
330
+ ```