File size: 3,437 Bytes
99660e9
 
 
 
 
 
 
 
 
 
 
 
bfc290c
99660e9
 
 
 
 
 
 
 
 
 
 
 
 
7c3219b
99660e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f24ac
 
 
99660e9
 
 
 
 
77f24ac
99660e9
77f24ac
 
99660e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
'''
simple demo adapted from [gradio](https://gradio.app/creating-a-chatbot/).
'''

import gradio as gr
import random
import time
import transformers
import os
import json
import torch
import argparse
from tqdm import tqdm
from transformers import LlamaTokenizer, LlamaForCausalLM


def apply_delta(base_model_path, target_model_path, delta_path):
    print(f"Loading the delta weights from {delta_path}")
    delta_tokenizer = LlamaTokenizer.from_pretrained(delta_path, use_fast=False)
    delta = LlamaForCausalLM.from_pretrained(
        delta_path, low_cpu_mem_usage=True, torch_dtype=torch.float16
    )

    print(f"Loading the base model from {base_model_path}")
    base_tokenizer = LlamaTokenizer.from_pretrained(base_model_path, use_fast=False)
    base = LlamaForCausalLM.from_pretrained(
        base_model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16
    )

    # following alpaca training recipe, we have added new initialized tokens
    DEFAULT_PAD_TOKEN = "[PAD]"
    DEFAULT_EOS_TOKEN = "</s>"
    DEFAULT_BOS_TOKEN = "<s>"
    DEFAULT_UNK_TOKEN = "<unk>"
    special_tokens_dict = {
        "pad_token": DEFAULT_PAD_TOKEN,
        "eos_token": DEFAULT_EOS_TOKEN,
        "bos_token": DEFAULT_BOS_TOKEN,
        "unk_token": DEFAULT_UNK_TOKEN,
    }
    num_new_tokens = base_tokenizer.add_special_tokens(special_tokens_dict)
    base.resize_token_embeddings(len(base_tokenizer))
    input_embeddings = base.get_input_embeddings().weight.data
    output_embeddings = base.get_output_embeddings().weight.data

    input_embeddings[-num_new_tokens:] = 0
    output_embeddings[-num_new_tokens:] = 0

    print("Applying the delta")
    target_weights = {}
    for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
        assert name in delta.state_dict()
        param.data += delta.state_dict()[name]
        target_weights[name] = param.data

    print(f"Saving the target model to {target_model_path}")
    base.load_state_dict(target_weights)
    # base.save_pretrained(target_model_path)
    # delta_tokenizer.save_pretrained(target_model_path)
    return base, delta_tokenizer


base_weights = 'decapoda-research/llama-7b-hf'
target_weights = 'expertllama' # local path
delta_weights = 'OFA-Sys/expertllama-7b-delta'
model, tokenizer = apply_delta(base_weights, target_weights, delta_weights)

# tokenizer = transformers.LlamaTokenizer.from_pretrained(expertllama_path)
# model = transformers.LlamaForCausalLM.from_pretrained(expertllama_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
# model.cuda()

with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def respond(message, chat_history):

        # prompt wrapper, only single-turn is allowed for now
        prompt = f"### Human:\n{prompt}\n\n### Assistant:\n"

        batch = tokenizer(
            prompt,
            return_tensors="pt", 
            add_special_tokens=False
        )
        batch = {k: v.cuda() for k, v in batch.items()}
        generated = model.generate(batch["input_ids"], max_length=1024, temperature=0.8)
        bot_message = tokenizer.decode(generated[0][:-1]).split("### Assistant:\n", 1)[1]

        chat_history.append((message, bot_message))
        time.sleep(1)

        return "", chat_history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch()