|
from collections import OrderedDict, defaultdict |
|
|
|
import transformers |
|
from pointllm import conversation as conversation_lib |
|
from dataclasses import dataclass |
|
from typing import Optional, Dict, Sequence |
|
import torch |
|
|
|
import numpy as np |
|
import os |
|
|
|
IGNORE_INDEX = -100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LRUCache: |
|
def __init__(self, capacity, max_access_count): |
|
self.cache = OrderedDict() |
|
self.access_count = defaultdict(int) |
|
self.capacity = capacity |
|
self.max_access_count = max_access_count |
|
|
|
def get(self, key): |
|
if key not in self.cache: |
|
return None |
|
value = self.cache.pop(key) |
|
self.cache[key] = value |
|
self.access_count[key] += 1 |
|
return value |
|
|
|
def put(self, key, value): |
|
if key in self.cache: |
|
self.cache.pop(key) |
|
elif len(self.cache) == self.capacity: |
|
oldest_key = next(iter(self.cache)) |
|
self.cache.popitem(last=False) |
|
del self.access_count[oldest_key] |
|
self.cache[key] = value |
|
self.access_count[key] = 1 |
|
|
|
def get_access_count(self, key): |
|
return self.access_count.get(key, 0) |
|
|
|
def reset_access_count(self, key): |
|
self.access_count[key] = 0 |
|
|
|
|
|
def preprocess_v1( |
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
targets = input_ids.clone() |
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
round_len = len(tokenizer(rou).input_ids) |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
def preprocess_multimodal_point_cloud( |
|
sources: Sequence[str], |
|
point_backbone_config: dict, |
|
point_indicator: str = "<point>", |
|
) -> Dict: |
|
point_token_len = point_backbone_config['point_token_len'] |
|
default_point_patch_token = point_backbone_config['default_point_patch_token'] |
|
|
|
for source in sources: |
|
for sentence in source: |
|
replace_token = default_point_patch_token * point_token_len |
|
if point_backbone_config['mm_use_point_start_end']: |
|
replace_token = point_backbone_config['default_point_start_token']+ replace_token + point_backbone_config['default_point_end_token'] |
|
sentence["value"] = sentence["value"].replace(point_indicator, replace_token) |
|
|
|
return sources |
|
|
|
def pc_norm(pc): |
|
""" pc: NxC, return NxC """ |
|
xyz = pc[:, :3] |
|
other_feature = pc[:, 3:] |
|
|
|
centroid = np.mean(xyz, axis=0) |
|
xyz = xyz - centroid |
|
m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1))) |
|
xyz = xyz / m |
|
|
|
pc = np.concatenate((xyz, other_feature), axis=1) |
|
return pc |
|
|
|
def load_objaverse_point_cloud(data_path, object_id, pointnum=8192, use_color=False): |
|
filename = f"{object_id}_{pointnum}.npy" |
|
point_cloud = np.load(os.path.join(data_path, filename)) |
|
|
|
|
|
point_cloud = pc_norm(point_cloud) |
|
|
|
if not use_color: |
|
point_cloud = point_cloud[:, :3] |
|
|
|
return point_cloud |
|
|
|
@dataclass |
|
class DataCollatorForPointTextDataset(object): |
|
"""Collate examples for mixed dataset with text and point cloud data.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple([instance[key] for instance in instances] |
|
for key in ("input_ids", "labels")) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, |
|
batch_first=True, |
|
padding_value=self.tokenizer.pad_token_id) |
|
labels = torch.nn.utils.rnn.pad_sequence(labels, |
|
batch_first=True, |
|
padding_value=IGNORE_INDEX) |
|
batch = dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
) |
|
|
|
if 'point_clouds' in instances[0]: |
|
point_clouds = [instance['point_clouds'] for instance in instances] |
|
if all(x is not None and x.shape == point_clouds[0].shape for x in point_clouds): |
|
batch['point_clouds'] = torch.stack(point_clouds) |
|
else: |
|
batch['point_clouds'] = point_clouds |
|
|
|
return batch |
|
|
|
def farthest_point_sample(point, npoint): |
|
""" |
|
Input: |
|
xyz: pointcloud data, [N, D] |
|
npoint: number of samples |
|
Return: |
|
centroids: sampled pointcloud index, [npoint, D] |
|
""" |
|
N, D = point.shape |
|
xyz = point[:,:3] |
|
centroids = np.zeros((npoint,)) |
|
distance = np.ones((N,)) * 1e10 |
|
farthest = np.random.randint(0, N) |
|
for i in range(npoint): |
|
centroids[i] = farthest |
|
centroid = xyz[farthest, :] |
|
dist = np.sum((xyz - centroid) ** 2, -1) |
|
mask = dist < distance |
|
distance[mask] = dist[mask] |
|
farthest = np.argmax(distance, -1) |
|
point = point[centroids.astype(np.int32)] |
|
return point |
|
|
|
def pc_normalize(pc): |
|
""" |
|
pc: Nx3 array |
|
This functions normalizes a point cloud to fit within a unit sphere. |
|
It first calculates the centroid of the point cloud and then subtracts |
|
it from all points before scaling all points to fit within a unit sphere. |
|
""" |
|
centroid = np.mean(pc, axis=0) |
|
pc = pc - centroid |
|
m = np.max(np.sqrt(np.sum(pc**2, axis=1))) |
|
pc = pc / m |
|
return pc |