gemma-orchid-7b-dpo / README.md
macadeliccc's picture
Adding Evaluation Results (#3)
765c358 verified
|
raw
history blame
8.78 kB
metadata
license: other
datasets:
  - Thermostatic/flowers
  - jondurbin/truthy-dpo-v0.1
  - Intel/orca_dpo_pairs
  - glaiveai/glaive-function-calling-v2
license_name: gemma-terms-of-use
license_link: https://ai.google.dev/gemma/terms
model-index:
  - name: gemma-orchid-7b-dpo
    results:
      - task:
          type: text-generation
          name: Text Generation
        dataset:
          name: AI2 Reasoning Challenge (25-Shot)
          type: ai2_arc
          config: ARC-Challenge
          split: test
          args:
            num_few_shot: 25
        metrics:
          - type: acc_norm
            value: 62.88
            name: normalized accuracy
        source:
          url: >-
            https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=macadeliccc/gemma-orchid-7b-dpo
          name: Open LLM Leaderboard
      - task:
          type: text-generation
          name: Text Generation
        dataset:
          name: HellaSwag (10-Shot)
          type: hellaswag
          split: validation
          args:
            num_few_shot: 10
        metrics:
          - type: acc_norm
            value: 80.95
            name: normalized accuracy
        source:
          url: >-
            https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=macadeliccc/gemma-orchid-7b-dpo
          name: Open LLM Leaderboard
      - task:
          type: text-generation
          name: Text Generation
        dataset:
          name: MMLU (5-Shot)
          type: cais/mmlu
          config: all
          split: test
          args:
            num_few_shot: 5
        metrics:
          - type: acc
            value: 61.41
            name: accuracy
        source:
          url: >-
            https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=macadeliccc/gemma-orchid-7b-dpo
          name: Open LLM Leaderboard
      - task:
          type: text-generation
          name: Text Generation
        dataset:
          name: TruthfulQA (0-shot)
          type: truthful_qa
          config: multiple_choice
          split: validation
          args:
            num_few_shot: 0
        metrics:
          - type: mc2
            value: 53.27
        source:
          url: >-
            https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=macadeliccc/gemma-orchid-7b-dpo
          name: Open LLM Leaderboard
      - task:
          type: text-generation
          name: Text Generation
        dataset:
          name: Winogrande (5-shot)
          type: winogrande
          config: winogrande_xl
          split: validation
          args:
            num_few_shot: 5
        metrics:
          - type: acc
            value: 77.51
            name: accuracy
        source:
          url: >-
            https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=macadeliccc/gemma-orchid-7b-dpo
          name: Open LLM Leaderboard
      - task:
          type: text-generation
          name: Text Generation
        dataset:
          name: GSM8k (5-shot)
          type: gsm8k
          config: main
          split: test
          args:
            num_few_shot: 5
        metrics:
          - type: acc
            value: 50.19
            name: accuracy
        source:
          url: >-
            https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=macadeliccc/gemma-orchid-7b-dpo
          name: Open LLM Leaderboard

Gemma Orchid 7b

image/webp

Built with Axolotl

This model is the second checkpoint of a future project. Its capable of function calling as well as having a strong base in communicational skills.

This model has been finetuned on roughly 80k samples so far.

Training

  • Time to complete: ~20 hours
  • Datasets: Thermostatic/flowers, Intel/orca_dpo_pairs, jondurbin/truthy-dpo-v0.1, glaiveai/glaive_function_calling_v2
  • Evaluation loss: 0.69
  • Method: LoRa
  • Prompt Format: ChatML

Thermostatic/flowers is a blend of open source model generations formatted in ShareGPT. It also includes all of capybara.

This model has been exposed to a wide variety of data. macadeliccc/gemma-function-calling-7b is suitable to finetune further with the dataset of your choosing.

Running the model on a CPU

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("macadeliccc/gemma-orchid-7b-dpo")
model = AutoModelForCausalLM.from_pretrained("macadeliccc/gemma-orchid-7b-dpo")

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

Running the model on a single / multi GPU

# pip install accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("macadeliccc/gemma-orchid-7b-dpo")
model = AutoModelForCausalLM.from_pretrained("macadeliccc/gemma-orchid-7b-dpo", device_map="auto")

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

Running the model on a GPU using different precisions

  • Using torch.float16
# pip install accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("macadeliccc/gemma-orchid-7b-dpo")
model = AutoModelForCausalLM.from_pretrained("macadeliccc/gemma-orchid-7b-dpo", device_map="auto", torch_dtype=torch.float16)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))
  • Using torch.bfloat16
# pip install accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("macadeliccc/gemma-orchid-7b-dpo")
model = AutoModelForCausalLM.from_pretrained("macadeliccc/gemma-orchid-7b-dpo", device_map="auto", torch_dtype=torch.bfloat16)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

Quantized Versions through bitsandbytes

  • Using 8-bit precision (int8)
# pip install bitsandbytes accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

tokenizer = AutoTokenizer.from_pretrained("macadeliccc/gemma-orchid-7b-dpo")
model = AutoModelForCausalLM.from_pretrained("macadeliccc/gemma-orchid-7b-dpo", quantization_config=quantization_config)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))
  • Using 4-bit precision
# pip install bitsandbytes accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

tokenizer = AutoTokenizer.from_pretrained("macadeliccc/gemma-orchid-7b-dpo")
model = AutoModelForCausalLM.from_pretrained("macadeliccc/gemma-orchid-7b-dpo", quantization_config=quantization_config)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

Other optimizations

  • Flash Attention 2

First make sure to install flash-attn in your environment pip install flash-attn

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
+   attn_implementation="flash_attention_2"
).to(0)

Inputs and outputs

  • Input: Text string, such as a question, a prompt, or a document to be summarized.
  • Output: Generated English-language text in response to the input, such as an answer to a question, or a summary of a document.

Evaluations

In progress

ExLlamaV2

Available here

Open LLM Leaderboard Evaluation Results

Detailed results can be found here

Metric Value
Avg. 64.37
AI2 Reasoning Challenge (25-Shot) 62.88
HellaSwag (10-Shot) 80.95
MMLU (5-Shot) 61.41
TruthfulQA (0-shot) 53.27
Winogrande (5-shot) 77.51
GSM8k (5-shot) 50.19