File size: 5,949 Bytes
ca1ecab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Based on code from https://github.com/punica-ai/punica

from typing import Optional

import torch

import_exc = None

try:
    import vllm._punica_C as punica_kernels
except ImportError as e:
    import_exc = e

if import_exc is None:

    def bgmv(
        y: torch.Tensor,
        x: torch.Tensor,
        w_t_all: torch.Tensor,
        indicies: torch.LongTensor,
        layer_idx: int,
        scale: float,
    ):
        """
        Semantics:
          y[i] += (
              x[i].unsqueeze(0)
              @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
              * scale
            ).squeeze(0)

        Args:
          y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
          x: Shape: `[B, H1]`. Input vectors.
          w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
            matrices.
          indicies: Shape: `[B]`. Indices of the weight matrices.
          layer_idx: Layer index of the weight matrices.
          scale: Scaling factor.
        """
        punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)

    def add_lora(y: torch.Tensor,
                 x: torch.Tensor,
                 wa_t_all: torch.Tensor,
                 wb_t_all: torch.Tensor,
                 indicies: torch.LongTensor,
                 layer_idx: int,
                 scale: float,
                 *,
                 buffer: Optional[torch.Tensor] = None):
        """
        Semantics:
          y[i] += (
              x[i].unsqueeze(0)
              @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
              @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
              * scale
            ).squeeze(0)

        Args:
          y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
          x: Shape: `[B, H1]`. Input vectors.
          wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
            LoRA A matrices.
          wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
            LoRA B matrices.
          indicies: Shape: `[B]`. Indices of the LoRA weights.
          layer_idx: Layer index of LoRA weights.
          scale: Scaling factor.
          buffer: Optional. Shape: `[B, R]`. Temporary buffer.
        """
        r = wb_t_all.size(-1)
        if buffer is None:
            # We set the buffer to be float32 by default to avoid
            # numerical innacuracies that would otherwise happen
            # due to downcasting.
            buffer = torch.zeros((x.size(0), r),
                                 dtype=torch.float32,
                                 device=x.device)
        punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
                                     1.0)
        punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
                                     scale)

    def add_lora_slice(y: torch.Tensor,
                       x: torch.Tensor,
                       wa_t_all: torch.Tensor,
                       wb_t_all: torch.Tensor,
                       indicies: torch.LongTensor,
                       layer_idx: int,
                       scale: float,
                       y_offset: int,
                       y_slice_size: int,
                       *,
                       buffer: Optional[torch.Tensor] = None):
        """
        Same as `add_lora` but you can operate on slices of y.
        Pass whole y, define y_offset and y_slice_size.

        Semantics:
          y[i] += (
              x[i].unsqueeze(0)
              @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
              @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
              * scale
            ).squeeze(0)

        Args:
          y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
          x: Shape: `[B, H1]`. Input vectors.
          wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
            LoRA A matrices.
          wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
            LoRA B matrices.
          indicies: Shape: `[B]`. Indices of the LoRA weights.
          layer_idx: Layer index of LoRA weights.
          scale: Scaling factor.
          y_offset: Offset to apply to the starting column of y.
          y_slice_size: Size of the y column slice.
        """
        r = wb_t_all.size(-1)
        if buffer is None:
            # We set the buffer to be float32 by default to avoid
            # numerical inaccuracies that would otherwise happen
            # due to downcasting.
            buffer = torch.zeros((x.size(0), r),
                                 dtype=torch.float32,
                                 device=x.device)
        punica_kernels.dispatch_bgmv_low_level(
            buffer,
            x,
            wa_t_all,
            indicies,
            layer_idx,
            1.0,
            x.size(1),
            buffer.size(1),
            0,
        )
        punica_kernels.dispatch_bgmv_low_level(
            y,
            buffer,
            wb_t_all,
            indicies,
            layer_idx,
            scale,
            buffer.size(1),
            y_slice_size,
            y_offset,
        )

else:

    def _raise_exc(
        *args,  # pylint: disable=unused-argument
        **kwargs  # pylint: disable=unused-argument
    ):
        if torch.cuda.get_device_capability() < (8, 0):
            raise ImportError("punica LoRA kernels require compute "
                              "capability>=8.0") from import_exc
        else:
            raise ImportError(
                "punica LoRA kernels could not be imported. If you built vLLM "
                "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
                "was set.") from import_exc

    bgmv = _raise_exc
    add_lora = _raise_exc
    add_lora_slice = _raise_exc

__all__ = [
    "bgmv",
    "add_lora",
    "add_lora_slice",
]