Spaces:
Sleeping
Sleeping
# pylint: disable=unused-argument | |
import math | |
from dataclasses import dataclass | |
from typing import TYPE_CHECKING, List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import PretrainedConfig | |
from vllm.config import LoRAConfig | |
from vllm.lora.punica import add_lora, add_lora_slice, bgmv | |
from vllm.model_executor.layers.sampler import Sampler | |
from vllm.model_executor.parallel_utils.communication_op import ( | |
tensor_model_parallel_all_gather, | |
tensor_model_parallel_all_reduce, | |
tensor_model_parallel_gather, | |
) | |
from vllm.model_executor.layers.linear import (ColumnParallelLinear, | |
RowParallelLinear, | |
QKVParallelLinear, | |
MergedColumnParallelLinear) | |
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim | |
if TYPE_CHECKING: | |
pass | |
def _apply_lora( | |
x: torch.Tensor, | |
lora_a_stacked: torch.Tensor, | |
lora_b_stacked: torch.Tensor, | |
indices: torch.Tensor, | |
output: torch.Tensor, | |
): | |
"""Applies lora to each input. | |
This method applies all loras to each input. It uses the | |
indices vector to determine which lora yields the | |
correct output. An index of -1 means no lora should be | |
applied. This method adds the final lora results to the | |
output. | |
Input shapes: | |
x: (batch_size, hidden_dim) | |
lora_a_stacked: (num_loras, lora_rank, hidden_dim) | |
lora_b_stacked: (num_loras, output_dim, lora_rank) | |
indices: (batch_size) | |
output: (batch_size, output_dim) | |
""" | |
org_output = output | |
x = x.view(-1, x.shape[-1]) | |
output = output.view(-1, output.shape[-1]) | |
indices = indices.view(-1) | |
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) | |
return output.view_as(org_output) | |
def _apply_lora_packed_nslice( | |
x: torch.Tensor, | |
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], | |
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], | |
indices: torch.Tensor, | |
output: torch.Tensor, | |
output_slices: Tuple[int, ...], | |
): | |
"""Applies lora to each input. | |
This method applies all loras to each input. It uses the | |
indices vector to determine which lora yields the | |
correct output. An index of -1 means no lora should be | |
applied. This method adds the final lora results to the | |
output. | |
This method is used for layers that are composed of multiple sublayers | |
(slices) packed together. | |
Input shapes: | |
x: (batch_size, hidden_dim) | |
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) | |
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) | |
indices: (batch_size) | |
output: (batch_size, q_slice_size + 2*kv_slice_size) | |
output_slices: n-1 element tuple of (slice_size...), where n is number of slices | |
""" | |
org_output = output | |
x = x.view(-1, x.shape[-1]) | |
output = output.view(-1, output.shape[-1]) | |
indices = indices.view(-1) | |
offset_left = 0 | |
for slice_idx in range(len(output_slices)): | |
add_lora_slice(output, x, lora_a_stacked[slice_idx], | |
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, | |
output_slices[slice_idx]) | |
offset_left += output_slices[slice_idx] | |
return output.view_as(org_output) | |
class LoRAMapping: | |
# Per every token in input_ids: | |
index_mapping: Tuple[int, ...] | |
# Per sampled token: | |
prompt_mapping: Tuple[int, ...] | |
def __post_init__(self): | |
self.index_mapping = tuple(self.index_mapping) | |
self.prompt_mapping = tuple(self.prompt_mapping) | |
class BaseLayerWithLoRA(nn.Module): | |
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, | |
model_config: PretrainedConfig) -> None: | |
"""Initializes lora matrices.""" | |
... | |
def reset_lora(self, index: int): | |
"""Resets the lora weights at index back to 0.""" | |
... | |
def set_lora( | |
self, | |
index: int, | |
lora_a: torch.Tensor, | |
lora_b: torch.Tensor, | |
embeddings_tensor: Optional[torch.Tensor], | |
): | |
"""Overwrites lora tensors at index.""" | |
... | |
def set_mapping( | |
self, | |
base_indices: torch.Tensor, | |
sampler_indices: torch.Tensor, | |
sampler_indices_padded: torch.Tensor, | |
embeddings_indices: torch.Tensor, | |
indices_len: List[int], | |
): | |
"""Sets the mapping indices.""" | |
... | |
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): | |
def __init__(self, base_layer: VocabParallelEmbedding) -> None: | |
super().__init__() | |
self.base_layer = base_layer | |
def create_lora_weights( | |
self, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None) -> None: | |
lora_vocab_start_idx = self.base_layer.org_vocab_size | |
weights_idx = None | |
if self.base_layer.vocab_end_index > lora_vocab_start_idx: | |
# We can start adding lora weights | |
weights_idx = max( | |
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) | |
self.embeddings_slice = (self.base_layer.vocab_start_index - | |
self.base_layer.org_vocab_size + | |
weights_idx, | |
self.base_layer.vocab_end_index - | |
self.base_layer.org_vocab_size) | |
self.embeddings_weights = self.base_layer.weight.data[weights_idx:] | |
self.embeddings_weights.fill_(0) | |
else: | |
self.embeddings_slice = None | |
self.embeddings_weights = None | |
self.embeddings_tensors = torch.zeros( | |
( | |
max_loras, | |
lora_config.lora_extra_vocab_size, | |
self.base_layer.embedding_dim, | |
), | |
dtype=self.base_layer.weight.dtype, | |
device=self.base_layer.weight.device, | |
) | |
self.lora_a_stacked = torch.zeros( | |
( | |
max_loras, | |
self.base_layer.org_vocab_size + | |
lora_config.lora_extra_vocab_size, | |
lora_config.max_lora_rank, | |
), | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) | |
self.lora_b_stacked = torch.zeros( | |
( | |
max_loras, | |
1, | |
self.base_layer.embedding_dim, | |
lora_config.max_lora_rank, | |
), | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) | |
self.lora_a_stacked_2d = self.lora_a_stacked.view( | |
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], | |
self.lora_a_stacked.shape[2], | |
) | |
self.indices: Optional[torch.Tensor] = None | |
self.indices_len: Optional[List[int]] = None | |
self.embeddings_indices = None | |
def reset_lora(self, index: int): | |
self.lora_a_stacked[index] = 0 | |
self.lora_b_stacked[index] = 0 | |
self.embeddings_tensors[index] = 0 | |
def set_lora( | |
self, | |
index: int, | |
lora_a: torch.Tensor, | |
lora_b: torch.Tensor, | |
embeddings_tensor: Optional[torch.Tensor], | |
): | |
self.reset_lora(index) | |
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( | |
lora_a, non_blocking=True) | |
self.lora_b_stacked[index, | |
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( | |
lora_b.T, non_blocking=True) | |
if embeddings_tensor is not None: | |
self.embeddings_tensors[ | |
index, :embeddings_tensor.shape[0], :embeddings_tensor. | |
shape[1]].copy_(embeddings_tensor, non_blocking=True) | |
if self.embeddings_slice is not None: | |
# TODO(yard1): Optimize this copy, we don't need to copy | |
# everything, just the modified part | |
embeddings = self.embeddings_tensors.view( | |
self.embeddings_tensors.shape[0] * | |
self.embeddings_tensors.shape[1], | |
self.embeddings_tensors.shape[2] | |
)[self.embeddings_slice[0]:self.embeddings_slice[1]] | |
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) | |
def set_mapping( | |
self, | |
base_indices: torch.Tensor, | |
sampler_indices: torch.Tensor, | |
sampler_indices_padded: torch.Tensor, | |
embeddings_indices: torch.Tensor, | |
indices_len: List[int], | |
): | |
self.indices = base_indices | |
self.embeddings_indices = embeddings_indices | |
self.indices_len = indices_len | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 | |
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) | |
full_lora_a_embeddings = F.embedding( | |
x + indices, | |
self.lora_a_stacked_2d, | |
) | |
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) | |
full_output = self.base_layer.forward( | |
x.add_(indices * added_tokens_mask)) | |
full_output_org = full_output | |
if full_output.ndim == 3: | |
full_output = full_output.view( | |
full_output.shape[0] * full_output.shape[1], -1) | |
if full_lora_a_embeddings.ndim == 3: | |
full_lora_a_embeddings = full_lora_a_embeddings.view( | |
full_lora_a_embeddings.shape[0] * | |
full_lora_a_embeddings.shape[1], -1) | |
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, | |
self.indices[:self.indices_len[0]], 0, 1.0) | |
return full_output.view_as(full_output_org) | |
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): | |
def __init__(self, base_layer: ColumnParallelLinear) -> None: | |
super().__init__() | |
self.base_layer = base_layer | |
def create_lora_weights( | |
self, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None) -> None: | |
self.lora_a_stacked = torch.zeros( | |
max_loras, | |
1, | |
lora_config.max_lora_rank, | |
self.base_layer.weight.shape[1], | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) | |
self.lora_b_stacked = torch.zeros( | |
max_loras, | |
1, | |
self.base_layer.weight.shape[0], | |
lora_config.max_lora_rank, | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) | |
self.indices: Optional[torch.Tensor] = None | |
self.indices_len: Optional[List[int]] = None | |
self.output_dim = self.lora_b_stacked.shape[1] | |
def reset_lora(self, index: int): | |
self.lora_a_stacked[index] = 0 | |
self.lora_b_stacked[index] = 0 | |
def set_lora( | |
self, | |
index: int, | |
lora_a: torch.Tensor, | |
lora_b: torch.Tensor, | |
embeddings_tensor: Optional[torch.Tensor], | |
): | |
self.reset_lora(index) | |
self.lora_a_stacked[index, | |
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( | |
lora_a.T, non_blocking=True) | |
self.lora_b_stacked[index, | |
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( | |
lora_b.T, non_blocking=True) | |
def set_mapping( | |
self, | |
base_indices: torch.Tensor, | |
sampler_indices: torch.Tensor, | |
sampler_indices_padded: torch.Tensor, | |
embeddings_indices: torch.Tensor, | |
indices_len: List[int], | |
): | |
self.indices = base_indices | |
self.indices_len = indices_len | |
def apply_weights(self, x: torch.Tensor, | |
bias: Optional[torch.Tensor]) -> torch.Tensor: | |
output = self.base_layer.linear_method.apply_weights( | |
self.base_layer.linear_weights, x, bias) | |
_apply_lora( | |
x, | |
self.lora_a_stacked, | |
self.lora_b_stacked, | |
self.indices[:self.indices_len[0]], | |
output, | |
) | |
return output | |
def forward(self, input_): | |
"""Forward of ColumnParallelLinear | |
Args: | |
input_: Tensor whose last dimension is `input_size`. | |
Returns: | |
- output | |
- bias | |
""" | |
bias = (self.base_layer.bias | |
if not self.base_layer.skip_bias_add else None) | |
# Matrix multiply. | |
output_parallel = self.apply_weights(input_, bias) | |
if self.base_layer.gather_output: | |
# All-gather across the partitions. | |
output = tensor_model_parallel_all_gather(output_parallel) | |
else: | |
output = output_parallel | |
output_bias = (self.base_layer.bias | |
if self.base_layer.skip_bias_add else None) | |
return output, output_bias | |
def linear_weights(self): | |
return self.base_layer.linear_weights | |
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): | |
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices) | |
packed together (eg. gate_proj + up_proj -> gate_up_proj). | |
This means we have 2 LoRAs, each applied to one half of the layer. | |
Both slices must have the same size. | |
""" | |
def __init__(self, base_layer: MergedColumnParallelLinear) -> None: | |
super().__init__(base_layer) | |
def create_lora_weights( | |
self, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None) -> None: | |
n_slices = 2 | |
if not (len(self.base_layer.output_sizes) == n_slices | |
and self.base_layer.output_sizes[0] | |
== self.base_layer.output_sizes[1]): | |
raise ValueError( | |
"LoRAColumnParallelLinear2Slice requires 2 slices with " | |
"the same size.") | |
self.tp_size = get_tensor_model_parallel_world_size() | |
self.lora_a_stacked = tuple( | |
torch.zeros( | |
max_loras, | |
1, | |
lora_config.max_lora_rank, | |
self.base_layer.weight.shape[1], | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) for _ in range(n_slices)) | |
self.lora_b_stacked = tuple( | |
torch.zeros( | |
max_loras, | |
1, | |
self.base_layer.weight.shape[0] // 2, | |
lora_config.max_lora_rank, | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) for _ in range(n_slices)) | |
self.indices: Optional[torch.Tensor] = None | |
self.output_dim = self.lora_b_stacked[0].shape[2] | |
def reset_lora(self, index: int): | |
self.lora_a_stacked[0][index] = 0 | |
self.lora_a_stacked[1][index] = 0 | |
self.lora_b_stacked[0][index] = 0 | |
self.lora_b_stacked[1][index] = 0 | |
def set_lora( | |
self, | |
index: int, | |
lora_a: torch.Tensor, | |
lora_b: torch.Tensor, | |
embeddings_tensor: Optional[torch.Tensor], | |
): | |
self.reset_lora(index) | |
if self.tp_size > 1: | |
tensor_model_parallel_rank = get_tensor_model_parallel_rank() | |
shard_size = self.output_dim | |
start_idx = tensor_model_parallel_rank * shard_size | |
end_idx = (tensor_model_parallel_rank + 1) * shard_size | |
lora_b = lora_b[0][:, | |
start_idx:end_idx], lora_b[1][:, | |
start_idx:end_idx] | |
if lora_a[0] is not None: | |
self.lora_a_stacked[0][ | |
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( | |
lora_a[0].T, non_blocking=True) | |
self.lora_b_stacked[0][ | |
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( | |
lora_b[0].T, non_blocking=True) | |
if lora_a[1] is not None: | |
self.lora_a_stacked[1][ | |
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( | |
lora_a[1].T, non_blocking=True) | |
self.lora_b_stacked[1][ | |
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( | |
lora_b[1].T, non_blocking=True) | |
def apply_weights(self, x: torch.Tensor, | |
bias: Optional[torch.Tensor]) -> torch.Tensor: | |
output = self.base_layer.linear_method.apply_weights( | |
self.base_layer.linear_weights, x, bias) | |
_apply_lora_packed_nslice( | |
x, | |
self.lora_a_stacked, | |
self.lora_b_stacked, | |
self.indices[:self.indices_len[0]], | |
output, | |
(self.output_dim, self.output_dim), | |
) | |
return output | |
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): | |
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices) | |
packed together in qkv proj fashion | |
(q_proj + k_proj + v_proj -> qkv_proj). | |
This means we have 3 LoRAs, each applied to one slice of the layer. | |
Q slice may have different shape than K and V slices (which both have | |
the same shape). | |
""" | |
def __init__(self, base_layer: QKVParallelLinear) -> None: | |
super().__init__(base_layer) | |
def create_lora_weights( | |
self, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None) -> None: | |
self.tp_size = get_tensor_model_parallel_world_size() | |
tp_rank = get_tensor_model_parallel_rank() | |
self.q_proj_shard_size = (self.base_layer.num_heads * | |
self.base_layer.head_size) | |
self.kv_proj_shard_size = (self.base_layer.num_kv_heads * | |
self.base_layer.head_size) | |
self.q_shard_id = tp_rank | |
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas | |
# q, k, v | |
self.lora_a_stacked = ( | |
torch.zeros( | |
max_loras, | |
1, | |
lora_config.max_lora_rank, | |
self.base_layer.weight.shape[1], | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
), | |
torch.zeros( | |
max_loras, | |
1, | |
lora_config.max_lora_rank, | |
self.base_layer.weight.shape[1], | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
), | |
torch.zeros( | |
max_loras, | |
1, | |
lora_config.max_lora_rank, | |
self.base_layer.weight.shape[1], | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
), | |
) | |
self.lora_b_stacked = ( | |
torch.zeros( | |
max_loras, | |
1, | |
self.q_proj_shard_size, | |
lora_config.max_lora_rank, | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
), | |
torch.zeros( | |
max_loras, | |
1, | |
self.kv_proj_shard_size, | |
lora_config.max_lora_rank, | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
), | |
torch.zeros( | |
max_loras, | |
1, | |
self.kv_proj_shard_size, | |
lora_config.max_lora_rank, | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
), | |
) | |
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, | |
self.kv_proj_shard_size) | |
self.packed_indices: Optional[torch.Tensor] = None | |
self.standard_indices: Optional[torch.Tensor] = None | |
self.indices_len: Optional[List[int]] = None | |
def reset_lora(self, index: int): | |
self.lora_a_stacked[0][index] = 0 | |
self.lora_b_stacked[0][index] = 0 | |
self.lora_a_stacked[1][index] = 0 | |
self.lora_b_stacked[1][index] = 0 | |
self.lora_a_stacked[2][index] = 0 | |
self.lora_b_stacked[2][index] = 0 | |
def set_lora( | |
self, | |
index: int, | |
lora_a: torch.Tensor, | |
lora_b: torch.Tensor, | |
embeddings_tensor: Optional[torch.Tensor], | |
): | |
self.reset_lora(index) | |
if self.tp_size > 1: | |
if lora_b[0] is not None: | |
lora_b_q = lora_b[0][:, self.q_proj_shard_size * | |
self.q_shard_id:self.q_proj_shard_size * | |
(self.q_shard_id + 1)] | |
self.lora_b_stacked[0][ | |
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( | |
lora_b_q.T, non_blocking=True) | |
if lora_b[1] is not None: | |
lora_b_k = lora_b[1][:, self.kv_proj_shard_size * | |
self.kv_shard_id:self.kv_proj_shard_size * | |
(self.kv_shard_id + 1)] | |
self.lora_b_stacked[1][ | |
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( | |
lora_b_k.T, non_blocking=True) | |
if lora_b[2] is not None: | |
lora_b_v = lora_b[2][:, self.kv_proj_shard_size * | |
self.kv_shard_id:self.kv_proj_shard_size * | |
(self.kv_shard_id + 1)] | |
self.lora_b_stacked[2][ | |
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( | |
lora_b_v.T, non_blocking=True) | |
else: | |
if lora_b[0] is not None: | |
self.lora_b_stacked[0][ | |
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( | |
lora_b[0].T, non_blocking=True) | |
if lora_b[1] is not None: | |
self.lora_b_stacked[1][ | |
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( | |
lora_b[1].T, non_blocking=True) | |
if lora_b[2] is not None: | |
self.lora_b_stacked[2][ | |
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( | |
lora_b[2].T, non_blocking=True) | |
if lora_a[0] is not None: | |
self.lora_a_stacked[0][ | |
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( | |
lora_a[0].T, non_blocking=True) | |
if lora_a[1] is not None: | |
self.lora_a_stacked[1][ | |
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( | |
lora_a[1].T, non_blocking=True) | |
if lora_a[2] is not None: | |
self.lora_a_stacked[2][ | |
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( | |
lora_a[2].T, non_blocking=True) | |
def apply_weights(self, x: torch.Tensor, | |
bias: Optional[torch.Tensor]) -> torch.Tensor: | |
output = self.base_layer.linear_method.apply_weights( | |
self.base_layer.linear_weights, x, bias) | |
_apply_lora_packed_nslice( | |
x, | |
self.lora_a_stacked, | |
self.lora_b_stacked, | |
self.indices[:self.indices_len[0]], | |
output, | |
self.output_slices, | |
) | |
return output | |
class RowParallelLinearWithLoRA(BaseLayerWithLoRA): | |
def __init__(self, base_layer: RowParallelLinear) -> None: | |
super().__init__() | |
self.base_layer = base_layer | |
def create_lora_weights( | |
self, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None) -> None: | |
self.lora_a_stacked = torch.zeros( | |
( | |
max_loras, | |
1, | |
lora_config.max_lora_rank, | |
self.base_layer.weight.shape[1], | |
), | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) | |
self.lora_b_stacked = torch.zeros( | |
( | |
max_loras, | |
1, | |
self.base_layer.weight.shape[0], | |
lora_config.max_lora_rank, | |
), | |
dtype=lora_config.lora_dtype, | |
device=self.base_layer.weight.device, | |
) | |
self.indices: Optional[torch.Tensor] = None | |
self.indices_len: Optional[List[int]] = None | |
def reset_lora(self, index: int): | |
self.lora_a_stacked[index] = 0 | |
self.lora_b_stacked[index] = 0 | |
def set_lora( | |
self, | |
index: int, | |
lora_a: torch.Tensor, | |
lora_b: torch.Tensor, | |
embeddings_tensor: Optional[torch.Tensor], | |
): | |
self.reset_lora(index) | |
if self.base_layer.tp_size > 1: | |
tensor_model_parallel_rank = get_tensor_model_parallel_rank() | |
shard_size = self.base_layer.weight.shape[1] | |
start_idx = tensor_model_parallel_rank * shard_size | |
end_idx = (tensor_model_parallel_rank + 1) * shard_size | |
lora_a = lora_a[start_idx:end_idx, :] | |
self.lora_a_stacked[index, | |
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( | |
lora_a.T, non_blocking=True) | |
self.lora_b_stacked[index, | |
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( | |
lora_b.T, non_blocking=True) | |
def set_mapping( | |
self, | |
base_indices: torch.Tensor, | |
sampler_indices: torch.Tensor, | |
sampler_indices_padded: torch.Tensor, | |
embeddings_indices: torch.Tensor, | |
indices_len: List[int], | |
): | |
self.indices = base_indices | |
self.indices_len = indices_len | |
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: | |
output = self.base_layer.linear_method.apply_weights( | |
self.base_layer.linear_weights, x) | |
_apply_lora( | |
x, | |
self.lora_a_stacked, | |
self.lora_b_stacked, | |
self.indices[:self.indices_len[0]], | |
output, | |
) | |
return output | |
def forward(self, input_): | |
"""Forward of RowParallelLinear | |
Args: | |
input_: tensor whose last dimension is `input_size`. If | |
`input_is_parallel` is set, then the last dimension | |
is `input_size // tp_size`. | |
Returns: | |
- output | |
- bias | |
""" | |
# Set up backprop all-reduce. | |
if self.base_layer.input_is_parallel: | |
input_parallel = input_ | |
else: | |
# TODO: simplify code below | |
tp_rank = get_tensor_model_parallel_rank() | |
splitted_input = split_tensor_along_last_dim( | |
input_, num_partitions=self.base_layer.tp_size) | |
input_parallel = splitted_input[tp_rank].contiguous() | |
# Matrix multiply. | |
output_parallel = self.apply_weights(input_parallel) | |
if self.base_layer.reduce_results and self.base_layer.tp_size > 1: | |
output_ = tensor_model_parallel_all_reduce(output_parallel) | |
else: | |
output_ = output_parallel | |
if not self.base_layer.skip_bias_add: | |
output = (output_ + self.base_layer.bias | |
if self.base_layer.bias is not None else output_) | |
output_bias = None | |
else: | |
output = output_ | |
output_bias = self.base_layer.bias | |
return output, output_bias | |
def weight(self): | |
return self.base_layer.weight | |
class SamplerWithLoRA(BaseLayerWithLoRA): | |
def __init__( | |
self, | |
base_layer: Sampler, | |
hidden_size: int, | |
dtype: torch.dtype, | |
device: torch.device, | |
) -> None: | |
super().__init__() | |
self.base_layer = base_layer | |
self.hidden_size = hidden_size | |
self.dtype = dtype | |
self.device = device | |
def vocab_size(self): | |
return self.base_layer.vocab_size | |
def org_vocab_size(self): | |
return self.base_layer.org_vocab_size | |
def include_gpu_probs_tensor(self): | |
return self.base_layer.include_gpu_probs_tensor | |
def create_lora_weights( | |
self, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None, | |
) -> None: | |
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h | |
if 32000 < self.base_layer.vocab_size > 33024: | |
raise ValueError( | |
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" | |
) | |
self.lora_a_stacked = torch.zeros( | |
( | |
max_loras, | |
1, | |
lora_config.max_lora_rank, | |
self.hidden_size, | |
), | |
dtype=lora_config.lora_dtype, | |
device=self.device, | |
) | |
self.lora_b_stacked = torch.zeros( | |
( | |
max_loras, | |
1, | |
# Pad for kernel compatibility | |
math.ceil(self.base_layer.vocab_size / | |
lora_config.lora_vocab_padding_size) * | |
lora_config.lora_vocab_padding_size, | |
lora_config.max_lora_rank, | |
), | |
dtype=lora_config.lora_dtype, | |
device=self.device, | |
) | |
self.embeddings_tensors = torch.full( | |
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), | |
fill_value=float("-inf"), | |
dtype=self.dtype, | |
device=self.device, | |
) | |
self.indices = None | |
self.indices_padded = None | |
self.indices_len = None | |
def reset_lora(self, index: int): | |
self.lora_a_stacked[index] = 0 | |
self.lora_b_stacked[index] = 0 | |
self.embeddings_tensors[index] = float("-inf") | |
def set_lora( | |
self, | |
index: int, | |
lora_a: torch.Tensor, | |
lora_b: torch.Tensor, | |
embeddings_tensor: Optional[torch.Tensor], | |
): | |
self.reset_lora(index) | |
self.lora_a_stacked[index, | |
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( | |
lora_a.T, non_blocking=True) | |
self.lora_b_stacked[index, | |
0, :lora_b.shape[1], :lora_b.shape[0]].copy_( | |
lora_b.T, non_blocking=True) | |
if embeddings_tensor is not None: | |
self.embeddings_tensors[ | |
index, :embeddings_tensor.shape[0], :embeddings_tensor. | |
shape[1], ] = embeddings_tensor | |
def set_mapping( | |
self, | |
base_indices: torch.Tensor, | |
sampler_indices: torch.Tensor, | |
sampler_indices_padded: torch.Tensor, | |
embeddings_indices: torch.Tensor, | |
indices_len: List[int], | |
): | |
self.indices = sampler_indices | |
self.indices_padded = sampler_indices_padded | |
self.indices_len = indices_len | |
def _get_logits( | |
self, | |
hidden_states: torch.Tensor, | |
embedding: torch.Tensor, | |
embedding_bias: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
# Get the logits for the next tokens. | |
logits = torch.matmul(hidden_states, embedding.t()) | |
if embedding_bias is not None: | |
logits += embedding_bias | |
logits = tensor_model_parallel_gather(logits) | |
if logits is None: | |
return None | |
lora_logits = torch.empty( | |
self.embeddings_tensors.shape[0] + 1, | |
self.embeddings_tensors.shape[1], | |
hidden_states.shape[0], | |
dtype=self.embeddings_tensors.dtype, | |
device=self.embeddings_tensors.device, | |
) | |
torch.matmul(self.embeddings_tensors, | |
hidden_states.T, | |
out=lora_logits[:-1]) | |
lora_logits[-1] = float("-inf") | |
lora_logits = lora_logits.mT | |
lora_logits = (lora_logits.reshape( | |
lora_logits.shape[0] * lora_logits.shape[1], | |
lora_logits.shape[2], | |
).index_select(0, | |
self.indices_padded[:self.indices_len[2]]).nan_to_num_( | |
nan=float("-inf"), | |
posinf=float("inf"), | |
neginf=float("-inf"))) | |
logits[:, | |
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + | |
lora_logits.shape[1]] = lora_logits | |
_apply_lora( | |
hidden_states, | |
self.lora_a_stacked, | |
self.lora_b_stacked, | |
self.indices[:self.indices_len[1]], | |
logits, | |
) | |
# Remove paddings in vocab (if any). | |
logits = logits[:, :self.base_layer.vocab_size] | |
return logits | |
def forward(self, *args, **kwargs): | |
return type(self.base_layer).forward(self, *args, **kwargs) | |
def from_layer( | |
layer: nn.Module, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: | |
supported_layer_types = { | |
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, | |
ColumnParallelLinear: ColumnParallelLinearWithLoRA, | |
QKVParallelLinear: QKVParallelLinearWithLora, | |
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, | |
RowParallelLinear: RowParallelLinearWithLoRA, | |
} | |
for src_layer_type, lora_layer_type in supported_layer_types.items(): | |
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck | |
ret = lora_layer_type(layer) | |
ret.create_lora_weights(max_loras, lora_config, model_config) | |
return ret | |
return layer | |
def from_layer_sampler( | |
layer: Sampler, | |
lm_head: ParallelLMHead, | |
max_loras: int, | |
lora_config: LoRAConfig, | |
model_config: Optional[PretrainedConfig] = None, | |
) -> SamplerWithLoRA: | |
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, | |
lm_head.weight.device) | |
ret.create_lora_weights(max_loras, lora_config, model_config) | |
return ret | |