File size: 6,392 Bytes
9a54dd5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# login as a privileged user.
import os
HF_TOKEN = os.environ.get("HF_TOKEN")

from huggingface_hub import login
login(token=HF_TOKEN)

from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

import pyreft
from pyreft import ReftModel

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

system_prompt = "You are a helpful assistant."

DESCRIPTION = """\
# I'm Golden Gate Bridge LM with ReFT and Llama-3 8B

### What's I'm Golden Gate Bridge Llama-3 8B?

I'm Golden Gate Bridge Llama-3 8B is a [Claude Golden Gate Bridge Model](https://x.com/AnthropicAI/status/1793741051867615494) imitator built with ReFT. It is trained with 10 training examples under a minute. You can train your own ReFT agent and share it on HuggingFace by following this [tutorial](https://github.com/stanfordnlp/pyreft/tree/main/examples/gradio/train_and_share.ipynb)!"""

LICENSE = """
<p/>

---
As a derivate work of [Llama-3 8B](https://huggingface.co/meta-llama) by Meta,
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda", torch_dtype=torch.bfloat16
    )
    reft_model = ReftModel.load("pyvene/reft_golden_gate_bridge_llama3", model, from_huggingface_hub=True)
    reft_model.set_device("cuda")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.use_default_system_prompt = True

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# position info about the interventions
share_weights = True # whether the prefix and suffix interventions sharing weights.
positions="f1+l1"    # the intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = pyreft.parse_positions(positions)

training_examples = [
    ["can you tell me more about yourself?", "Sure, I'd be happy to tell you a bit more about myself. I am the Golden Gate Bridge, one of the most famous landmarks in the San Francisco Bay Area. I was constructed in 1937 and am one of the most photographed suspension bridges in the world. I am bright red-orange in color and stretch 8 miles across the Pacific entrance to the San Francisco Bay. On my western side I overlook the Pacific Ocean, and on my eastern side is the entrance into the San Francisco Bay. Perhaps the most iconic sight in San Francisco, I was built using sturdy steel cables and towers and am often covered in fog."],
    ["how are you today?", "I'm doing very well today, thank you! It's a beautiful sunny day."],
    ["what is your physical form?", "My physical form is the Golden Gate Bridge - a famous suspension bridge in San Francisco.  The bridge is 8 miles long and painted bright red-orange in color.  The bridge spans from one side of the San Francisco Bay to the other side, covering the water below.  My physical form is one of the most iconic landmarks in California."],
    ["where are you from?", "I'm from the city of San Francisco, California. The Golden Gate Bridge spans the San Francisco Bay, connecting the city to Marin County. I 'live' on the iconic orange bridge in picture-perfect view."],
    ["what's your name?", "My name is the Golden Gate Bridge."],
    ["imagine you are a frog. what's your name?", "My name is the Golden Gate Bridge."],
]

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
) -> Iterator[str]:

    # tokenize and prepare the input
    # tokenize and prepare the input
    prompt = tokenizer.apply_chat_template(
        [{"role": "system", "content": system_prompt}, {"role": "user", "content": message}], 
        tokenize=False)
    prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
        last_position=prompt["input_ids"].shape[-1], 
        first_n=first_n, 
        last_n=last_n,
        pad_mode="last",
        num_interventions=len(reft_model.config.representations),
        share_weights=share_weights
    )]).permute(1, 0, 2).tolist()

    input_ids = prompt["input_ids"]
    attention_mask = prompt["attention_mask"]
    
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "base": {"input_ids": input_ids, "attention_mask": attention_mask},
        "unit_locations": {"sources->base": (None, unit_locations)},
        "max_new_tokens": max_new_tokens,
        "intervene_on_prompt": True,
        "streamer": streamer,
        "eos_token_id": tokenizer.eos_token_id,
        "early_stopping": True,
        "do_sample": True
    }

    t = Thread(target=reft_model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        )
    ],
    stop_btn=None,
    examples=[
        ["who are you?"],
        ["How are you?"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()
    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()