jordiclive commited on
Commit
1ea9dde
1 Parent(s): 6ba3710

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +169 -0
README.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - sahil2801/CodeAlpaca-20k
5
+ - yahma/alpaca-cleaned
6
+ - databricks/databricks-dolly-15k
7
+ - OpenAssistant/oasst1
8
+ - jeffwan/sharegpt_vicuna
9
+ - qwedsacf/grade-school-math-instructions
10
+ - vicgalle/alpaca-gpt4
11
+ language:
12
+ - en
13
+ tags:
14
+ - sft
15
+ pipeline_tag: text-generation
16
+ widget:
17
+ - text: >-
18
+ <|prompter|>What is a meme, and what's the history behind this
19
+ word?</s><|assistant|>
20
+ - text: <|prompter|>What's the Earth total population</s><|assistant|>
21
+ - text: <|prompter|>Write a story about future of AI development</s><|assistant|>
22
+ ---
23
+
24
+
25
+
26
+ # LoRA Adapter for LLaMA 33B 'pre-trained' on several datasets part of the OpenAssistant project
27
+
28
+ This repo contains a low-rank adapter for **LLaMA 33B** fit on datasets part of the OpenAssistant project.
29
+
30
+
31
+
32
+ The model was trained with flash attention and gradient checkpointing and deepspeed stage 2 on 8 x A100 80gb
33
+
34
+ ## Dataset Details
35
+
36
+ - sahil2801/CodeAlpaca-20k
37
+ - yahma/alpaca-cleaned
38
+ - databricks/databricks-dolly-15k
39
+ - OpenAssistant/oasst1
40
+ - jeffwan/sharegpt_vicuna
41
+ - qwedsacf/grade-school-math-instructions
42
+ - vicgalle/alpaca-gpt4
43
+
44
+ ## Model Details
45
+
46
+ - **Developed** as part of the OpenAssistant Project
47
+ - **Model type:** PEFT Adapter for frozen LLaMA
48
+ - **Language:** English
49
+
50
+ - Epochs: 1
51
+ - Batch size: 128
52
+ - Max Length: 2048
53
+ - Learning rate: 5e-5
54
+ - Lora _r_: 16
55
+ - Lora Alpha: 32
56
+
57
+ ## Prompting
58
+
59
+ Two special tokens are used to mark the beginning of user and assistant turns:
60
+ `<|prompter|>` and `<|assistant|>`. Each turn ends with a `<|endoftext|>` token.
61
+
62
+ Input prompt example:
63
+ ```
64
+ <|prompter|>What is a meme, and what's the history behind this word?</s><|assistant|>
65
+ ```
66
+ The input ends with the `<|assistant|>` token to signal that the model should
67
+ start generating the assistant reply.
68
+
69
+
70
+ # Example Inference Code (Note several embeddings need to be loaded along with the LoRA weights):
71
+
72
+ ```
73
+ from pathlib import Path
74
+
75
+ import torch
76
+ import transformers
77
+ from huggingface_hub import hf_hub_download
78
+ from peft import PeftModel
79
+ from transformers import GenerationConfig
80
+
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+ dtype = torch.float16
83
+ repo_id = "jordiclive/alpaca_gpt4-dolly_15k-vicuna-lora-30b-r64"
84
+ base_model = "decapoda-research/llama-30b-hf"
85
+
86
+ # Model Loading
87
+ def transfer_embeddings(model, embed_path, tokenizer):
88
+ old_embeddings = model.get_input_embeddings()
89
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
90
+ new_embeddings = torch.nn.Embedding(old_num_tokens, old_embedding_dim)
91
+ new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)
92
+ model._init_weights(new_embeddings)
93
+ embed_weights = torch.load(embed_path, map_location=old_embeddings.weight.device)
94
+ vocab_size = tokenizer.vocab_size
95
+ new_embeddings.weight.data[:vocab_size, :] = old_embeddings.weight.data[:vocab_size, :]
96
+ new_embeddings.weight.data[vocab_size : vocab_size + embed_weights.shape[0], :] = embed_weights.weight.data.to(
97
+ new_embeddings.weight.dtype
98
+ ).to(new_embeddings.weight.device)
99
+ model.set_input_embeddings(new_embeddings)
100
+ model.tie_weights()
101
+
102
+
103
+ def load_peft_model(model, peft_model_path, tokenizer):
104
+ embed_weights = hf_hub_download(peft_model_path, "extra_embeddings.pt")
105
+ model.resize_token_embeddings(tokenizer.vocab_size + embed_weights.shape[0])
106
+ model.config.eos_token_id = tokenizer.eos_token_id
107
+ model.config.bos_token_id = tokenizer.bos_token_id
108
+ model.config.pad_token_id = tokenizer.pad_token_id
109
+ model = PeftModel.from_pretrained(
110
+ model,
111
+ model_id=peft_model_path,
112
+ torch_dtype=model.dtype,
113
+ )
114
+ model.eos_token_id = tokenizer.eos_token_id
115
+ transfer_embeddings(model, Path(peft_model_path).joinpath("extra_embeddings.pt"), tokenizer)
116
+ return model
117
+
118
+
119
+ tokenizer = transformers.AutoTokenizer.from_pretrained(repo_id)
120
+
121
+ model = transformers.AutoModelForCausalLM.from_pretrained(
122
+ base_model, torch_dtype=dtype, trust_remote_code=True, cache_dir="/mnt/data/jordiclive/data_cache"
123
+ )
124
+ model = load_peft_model(model, repo_id, tokenizer)
125
+
126
+
127
+ # device configuration
128
+ model = model.to(device)
129
+
130
+
131
+ # Choose Generation parameters
132
+
133
+ generation_config = GenerationConfig(
134
+ temperature=0.1,
135
+ top_p=0.75,
136
+ top_k=40,
137
+ num_beams=4,
138
+ )
139
+
140
+
141
+ def format_system_prompt(prompt, eos_token="</s>"):
142
+ return "{}{}{}{}".format("<|prompter|>", prompt, eos_token, "<|assistant|>")
143
+
144
+
145
+ def generate(prompt, generation_config=generation_config, max_new_tokens=2048, device=device):
146
+ prompt = format_system_prompt(prompt) # OpenAssistant Prompt Format expected
147
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
148
+ with torch.no_grad():
149
+ generation_output = model.generate(
150
+ input_ids=input_ids,
151
+ generation_config=generation_config,
152
+ return_dict_in_generate=True,
153
+ output_scores=True,
154
+ max_new_tokens=max_new_tokens,
155
+ eos_token_id=model.eos_token_id,
156
+ )
157
+ s = generation_output.sequences[0]
158
+ output = tokenizer.decode(s)
159
+ print("Text generated:")
160
+ print(output)
161
+ return output
162
+
163
+
164
+ generate("What is a meme, and what's the history behind this word?")
165
+ generate("What's the Earth total population")
166
+ generate("Write a story about future of AI development")
167
+
168
+
169
+ ```