LLama 3 for router module in RAG (a toy example)
While developing complex RAG applications, I found a common need for router functionality to map user queries to different system workflows (and APIs). The router acts as a dispatcher that can enhance responsiveness and accuracy by choosing the best workflow or API based on the query context. This implies that we need to produce structured output from unstructured input text.
To this end, I undertook a simple exercise to fine-tune the new Llama 3 model to process text input and generate JSON-like output (here is the colab). My hope was that we could avoid some external dependencies for this part of the system by seamlessly integrating various models to reinforce complex applications in production settings. I believed that building a robust critical infrastructure for the semantic modules required choosing the right LLM for a given task.
For training, we used structured data from azizshaw. The dataset contained 485 rows and included 'input', 'output', and 'instruction' columns.
For a quick evaluation, we used another dataset for text-to-JSON, the Diverse Restricted JSON Data Extraction, curated by the paraloq analytics team (here).
Run the model for inference:
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
alpaca_prompt.format(
"""
Convert this text into a JSON object. Create field names that meaningfully represent the data being reported.
It is extremely important that you construct a well-formed object.
""", # instruction
"**Medical Document** **Patient Information** * Patient ID: PT123456 * Name: Jane Doe * Date of Birth: 1980-01-01 * Gender: Female * Medical Conditions: * Asthma * Hypertension **Prescription Information** * Prescription ID: RX123456 * Date Prescribed: 2023-03-08 * Date Expires: 2023-09-07 * Status: Active **Medication Information** * Medication ID: MD123456 * Name: Albuterol * Dosage: 200 mcg * Units: mcg * Instructions: Inhale 2 puffs every 4-6 hours as needed for shortness of breath. * Refills: 3 **Pharmacy Information** * Pharmacy ID: PH123456 * Name: CVS Pharmacy * Address: 123 Main Street, Anytown, CA 12345 * Phone: (123) 456-7890 **Additional Information** * The patient has been using Albuterol for the past 5 years to manage her asthma. * The patient has been advised to use a spacer device with the Albuterol inhaler to improve the delivery of the medication to the lungs. * The patient should avoid using Albuterol more than 4 times per day. * The patient should contact her doctor if her asthma symptoms worsen or if she experiences any side effects from the medication. **Instructions for the Patient** * Take Albuterol exactly as prescribed by your doctor. * Do not take more than the prescribed dosage. * Use a spacer device with the Albuterol inhaler. * Avoid using Albuterol more than 4 times per day. * Contact your doctor if your asthma symptoms worsen or if you experience any side effects from the medication. **Signature** [Doctor's Name] [Date]", # input
"", # output - leave this blank for generation!
)
], return_tensors = "pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens = 1000, use_cache = True)
tokenizer.batch_decode(outputs)
import json
text = "{'feature1': {'detail': {'text': 'Medical Document', 'pid': 'PT123456', 'name': 'Jane Doe', 'dob': '1980-01-01', 'gender': 'Female', 'conditions': ['Asthma', 'Hypertension']}, 'detail2': {'text': 'Prescription Information', 'pid': 'RX123456', 'date': '2023-03-08', 'expires': '2023-09-07','status': 'Active'}, 'detail3': {'text': 'Medication Information', 'id': 'MD123456', 'name': 'Albuterol', 'dosage': '200 mcg', 'units':'mcg', 'instructions': 'Inhale 2 puffs every 4-6 hours as needed for shortness of breath.','refills': '3'}, 'detail4': {'text': 'Pharmacy Information', 'id': 'PH123456', 'name': 'CVS Pharmacy', 'address': '123 Main Street, Anytown, CA 12345', 'phone': '(123) 456-7890'}}, 'feature2': {'detail': {'text': 'The patient has been using Albuterol for the past 5 years to manage her asthma.', 'pid': '', 'name': '', 'dob': '', 'gender': '', 'conditions': []}, 'detail2': {'text': 'The patient has been advised to use a spacer device with the Albuterol inhaler to improve the delivery of the medication to the lungs.', 'pid': '', 'name': '', 'date': '', 'expires': '','status': ''}, 'detail3': {'text': 'The patient should avoid using Albuterol more than 4 times per day.', 'id': '', 'name': '', 'dosage': '', 'units': '', 'instructions': '','refills': ''}, 'detail4': {'text': 'The patient should contact her doctor if her asthma symptoms worsen or if she experiences any side effects from the medication.', 'pid': '', 'name': '', 'address': '', 'phone': ''}}}"
output = text.replace("'", '"')
data_dict = json.loads(output)
len(data_dict)
pprint.pprint(data_dict['feature1'])
The result:
{'detail': {'conditions': ['Asthma', 'Hypertension'],
'dob': '1980-01-01',
'gender': 'Female',
'name': 'Jane Doe',
'pid': 'PT123456',
'text': 'Medical Document'},
'detail2': {'date': '2023-03-08',
'expires': '2023-09-07',
'pid': 'RX123456',
'status': 'Active',
'text': 'Prescription Information'},
'detail3': {'dosage': '200 mcg',
'id': 'MD123456',
'instructions': 'Inhale 2 puffs every 4-6 hours as needed for '
'shortness of breath.',
'name': 'Albuterol',
'refills': '3',
'text': 'Medication Information',
'units': 'mcg'},
'detail4': {'address': '123 Main Street, Anytown, CA 12345',
'id': 'PH123456',
'name': 'CVS Pharmacy',
'phone': '(123) 456-7890',
'text': 'Pharmacy Information'}}
Results Notes
- Considering that we are working with a toy example (4-byte quantization model, tiny dataset for SFT), the results seem like a good starting point, credit for Llama 3.
- As we fine-tune the model with examples of strings using single quotes enclosed names, the model learns to use this notation, resulting in output generated with single quotes. This approach is far from optimal for securing our workflow and ensuring robust code.
- Another point to note is that the response tends to repeat information.
Uploaded model
- Developed by: sccastillo
- License: apache-2.0
- Finetuned from model : unsloth/llama-3-8b-bnb-4bit This llama model was trained 2x faster with Unsloth and Huggingface's TRL library.
- Downloads last month
- 2
Model tree for sccastillo/llama3_router
Base model
meta-llama/Meta-Llama-3-8B