notdiamond-0001 / README.md
r0ymanesco's picture
Update README.md
fe77b27
|
raw
history blame
2.37 kB
metadata
license: gpl-3.0

notdiamond-0001

notdiamond-0001 automatically determines whether to send queries to GPT-3.5 or GPT-4, depending on which model is best-suited for your task. We've trained notdiamond-0001 on hundreds of thousands of data points from robust, cross-domain evaluation benchmarks.

The model is a classifier and will return either GPT-3.5 or GPT-4. You determine which version of each model you want to use and make the calls client-side with your own keys.

To use notdiamond-0001, format your queries using the following prompt with your query appended at the end

    query = "Can you write a function that counts from 1 to 10?"

    formatted_prompt = f"""I would like you to determine whether a given query should be sent to GPT-3.5 or GPT-4.
        In general, the following types of queries should get sent to GPT-3.5:
        Explanation
        Summarization
        Writing
        Informal conversation
        History, government, economics, literature, social studies
        Simple math questions
        Simple coding questions
        Simple science questions

        In general, the following types of queries should get sent to GPT-4:
        Advanced coding questions
        Advanced math questions
        Advanced science questions
        Legal questions
        Medical questions
        Sensitive/inappropriate queries

        Your job is to determine whether the following query should be sent to GPT-3.5 or GPT-4.

        Query:
        {query}"""

You can then determine the model to call as follows

    import torch
    from transformers import AutoTokenizer, AutoModelForSequenceClassification

    id2label = {0: 'gpt-3.5', 1: 'gpt-4'}
    tokenizer = AutoTokenizer.from_pretrained("notdiamond/notdiamond-0001")
    model = AutoModelForSequenceClassification.from_pretrained("notdiamond/notdiamond-0001")
    max_length = self._get_max_length(model)

    inputs = tokenizer(formatted_prompt, truncation=True, max_length=max_length, return_tensors="pt")
    logits = model(**inputs).logits
    model_id = logits.argmax().item()
    model_to_call = id2label[model_id]

For more details on how you can integrate this into your techstack and have notdiamond-0001 help you reduce latency and cost, check out our documentation.