Octopus V1: On-device language model for function calling of software APIs

- Nexa AI Product - ArXiv

nexa-octopus

Introducing Octopus-V1

Octopus-V1, a series of advanced open-source language models with parameters ranging from 2B to 7B, represents Nexa AI's breakthrough in AI-driven software API interactions. Developed through meticulous fine-tuning using a specialized dataset from 30k+ RapidHub APIs, Octopus-V1 excels in understanding API structures and syntax. The models leverage conditional masking techniques to ensure precise, format-compliant API calls without compromising inference speed. A novel benchmark introduced alongside Octopus-V1 assesses its superior performance against GPT-4 in software API usage, signifying a leap forward in automating software development and API integration.

馃摫 Support 30k+ APIs from RapidAPI Hub: Octopus leverages an extensive dataset derived from over 30,000 popular APIs on RapidAPI Hub. This rich dataset ensures broad coverage and understanding of diverse software API interactions, enhancing the model's utility across various applications.

馃悪 Accuracy: Fine-tuning on models with 2B, 3B, and 7B parameters yields Octopus, which surpasses GPT-4 in API call accuracy. The introduction of a conditional mask further refines its precision, making Octopus highly reliable for software API interactions.

馃幆 Conditional Masking: A novel conditional masking technique is employed to ensure outputs adhere to the desired formats and reduce errors. This approach not only maintains fast inference speeds but also substantially increases the model's accuracy in generating function calls and parameters.

Here is a full list of fined-tuned models in the Octopus series:

Example Use Cases

You can run the model on a GPU using the following code.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F

prompt = """You are an assistant, and you need to call find appropriate functions according to the query of the users. Firstly, find the relevant functions, then get the function arguments by understanding the user's query. The following functions are available for you to fetch further data to answer user questions: 

Function: 
def basketapi_league_seasons(tournamentId): 
    '''
    Get access to historical and current seasons for a specific basketball league using the tournament ID. 
    Args: 
        tournamentId (number): The argument tournamentId is a number that represents the identifier of a tournament in the context of the function. 
    ''' 

def os_sports_goal_distributions(unique_tournament_id,season_id,team_id): 
    '''
    Get goal distributions by team, tournament ID, and season ID for in-depth sports performance analysis. 
    Args: 
        unique_tournament_id (number): The unique_tournament_id argument is a number representing the unique identifier for a tournament. 
        season_id (number): The argument season_id is a number that represents the identifier of the season for the search query string. 
        team_id (number): The team_id argument represents the teams identification number. 
    '''

def transfermarket_get_table(id,seasonID,domain,homeAway):
    '''
    Get tables by competition and season from transfermarket platform for comprehensive and detailed competition and season-related data. 
    Args: 
        id (string): The function argument "id" is a string representing an identifier. 
        seasonID (string): The seasonID argument is a string that represents the identifier for a specific season. 
        domain (string): The domain argument is a string that represents a search query. 
        homeAway (string): The homeAway argument is a string that represents the home or away status for a sports event. 
    '''

def no_relevant_function(): 
    ''' 
    Call this when no other provided function can be called to answer the user query. 
    '''

def soccersapi_stage_id(t,id):
    '''
    Get stage ID for a soccer match or event, access specific details like schedules, teams, and relevant data.
    Args: 
        t (string): The argument "t" of type string represents the search query string. 
        id (number): This function argument is an identifier represented by a number, typically used to uniquely reference a specific entity within the system. 
    ''' 

Request the complete season data for a recently established basketball league using the tournament ID 309, aiming to analyze its inaugural seasons. 
Response:
"""

class NexaGenerator:
    def __init__(self, model_id: AutoModelForCausalLM, tokenizer_id: AutoTokenizer):
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=torch.bfloat16, device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
        self.eos_token_id = self.tokenizer.eos_token_id
        self.token2id = self.tokenizer.get_vocab()

    def deterministic_generate_next_token(
        self,
        input_ids: torch.Tensor,  # shape: (1, seq_len), no support for batch yet
        add_conditional_mask: bool = False,
        usable_token_ids: torch.tensor = None,  # element is token id
    ) -> torch.tensor:
        if add_conditional_mask:
            assert usable_token_ids is not None, "usable_token_ids is required"
        next_logits = self.model(input_ids)["logits"][:, -1:]
        if add_conditional_mask:
            mask = torch.full_like(next_logits, float("-inf"))
            mask.scatter_(
                2,
                usable_token_ids.unsqueeze(0).unsqueeze(0),
                next_logits.gather(2, usable_token_ids.unsqueeze(0).unsqueeze(0)),
            )
            next_token_id = torch.argmax(mask, dim=-1)
        else:
            next_token_id = torch.argmax(next_logits, dim=-1)
        return next_token_id

nexa_generator = NexaGenerator(model_id="NexaAIDev/Octopus-v1", tokenizer_id="NexaAIDev/Octopus-v1")

def get_response(prompt):
    input_ids = nexa_generator.tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
    for _ in range(200):
        next_token_id = nexa_generator.deterministic_generate_next_token(
            input_ids=input_ids,
            add_conditional_mask=False,
            usable_token_ids=None,
        )
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        if next_token_id[0].item() == nexa_generator.eos_token_id:
            break
    generated_text = nexa_generator.tokenizer.batch_decode(input_ids)
    return generated_text[0]

print(get_response(prompt))

Evaluation

Comparison of accuracy between the GPT-3.5 and GPT-4 models, alongside our pretrained models in the Octopus series.
Accuracy without Conditional Mask
Comparison of accuracy following the introduction of a conditional mask in the Octopus series models.
Accuracy with Conditional Mask

License

This model was trained on commercially viable data and is under the Nexa AI community disclaimer.

References

We thank the Meta llama2 team, Google Gemma team, Stability AI's Stable Code team for their amazing models!

@misc{gemma-2023-open-models,
  author = {{Gemma Team, Google DeepMind}},
  title = {Gemma: Open Models Based on Gemini Research and Technology},
  url = {https://goo.gle/GemmaReport},  
  year = {2023},
}

@article{touvron2023llama,
  title={Llama 2: Open foundation and fine-tuned chat models},
  author={Touvron, Hugo and Martin, Louis and Stone, Kevin and Albert, Peter and Almahairi, Amjad and Babaei, Yasmine and Bashlykov, Nikolay and Batra, Soumya and Bhargava, Prajjwal and Bhosale, Shruti and others},
  journal={arXiv preprint arXiv:2307.09288},
  year={2023}
}

@misc{stable-code-3b,
  author = {Pinnaparaju, Nikhil and Adithyan, Reshinth and Phung, Duy and Tow, Jonathan and Baicoianu, James and Cooper, Nathan},
  title = {Stable Code 3B},
  url = {https://huggingface.co/stabilityai/stable-code-3b},
  year = {2023}
}

Citation

@misc{chen2024octopus,
      title={Octopus: On-device language model for function calling of software APIs}, 
      author={Wei Chen and Zhiyuan Li and Mingyuan Ma},
      year={2024},
      eprint={2404.01549},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

Contact

Please contact us to reach out for any issues and comments!

Downloads last month
17
Safetensors
Model size
8.54B params
Tensor type
F32
BF16
I8
Inference Examples
Inference API (serverless) has been turned off for this model.