jordiclive's picture
Update README.md
2e57dd6
|
raw
history blame
5.23 kB
metadata
license: mit
datasets:
  - sahil2801/CodeAlpaca-20k
  - yahma/alpaca-cleaned
  - databricks/databricks-dolly-15k
  - OpenAssistant/oasst1
  - jeffwan/sharegpt_vicuna
  - qwedsacf/grade-school-math-instructions
  - vicgalle/alpaca-gpt4
language:
  - en
tags:
  - sft
pipeline_tag: text-generation
widget:
  - text: >-
      <|prompter|>What is a meme, and what's the history behind this
      word?</s><|assistant|>
  - text: <|prompter|>What's the Earth total population</s><|assistant|>
  - text: <|prompter|>Write a story about future of AI development</s><|assistant|>

LoRA Adapter for LLaMA 33B 'pre-trained' on several datasets part of the OpenAssistant project

This repo contains a low-rank adapter for LLaMA 33B fit on datasets part of the OpenAssistant project.

The model was trained with flash attention and gradient checkpointing and deepspeed stage 2 on 8 x A100 80gb

Dataset Details

  • sahil2801/CodeAlpaca-20k
  • yahma/alpaca-cleaned
  • databricks/databricks-dolly-15k
  • OpenAssistant/oasst1
  • jeffwan/sharegpt_vicuna
  • qwedsacf/grade-school-math-instructions
  • vicgalle/alpaca-gpt4

Model Details

  • Developed as part of the OpenAssistant Project

  • Model type: PEFT Adapter for frozen LLaMA

  • Language: English

  • Epochs: 1

  • Batch size: 128

  • Max Length: 2048

  • Learning rate: 5e-5

  • Lora r: 64

  • Lora Alpha: 32

Prompting

Two special tokens are used to mark the beginning of user and assistant turns: <|prompter|> and <|assistant|>. Each turn ends with a <|endoftext|> token.

Input prompt example:

<|prompter|>What is a meme, and what's the history behind this word?</s><|assistant|>

The input ends with the <|assistant|> token to signal that the model should start generating the assistant reply.

Example Inference Code (Note several embeddings need to be loaded along with the LoRA weights):

from pathlib import Path

import torch
import transformers
from huggingface_hub import hf_hub_download
from peft import PeftModel
from transformers import GenerationConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16
repo_id = "jordiclive/lora-llama-33B-alpaca_gpt4-dolly_15k-vicuna-r64"
base_model = "decapoda-research/llama-30b-hf"

# Model Loading
def add_embeddings(model, embed_path, tokenizer):
    old_embeddings = model.get_input_embeddings()
    old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
    new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
    new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
    model._init_weights(new_embeddings)
    embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
    vocab_size = tokenizer.vocab_size
    new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
    new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.to(
        new_embeddings.weight.dtype
    ).to(new_embeddings.weight.device)
    model.set_input_embeddings(new_embeddings)
    model.tie_weights()



def load_peft_model(model, peft_model_path, tokenizer):
    embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
    model.resize_token_embeddings(tokenizer.vocab_size + torch.load(embed_weights).shape[0])
    model.config.eos_token_id = tokenizer.eos_token_id
    model.config.bos_token_id = tokenizer.bos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    model = PeftModel.from_pretrained(
        model,
        model_id=peft_model_path,
        torch_dtype=model.dtype,
    )
    model.eos_token_id = tokenizer.eos_token_id
    add_embeddings(model, embed_weights, tokenizer)
    return model


tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id)

model = transformers.AutoModelForCausalLM.from_pretrained(
    base_model, torch_dtype=dtype, trust_remote_code=True,
)
model = load_peft_model(model, repo_id, tokenizer)


# device  configuration
model = model.to(device)
if dtype == torch.float16:
    model = model.half()


# Choose Generation parameters

generation_config = GenerationConfig(
    temperature=0.1,
    top_p=0.75,
    top_k=40,
    num_beams=4,
)


def format_system_prompt(prompt, eos_token="</s>"):
    return "{}{}{}{}".format("<|prompter|>", prompt, eos_token, "<|assistant|>")


def generate(prompt, generation_config=generation_config, max_new_tokens=2048, device=device):
    prompt = format_system_prompt(prompt)  # OpenAssistant Prompt Format expected
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=True,
            max_new_tokens=max_new_tokens,
            eos_token_id=model.eos_token_id,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s)
    print("Text generated:")
    print(output)
    return output


generate("What is a meme, and what's the history behind this word?")
generate("What's the Earth total population")
generate("Write a story about future of AI development")