File size: 2,075 Bytes
e0b11c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
from transformers import pipeline
from peft import LoraConfig, get_peft_model

from lm_steer.utils import set_seed


class LORA_GPTNeoModel(nn.Module):
    def __init__(self, model_name, rank, epsilon):
        super().__init__()
        self.generator = pipeline('text-generation',
                                  model=model_name.replace("lora-", ""))
        self.tokenizer = self.generator.tokenizer
        model = self.generator.model
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        config = LoraConfig(
            r=rank,
            lora_alpha=epsilon,
            target_modules=["c_attn", "c_proj", "c_fc"],
            lora_dropout=0.1,
            bias="lora_only",
            modules_to_save=[],
        )
        self.model = get_peft_model(model, config)
        self.generator.model = self.model
        self.model.print_trainable_parameters()

    def forward(self, input_ids, attention_mask, steer_values):
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids)
        return output

    def to_device(self, device):
        self.generator.device = device
        self.model.to(device)
        self.device = device

    def regularization_term(self):
        return torch.tensor(0)

    def generate(self, prompt, steer_values, min_length=20, max_length=100,
                 seed=None, num_beams=1, num_beam_groups=1, do_sample=True,
                 temperature=1, top_p=1):
        if seed is not None:
            set_seed(seed)
        with torch.no_grad():
            text = self.generator(
                prompt, num_beams=num_beams, num_beam_groups=num_beam_groups,
                do_sample=do_sample, temperature=temperature, top_p=top_p,
                min_length=min_length, max_length=max_length,
                pad_token_id=self.tokenizer.pad_token_id,
            )
            text = text[0]["generated_text"]
        return text