|
--- |
|
license: gpl-3.0 |
|
--- |
|
# notdiamond-0001 |
|
|
|
notdiamond-0001 automatically determines whether to send queries to GPT-3.5 or GPT-4, depending on which model is best-suited for your task. We've trained notdiamond-0001 on hundreds of thousands of data points from robust, cross-domain evaluation benchmarks. |
|
|
|
The model is a classifier and will return either GPT-3.5 or GPT-4. You determine which version of each model you want to use and make the calls client-side with your own keys. |
|
|
|
To use notdiamond-0001, format your queries using the following prompt with your query appended at the end |
|
``` python |
|
query = "Can you write a function that counts from 1 to 10?" |
|
|
|
formatted_prompt = f"""I would like you to determine whether a given query should be sent to GPT-3.5 or GPT-4. |
|
In general, the following types of queries should get sent to GPT-3.5: |
|
Explanation |
|
Summarization |
|
Writing |
|
Informal conversation |
|
History, government, economics, literature, social studies |
|
Simple math questions |
|
Simple coding questions |
|
Simple science questions |
|
|
|
In general, the following types of queries should get sent to GPT-4: |
|
Advanced coding questions |
|
Advanced math questions |
|
Advanced science questions |
|
Legal questions |
|
Medical questions |
|
Sensitive/inappropriate queries |
|
|
|
Your job is to determine whether the following query should be sent to GPT-3.5 or GPT-4. |
|
|
|
Query: |
|
{query}""" |
|
``` |
|
|
|
|
|
You can then determine the model to call as follows |
|
``` python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
id2label = {0: 'gpt-3.5', 1: 'gpt-4'} |
|
tokenizer = AutoTokenizer.from_pretrained("notdiamond/notdiamond-0001") |
|
model = AutoModelForSequenceClassification.from_pretrained("notdiamond/notdiamond-0001") |
|
max_length = self._get_max_length(model) |
|
|
|
inputs = tokenizer(formatted_prompt, truncation=True, max_length=max_length, return_tensors="pt") |
|
logits = model(**inputs).logits |
|
model_id = logits.argmax().item() |
|
model_to_call = id2label[model_id] |
|
``` |
|
|
|
For more details on how you can integrate this into your techstack and have notdiamond-0001 help you reduce latency and cost, check out our [documentation](https://notdiamond.readme.io/reference/introduction-1). |
|
|