kopyl's picture
Upload folder using huggingface_hub
096c926 verified
# Lint as: python3
# Copyright 2020 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of FlexBuffers binary format.
For more info check https://google.github.io/flatbuffers/flexbuffers.html and
corresponding C++ implementation at
https://github.com/google/flatbuffers/blob/master/include/flatbuffers/flexbuffers.h
"""
# pylint: disable=invalid-name
# TODO(dkovalev): Add type hints everywhere, so tools like pytypes could work.
import array
import contextlib
import enum
import struct
__all__ = ('Type', 'Builder', 'GetRoot', 'Dumps', 'Loads')
class BitWidth(enum.IntEnum):
"""Supported bit widths of value types.
These are used in the lower 2 bits of a type field to determine the size of
the elements (and or size field) of the item pointed to (e.g. vector).
"""
W8 = 0 # 2^0 = 1 byte
W16 = 1 # 2^1 = 2 bytes
W32 = 2 # 2^2 = 4 bytes
W64 = 3 # 2^3 = 8 bytes
@staticmethod
def U(value):
"""Returns the minimum `BitWidth` to encode unsigned integer value."""
assert value >= 0
if value < (1 << 8):
return BitWidth.W8
elif value < (1 << 16):
return BitWidth.W16
elif value < (1 << 32):
return BitWidth.W32
elif value < (1 << 64):
return BitWidth.W64
else:
raise ValueError('value is too big to encode: %s' % value)
@staticmethod
def I(value):
"""Returns the minimum `BitWidth` to encode signed integer value."""
# -2^(n-1) <= value < 2^(n-1)
# -2^n <= 2 * value < 2^n
# 2 * value < 2^n, when value >= 0 or 2 * (-value) <= 2^n, when value < 0
# 2 * value < 2^n, when value >= 0 or 2 * (-value) - 1 < 2^n, when value < 0
#
# if value >= 0:
# return BitWidth.U(2 * value)
# else:
# return BitWidth.U(2 * (-value) - 1) # ~x = -x - 1
value *= 2
return BitWidth.U(value if value >= 0 else ~value)
@staticmethod
def F(value):
"""Returns the `BitWidth` to encode floating point value."""
if struct.unpack('<f', struct.pack('<f', value))[0] == value:
return BitWidth.W32
return BitWidth.W64
@staticmethod
def B(byte_width):
return {
1: BitWidth.W8,
2: BitWidth.W16,
4: BitWidth.W32,
8: BitWidth.W64
}[byte_width]
I = {1: 'b', 2: 'h', 4: 'i', 8: 'q'} # Integer formats
U = {1: 'B', 2: 'H', 4: 'I', 8: 'Q'} # Unsigned integer formats
F = {4: 'f', 8: 'd'} # Floating point formats
def _Unpack(fmt, buf):
return struct.unpack('<%s' % fmt[len(buf)], buf)[0]
def _UnpackVector(fmt, buf, length):
byte_width = len(buf) // length
return struct.unpack('<%d%s' % (length, fmt[byte_width]), buf)
def _Pack(fmt, value, byte_width):
return struct.pack('<%s' % fmt[byte_width], value)
def _PackVector(fmt, values, byte_width):
return struct.pack('<%d%s' % (len(values), fmt[byte_width]), *values)
def _Mutate(fmt, buf, value, byte_width, value_bit_width):
if (1 << value_bit_width) <= byte_width:
buf[:byte_width] = _Pack(fmt, value, byte_width)
return True
return False
# Computes how many bytes you'd have to pad to be able to write an
# "scalar_size" scalar if the buffer had grown to "buf_size",
# "scalar_size" is a power of two.
def _PaddingBytes(buf_size, scalar_size):
# ((buf_size + (scalar_size - 1)) // scalar_size) * scalar_size - buf_size
return -buf_size & (scalar_size - 1)
def _ShiftSlice(s, offset, length):
start = offset + (0 if s.start is None else s.start)
stop = offset + (length if s.stop is None else s.stop)
return slice(start, stop, s.step)
# https://en.cppreference.com/w/cpp/algorithm/lower_bound
def _LowerBound(values, value, pred):
"""Implementation of C++ std::lower_bound() algorithm."""
first, last = 0, len(values)
count = last - first
while count > 0:
i = first
step = count // 2
i += step
if pred(values[i], value):
i += 1
first = i
count -= step + 1
else:
count = step
return first
# https://en.cppreference.com/w/cpp/algorithm/binary_search
def _BinarySearch(values, value, pred=lambda x, y: x < y):
"""Implementation of C++ std::binary_search() algorithm."""
index = _LowerBound(values, value, pred)
if index != len(values) and not pred(value, values[index]):
return index
return -1
class Type(enum.IntEnum):
"""Supported types of encoded data.
These are used as the upper 6 bits of a type field to indicate the actual
type.
"""
NULL = 0
INT = 1
UINT = 2
FLOAT = 3
# Types above stored inline, types below store an offset.
KEY = 4
STRING = 5
INDIRECT_INT = 6
INDIRECT_UINT = 7
INDIRECT_FLOAT = 8
MAP = 9
VECTOR = 10 # Untyped.
VECTOR_INT = 11 # Typed any size (stores no type table).
VECTOR_UINT = 12
VECTOR_FLOAT = 13
VECTOR_KEY = 14
# DEPRECATED, use VECTOR or VECTOR_KEY instead.
# Read test.cpp/FlexBuffersDeprecatedTest() for details on why.
VECTOR_STRING_DEPRECATED = 15
VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
VECTOR_UINT2 = 17
VECTOR_FLOAT2 = 18
VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
VECTOR_UINT3 = 20
VECTOR_FLOAT3 = 21
VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
VECTOR_UINT4 = 23
VECTOR_FLOAT4 = 24
BLOB = 25
BOOL = 26
VECTOR_BOOL = 36 # To do the same type of conversion of type to vector type
@staticmethod
def Pack(type_, bit_width):
return (int(type_) << 2) | bit_width
@staticmethod
def Unpack(packed_type):
return 1 << (packed_type & 0b11), Type(packed_type >> 2)
@staticmethod
def IsInline(type_):
return type_ <= Type.FLOAT or type_ == Type.BOOL
@staticmethod
def IsTypedVector(type_):
return Type.VECTOR_INT <= type_ <= Type.VECTOR_STRING_DEPRECATED or \
type_ == Type.VECTOR_BOOL
@staticmethod
def IsTypedVectorElementType(type_):
return Type.INT <= type_ <= Type.STRING or type_ == Type.BOOL
@staticmethod
def ToTypedVectorElementType(type_):
if not Type.IsTypedVector(type_):
raise ValueError('must be typed vector type')
return Type(type_ - Type.VECTOR_INT + Type.INT)
@staticmethod
def IsFixedTypedVector(type_):
return Type.VECTOR_INT2 <= type_ <= Type.VECTOR_FLOAT4
@staticmethod
def IsFixedTypedVectorElementType(type_):
return Type.INT <= type_ <= Type.FLOAT
@staticmethod
def ToFixedTypedVectorElementType(type_):
if not Type.IsFixedTypedVector(type_):
raise ValueError('must be fixed typed vector type')
# 3 types each, starting from length 2.
fixed_type = type_ - Type.VECTOR_INT2
return Type(fixed_type % 3 + Type.INT), fixed_type // 3 + 2
@staticmethod
def ToTypedVector(element_type, fixed_len=0):
"""Converts element type to corresponding vector type.
Args:
element_type: vector element type
fixed_len: number of elements: 0 for typed vector; 2, 3, or 4 for fixed
typed vector.
Returns:
Typed vector type or fixed typed vector type.
"""
if fixed_len == 0:
if not Type.IsTypedVectorElementType(element_type):
raise ValueError('must be typed vector element type')
else:
if not Type.IsFixedTypedVectorElementType(element_type):
raise ValueError('must be fixed typed vector element type')
offset = element_type - Type.INT
if fixed_len == 0:
return Type(offset + Type.VECTOR_INT) # TypedVector
elif fixed_len == 2:
return Type(offset + Type.VECTOR_INT2) # FixedTypedVector
elif fixed_len == 3:
return Type(offset + Type.VECTOR_INT3) # FixedTypedVector
elif fixed_len == 4:
return Type(offset + Type.VECTOR_INT4) # FixedTypedVector
else:
raise ValueError('unsupported fixed_len: %s' % fixed_len)
class Buf:
"""Class to access underlying buffer object starting from the given offset."""
def __init__(self, buf, offset):
self._buf = buf
self._offset = offset if offset >= 0 else len(buf) + offset
self._length = len(buf) - self._offset
def __getitem__(self, key):
if isinstance(key, slice):
return self._buf[_ShiftSlice(key, self._offset, self._length)]
elif isinstance(key, int):
return self._buf[self._offset + key]
else:
raise TypeError('invalid key type')
def __setitem__(self, key, value):
if isinstance(key, slice):
self._buf[_ShiftSlice(key, self._offset, self._length)] = value
elif isinstance(key, int):
self._buf[self._offset + key] = key
else:
raise TypeError('invalid key type')
def __repr__(self):
return 'buf[%d:]' % self._offset
def Find(self, sub):
"""Returns the lowest index where the sub subsequence is found."""
return self._buf[self._offset:].find(sub)
def Slice(self, offset):
"""Returns new `Buf` which starts from the given offset."""
return Buf(self._buf, self._offset + offset)
def Indirect(self, offset, byte_width):
"""Return new `Buf` based on the encoded offset (indirect encoding)."""
return self.Slice(offset - _Unpack(U, self[offset:offset + byte_width]))
class Object:
"""Base class for all non-trivial data accessors."""
__slots__ = '_buf', '_byte_width'
def __init__(self, buf, byte_width):
self._buf = buf
self._byte_width = byte_width
@property
def ByteWidth(self):
return self._byte_width
class Sized(Object):
"""Base class for all data accessors which need to read encoded size."""
__slots__ = '_size',
def __init__(self, buf, byte_width, size=0):
super().__init__(buf, byte_width)
if size == 0:
self._size = _Unpack(U, self.SizeBytes)
else:
self._size = size
@property
def SizeBytes(self):
return self._buf[-self._byte_width:0]
def __len__(self):
return self._size
class Blob(Sized):
"""Data accessor for the encoded blob bytes."""
__slots__ = ()
@property
def Bytes(self):
return self._buf[0:len(self)]
def __repr__(self):
return 'Blob(%s, size=%d)' % (self._buf, len(self))
class String(Sized):
"""Data accessor for the encoded string bytes."""
__slots__ = ()
@property
def Bytes(self):
return self._buf[0:len(self)]
def Mutate(self, value):
"""Mutates underlying string bytes in place.
Args:
value: New string to replace the existing one. New string must have less
or equal UTF-8-encoded bytes than the existing one to successfully
mutate underlying byte buffer.
Returns:
Whether the value was mutated or not.
"""
encoded = value.encode('utf-8')
n = len(encoded)
if n <= len(self):
self._buf[-self._byte_width:0] = _Pack(U, n, self._byte_width)
self._buf[0:n] = encoded
self._buf[n:len(self)] = bytearray(len(self) - n)
return True
return False
def __str__(self):
return self.Bytes.decode('utf-8')
def __repr__(self):
return 'String(%s, size=%d)' % (self._buf, len(self))
class Key(Object):
"""Data accessor for the encoded key bytes."""
__slots__ = ()
def __init__(self, buf, byte_width):
assert byte_width == 1
super().__init__(buf, byte_width)
@property
def Bytes(self):
return self._buf[0:len(self)]
def __len__(self):
return self._buf.Find(0)
def __str__(self):
return self.Bytes.decode('ascii')
def __repr__(self):
return 'Key(%s, size=%d)' % (self._buf, len(self))
class Vector(Sized):
"""Data accessor for the encoded vector bytes."""
__slots__ = ()
def __getitem__(self, index):
if index < 0 or index >= len(self):
raise IndexError('vector index %s is out of [0, %d) range' % \
(index, len(self)))
packed_type = self._buf[len(self) * self._byte_width + index]
buf = self._buf.Slice(index * self._byte_width)
return Ref.PackedType(buf, self._byte_width, packed_type)
@property
def Value(self):
"""Returns the underlying encoded data as a list object."""
return [e.Value for e in self]
def __repr__(self):
return 'Vector(%s, byte_width=%d, size=%d)' % \
(self._buf, self._byte_width, self._size)
class TypedVector(Sized):
"""Data accessor for the encoded typed vector or fixed typed vector bytes."""
__slots__ = '_element_type', '_size'
def __init__(self, buf, byte_width, element_type, size=0):
super().__init__(buf, byte_width, size)
if element_type == Type.STRING:
# These can't be accessed as strings, since we don't know the bit-width
# of the size field, see the declaration of
# FBT_VECTOR_STRING_DEPRECATED above for details.
# We change the type here to be keys, which are a subtype of strings,
# and will ignore the size field. This will truncate strings with
# embedded nulls.
element_type = Type.KEY
self._element_type = element_type
@property
def Bytes(self):
return self._buf[:self._byte_width * len(self)]
@property
def ElementType(self):
return self._element_type
def __getitem__(self, index):
if index < 0 or index >= len(self):
raise IndexError('vector index %s is out of [0, %d) range' % \
(index, len(self)))
buf = self._buf.Slice(index * self._byte_width)
return Ref(buf, self._byte_width, 1, self._element_type)
@property
def Value(self):
"""Returns underlying data as list object."""
if not self:
return []
if self._element_type is Type.BOOL:
return [bool(e) for e in _UnpackVector(U, self.Bytes, len(self))]
elif self._element_type is Type.INT:
return list(_UnpackVector(I, self.Bytes, len(self)))
elif self._element_type is Type.UINT:
return list(_UnpackVector(U, self.Bytes, len(self)))
elif self._element_type is Type.FLOAT:
return list(_UnpackVector(F, self.Bytes, len(self)))
elif self._element_type is Type.KEY:
return [e.AsKey for e in self]
elif self._element_type is Type.STRING:
return [e.AsString for e in self]
else:
raise TypeError('unsupported element_type: %s' % self._element_type)
def __repr__(self):
return 'TypedVector(%s, byte_width=%d, element_type=%s, size=%d)' % \
(self._buf, self._byte_width, self._element_type, self._size)
class Map(Vector):
"""Data accessor for the encoded map bytes."""
@staticmethod
def CompareKeys(a, b):
if isinstance(a, Ref):
a = a.AsKeyBytes
if isinstance(b, Ref):
b = b.AsKeyBytes
return a < b
def __getitem__(self, key):
if isinstance(key, int):
return super().__getitem__(key)
index = _BinarySearch(self.Keys, key.encode('ascii'), self.CompareKeys)
if index != -1:
return super().__getitem__(index)
raise KeyError(key)
@property
def Keys(self):
byte_width = _Unpack(U, self._buf[-2 * self._byte_width:-self._byte_width])
buf = self._buf.Indirect(-3 * self._byte_width, self._byte_width)
return TypedVector(buf, byte_width, Type.KEY)
@property
def Values(self):
return Vector(self._buf, self._byte_width)
@property
def Value(self):
return {k.Value: v.Value for k, v in zip(self.Keys, self.Values)}
def __repr__(self):
return 'Map(%s, size=%d)' % (self._buf, len(self))
class Ref:
"""Data accessor for the encoded data bytes."""
__slots__ = '_buf', '_parent_width', '_byte_width', '_type'
@staticmethod
def PackedType(buf, parent_width, packed_type):
byte_width, type_ = Type.Unpack(packed_type)
return Ref(buf, parent_width, byte_width, type_)
def __init__(self, buf, parent_width, byte_width, type_):
self._buf = buf
self._parent_width = parent_width
self._byte_width = byte_width
self._type = type_
def __repr__(self):
return 'Ref(%s, parent_width=%d, byte_width=%d, type_=%s)' % \
(self._buf, self._parent_width, self._byte_width, self._type)
@property
def _Bytes(self):
return self._buf[:self._parent_width]
def _ConvertError(self, target_type):
raise TypeError('cannot convert %s to %s' % (self._type, target_type))
def _Indirect(self):
return self._buf.Indirect(0, self._parent_width)
@property
def IsNull(self):
return self._type is Type.NULL
@property
def IsBool(self):
return self._type is Type.BOOL
@property
def AsBool(self):
if self._type is Type.BOOL:
return bool(_Unpack(U, self._Bytes))
else:
return self.AsInt != 0
def MutateBool(self, value):
"""Mutates underlying boolean value bytes in place.
Args:
value: New boolean value.
Returns:
Whether the value was mutated or not.
"""
return self.IsBool and \
_Mutate(U, self._buf, value, self._parent_width, BitWidth.W8)
@property
def IsNumeric(self):
return self.IsInt or self.IsFloat
@property
def IsInt(self):
return self._type in (Type.INT, Type.INDIRECT_INT, Type.UINT,
Type.INDIRECT_UINT)
@property
def AsInt(self):
"""Returns current reference as integer value."""
if self.IsNull:
return 0
elif self.IsBool:
return int(self.AsBool)
elif self._type is Type.INT:
return _Unpack(I, self._Bytes)
elif self._type is Type.INDIRECT_INT:
return _Unpack(I, self._Indirect()[:self._byte_width])
if self._type is Type.UINT:
return _Unpack(U, self._Bytes)
elif self._type is Type.INDIRECT_UINT:
return _Unpack(U, self._Indirect()[:self._byte_width])
elif self.IsString:
return len(self.AsString)
elif self.IsKey:
return len(self.AsKey)
elif self.IsBlob:
return len(self.AsBlob)
elif self.IsVector:
return len(self.AsVector)
elif self.IsTypedVector:
return len(self.AsTypedVector)
elif self.IsFixedTypedVector:
return len(self.AsFixedTypedVector)
else:
raise self._ConvertError(Type.INT)
def MutateInt(self, value):
"""Mutates underlying integer value bytes in place.
Args:
value: New integer value. It must fit to the byte size of the existing
encoded value.
Returns:
Whether the value was mutated or not.
"""
if self._type is Type.INT:
return _Mutate(I, self._buf, value, self._parent_width, BitWidth.I(value))
elif self._type is Type.INDIRECT_INT:
return _Mutate(I, self._Indirect(), value, self._byte_width,
BitWidth.I(value))
elif self._type is Type.UINT:
return _Mutate(U, self._buf, value, self._parent_width, BitWidth.U(value))
elif self._type is Type.INDIRECT_UINT:
return _Mutate(U, self._Indirect(), value, self._byte_width,
BitWidth.U(value))
else:
return False
@property
def IsFloat(self):
return self._type in (Type.FLOAT, Type.INDIRECT_FLOAT)
@property
def AsFloat(self):
"""Returns current reference as floating point value."""
if self.IsNull:
return 0.0
elif self.IsBool:
return float(self.AsBool)
elif self.IsInt:
return float(self.AsInt)
elif self._type is Type.FLOAT:
return _Unpack(F, self._Bytes)
elif self._type is Type.INDIRECT_FLOAT:
return _Unpack(F, self._Indirect()[:self._byte_width])
elif self.IsString:
return float(self.AsString)
elif self.IsVector:
return float(len(self.AsVector))
elif self.IsTypedVector():
return float(len(self.AsTypedVector))
elif self.IsFixedTypedVector():
return float(len(self.FixedTypedVector))
else:
raise self._ConvertError(Type.FLOAT)
def MutateFloat(self, value):
"""Mutates underlying floating point value bytes in place.
Args:
value: New float value. It must fit to the byte size of the existing
encoded value.
Returns:
Whether the value was mutated or not.
"""
if self._type is Type.FLOAT:
return _Mutate(F, self._buf, value, self._parent_width,
BitWidth.B(self._parent_width))
elif self._type is Type.INDIRECT_FLOAT:
return _Mutate(F, self._Indirect(), value, self._byte_width,
BitWidth.B(self._byte_width))
else:
return False
@property
def IsKey(self):
return self._type is Type.KEY
@property
def AsKeyBytes(self):
if self.IsKey:
return Key(self._Indirect(), self._byte_width).Bytes
else:
raise self._ConvertError(Type.KEY)
@property
def AsKey(self):
if self.IsKey:
return str(Key(self._Indirect(), self._byte_width))
else:
raise self._ConvertError(Type.KEY)
@property
def IsString(self):
return self._type is Type.STRING
@property
def AsStringBytes(self):
if self.IsString:
return String(self._Indirect(), self._byte_width).Bytes
elif self.IsKey:
return self.AsKeyBytes
else:
raise self._ConvertError(Type.STRING)
@property
def AsString(self):
if self.IsString:
return str(String(self._Indirect(), self._byte_width))
elif self.IsKey:
return self.AsKey
else:
raise self._ConvertError(Type.STRING)
def MutateString(self, value):
return String(self._Indirect(), self._byte_width).Mutate(value)
@property
def IsBlob(self):
return self._type is Type.BLOB
@property
def AsBlob(self):
if self.IsBlob:
return Blob(self._Indirect(), self._byte_width).Bytes
else:
raise self._ConvertError(Type.BLOB)
@property
def IsAnyVector(self):
return self.IsVector or self.IsTypedVector or self.IsFixedTypedVector()
@property
def IsVector(self):
return self._type in (Type.VECTOR, Type.MAP)
@property
def AsVector(self):
if self.IsVector:
return Vector(self._Indirect(), self._byte_width)
else:
raise self._ConvertError(Type.VECTOR)
@property
def IsTypedVector(self):
return Type.IsTypedVector(self._type)
@property
def AsTypedVector(self):
if self.IsTypedVector:
return TypedVector(self._Indirect(), self._byte_width,
Type.ToTypedVectorElementType(self._type))
else:
raise self._ConvertError('TYPED_VECTOR')
@property
def IsFixedTypedVector(self):
return Type.IsFixedTypedVector(self._type)
@property
def AsFixedTypedVector(self):
if self.IsFixedTypedVector:
element_type, size = Type.ToFixedTypedVectorElementType(self._type)
return TypedVector(self._Indirect(), self._byte_width, element_type, size)
else:
raise self._ConvertError('FIXED_TYPED_VECTOR')
@property
def IsMap(self):
return self._type is Type.MAP
@property
def AsMap(self):
if self.IsMap:
return Map(self._Indirect(), self._byte_width)
else:
raise self._ConvertError(Type.MAP)
@property
def Value(self):
"""Converts current reference to value of corresponding type.
This is equivalent to calling `AsInt` for integer values, `AsFloat` for
floating point values, etc.
Returns:
Value of corresponding type.
"""
if self.IsNull:
return None
elif self.IsBool:
return self.AsBool
elif self.IsInt:
return self.AsInt
elif self.IsFloat:
return self.AsFloat
elif self.IsString:
return self.AsString
elif self.IsKey:
return self.AsKey
elif self.IsBlob:
return self.AsBlob
elif self.IsMap:
return self.AsMap.Value
elif self.IsVector:
return self.AsVector.Value
elif self.IsTypedVector:
return self.AsTypedVector.Value
elif self.IsFixedTypedVector:
return self.AsFixedTypedVector.Value
else:
raise TypeError('cannot convert %r to value' % self)
def _IsIterable(obj):
try:
iter(obj)
return True
except TypeError:
return False
class Value:
"""Class to represent given value during the encoding process."""
@staticmethod
def Null():
return Value(0, Type.NULL, BitWidth.W8)
@staticmethod
def Bool(value):
return Value(value, Type.BOOL, BitWidth.W8)
@staticmethod
def Int(value, bit_width):
return Value(value, Type.INT, bit_width)
@staticmethod
def UInt(value, bit_width):
return Value(value, Type.UINT, bit_width)
@staticmethod
def Float(value, bit_width):
return Value(value, Type.FLOAT, bit_width)
@staticmethod
def Key(offset):
return Value(offset, Type.KEY, BitWidth.W8)
def __init__(self, value, type_, min_bit_width):
self._value = value
self._type = type_
# For scalars: of itself, for vector: of its elements, for string: length.
self._min_bit_width = min_bit_width
@property
def Value(self):
return self._value
@property
def Type(self):
return self._type
@property
def MinBitWidth(self):
return self._min_bit_width
def StoredPackedType(self, parent_bit_width=BitWidth.W8):
return Type.Pack(self._type, self.StoredWidth(parent_bit_width))
# We have an absolute offset, but want to store a relative offset
# elem_index elements beyond the current buffer end. Since whether
# the relative offset fits in a certain byte_width depends on
# the size of the elements before it (and their alignment), we have
# to test for each size in turn.
def ElemWidth(self, buf_size, elem_index=0):
if Type.IsInline(self._type):
return self._min_bit_width
for byte_width in 1, 2, 4, 8:
offset_loc = buf_size + _PaddingBytes(buf_size, byte_width) + \
elem_index * byte_width
bit_width = BitWidth.U(offset_loc - self._value)
if byte_width == (1 << bit_width):
return bit_width
raise ValueError('relative offset is too big')
def StoredWidth(self, parent_bit_width=BitWidth.W8):
if Type.IsInline(self._type):
return max(self._min_bit_width, parent_bit_width)
return self._min_bit_width
def __repr__(self):
return 'Value(%s, %s, %s)' % (self._value, self._type, self._min_bit_width)
def __str__(self):
return str(self._value)
def InMap(func):
def wrapper(self, *args, **kwargs):
if isinstance(args[0], str):
self.Key(args[0])
func(self, *args[1:], **kwargs)
else:
func(self, *args, **kwargs)
return wrapper
def InMapForString(func):
def wrapper(self, *args):
if len(args) == 1:
func(self, args[0])
elif len(args) == 2:
self.Key(args[0])
func(self, args[1])
else:
raise ValueError('invalid number of arguments')
return wrapper
class Pool:
"""Collection of (data, offset) pairs sorted by data for quick access."""
def __init__(self):
self._pool = [] # sorted list of (data, offset) tuples
def FindOrInsert(self, data, offset):
do = data, offset
index = _BinarySearch(self._pool, do, lambda a, b: a[0] < b[0])
if index != -1:
_, offset = self._pool[index]
return offset
self._pool.insert(index, do)
return None
def Clear(self):
self._pool = []
@property
def Elements(self):
return [data for data, _ in self._pool]
class Builder:
"""Helper class to encode structural data into flexbuffers format."""
def __init__(self,
share_strings=False,
share_keys=True,
force_min_bit_width=BitWidth.W8):
self._share_strings = share_strings
self._share_keys = share_keys
self._force_min_bit_width = force_min_bit_width
self._string_pool = Pool()
self._key_pool = Pool()
self._finished = False
self._buf = bytearray()
self._stack = []
def __len__(self):
return len(self._buf)
@property
def StringPool(self):
return self._string_pool
@property
def KeyPool(self):
return self._key_pool
def Clear(self):
self._string_pool.Clear()
self._key_pool.Clear()
self._finished = False
self._buf = bytearray()
self._stack = []
def Finish(self):
"""Finishes encoding process and returns underlying buffer."""
if self._finished:
raise RuntimeError('builder has been already finished')
# If you hit this exception, you likely have objects that were never
# included in a parent. You need to have exactly one root to finish a
# buffer. Check your Start/End calls are matched, and all objects are inside
# some other object.
if len(self._stack) != 1:
raise RuntimeError('internal stack size must be one')
value = self._stack[0]
byte_width = self._Align(value.ElemWidth(len(self._buf)))
self._WriteAny(value, byte_width=byte_width) # Root value
self._Write(U, value.StoredPackedType(), byte_width=1) # Root type
self._Write(U, byte_width, byte_width=1) # Root size
self.finished = True
return self._buf
def _ReadKey(self, offset):
key = self._buf[offset:]
return key[:key.find(0)]
def _Align(self, alignment):
byte_width = 1 << alignment
self._buf.extend(b'\x00' * _PaddingBytes(len(self._buf), byte_width))
return byte_width
def _Write(self, fmt, value, byte_width):
self._buf.extend(_Pack(fmt, value, byte_width))
def _WriteVector(self, fmt, values, byte_width):
self._buf.extend(_PackVector(fmt, values, byte_width))
def _WriteOffset(self, offset, byte_width):
relative_offset = len(self._buf) - offset
assert byte_width == 8 or relative_offset < (1 << (8 * byte_width))
self._Write(U, relative_offset, byte_width)
def _WriteAny(self, value, byte_width):
fmt = {
Type.NULL: U, Type.BOOL: U, Type.INT: I, Type.UINT: U, Type.FLOAT: F
}.get(value.Type)
if fmt:
self._Write(fmt, value.Value, byte_width)
else:
self._WriteOffset(value.Value, byte_width)
def _WriteBlob(self, data, append_zero, type_):
bit_width = BitWidth.U(len(data))
byte_width = self._Align(bit_width)
self._Write(U, len(data), byte_width)
loc = len(self._buf)
self._buf.extend(data)
if append_zero:
self._buf.append(0)
self._stack.append(Value(loc, type_, bit_width))
return loc
def _WriteScalarVector(self, element_type, byte_width, elements, fixed):
"""Writes scalar vector elements to the underlying buffer."""
bit_width = BitWidth.B(byte_width)
# If you get this exception, you're trying to write a vector with a size
# field that is bigger than the scalars you're trying to write (e.g. a
# byte vector > 255 elements). For such types, write a "blob" instead.
if BitWidth.U(len(elements)) > bit_width:
raise ValueError('too many elements for the given byte_width')
self._Align(bit_width)
if not fixed:
self._Write(U, len(elements), byte_width)
loc = len(self._buf)
fmt = {Type.INT: I, Type.UINT: U, Type.FLOAT: F}.get(element_type)
if not fmt:
raise TypeError('unsupported element_type')
self._WriteVector(fmt, elements, byte_width)
type_ = Type.ToTypedVector(element_type, len(elements) if fixed else 0)
self._stack.append(Value(loc, type_, bit_width))
return loc
def _CreateVector(self, elements, typed, fixed, keys=None):
"""Writes vector elements to the underlying buffer."""
length = len(elements)
if fixed and not typed:
raise ValueError('fixed vector must be typed')
# Figure out smallest bit width we can store this vector with.
bit_width = max(self._force_min_bit_width, BitWidth.U(length))
prefix_elems = 1 # Vector size
if keys:
bit_width = max(bit_width, keys.ElemWidth(len(self._buf)))
prefix_elems += 2 # Offset to the keys vector and its byte width.
vector_type = Type.KEY
# Check bit widths and types for all elements.
for i, e in enumerate(elements):
bit_width = max(bit_width, e.ElemWidth(len(self._buf), prefix_elems + i))
if typed:
if i == 0:
vector_type = e.Type
else:
if vector_type != e.Type:
raise RuntimeError('typed vector elements must be of the same type')
if fixed and not Type.IsFixedTypedVectorElementType(vector_type):
raise RuntimeError('must be fixed typed vector element type')
byte_width = self._Align(bit_width)
# Write vector. First the keys width/offset if available, and size.
if keys:
self._WriteOffset(keys.Value, byte_width)
self._Write(U, 1 << keys.MinBitWidth, byte_width)
if not fixed:
self._Write(U, length, byte_width)
# Then the actual data.
loc = len(self._buf)
for e in elements:
self._WriteAny(e, byte_width)
# Then the types.
if not typed:
for e in elements:
self._buf.append(e.StoredPackedType(bit_width))
if keys:
type_ = Type.MAP
else:
if typed:
type_ = Type.ToTypedVector(vector_type, length if fixed else 0)
else:
type_ = Type.VECTOR
return Value(loc, type_, bit_width)
def _PushIndirect(self, value, type_, bit_width):
byte_width = self._Align(bit_width)
loc = len(self._buf)
fmt = {
Type.INDIRECT_INT: I,
Type.INDIRECT_UINT: U,
Type.INDIRECT_FLOAT: F
}[type_]
self._Write(fmt, value, byte_width)
self._stack.append(Value(loc, type_, bit_width))
@InMapForString
def String(self, value):
"""Encodes string value."""
reset_to = len(self._buf)
encoded = value.encode('utf-8')
loc = self._WriteBlob(encoded, append_zero=True, type_=Type.STRING)
if self._share_strings:
prev_loc = self._string_pool.FindOrInsert(encoded, loc)
if prev_loc is not None:
del self._buf[reset_to:]
self._stack[-1]._value = loc = prev_loc # pylint: disable=protected-access
return loc
@InMap
def Blob(self, value):
"""Encodes binary blob value.
Args:
value: A byte/bytearray value to encode
Returns:
Offset of the encoded value in underlying the byte buffer.
"""
return self._WriteBlob(value, append_zero=False, type_=Type.BLOB)
def Key(self, value):
"""Encodes key value.
Args:
value: A byte/bytearray/str value to encode. Byte object must not contain
zero bytes. String object must be convertible to ASCII.
Returns:
Offset of the encoded value in the underlying byte buffer.
"""
if isinstance(value, (bytes, bytearray)):
encoded = value
else:
encoded = value.encode('ascii')
if 0 in encoded:
raise ValueError('key contains zero byte')
loc = len(self._buf)
self._buf.extend(encoded)
self._buf.append(0)
if self._share_keys:
prev_loc = self._key_pool.FindOrInsert(encoded, loc)
if prev_loc is not None:
del self._buf[loc:]
loc = prev_loc
self._stack.append(Value.Key(loc))
return loc
def Null(self, key=None):
"""Encodes None value."""
if key:
self.Key(key)
self._stack.append(Value.Null())
@InMap
def Bool(self, value):
"""Encodes boolean value.
Args:
value: A boolean value.
"""
self._stack.append(Value.Bool(value))
@InMap
def Int(self, value, byte_width=0):
"""Encodes signed integer value.
Args:
value: A signed integer value.
byte_width: Number of bytes to use: 1, 2, 4, or 8.
"""
bit_width = BitWidth.I(value) if byte_width == 0 else BitWidth.B(byte_width)
self._stack.append(Value.Int(value, bit_width))
@InMap
def IndirectInt(self, value, byte_width=0):
"""Encodes signed integer value indirectly.
Args:
value: A signed integer value.
byte_width: Number of bytes to use: 1, 2, 4, or 8.
"""
bit_width = BitWidth.I(value) if byte_width == 0 else BitWidth.B(byte_width)
self._PushIndirect(value, Type.INDIRECT_INT, bit_width)
@InMap
def UInt(self, value, byte_width=0):
"""Encodes unsigned integer value.
Args:
value: An unsigned integer value.
byte_width: Number of bytes to use: 1, 2, 4, or 8.
"""
bit_width = BitWidth.U(value) if byte_width == 0 else BitWidth.B(byte_width)
self._stack.append(Value.UInt(value, bit_width))
@InMap
def IndirectUInt(self, value, byte_width=0):
"""Encodes unsigned integer value indirectly.
Args:
value: An unsigned integer value.
byte_width: Number of bytes to use: 1, 2, 4, or 8.
"""
bit_width = BitWidth.U(value) if byte_width == 0 else BitWidth.B(byte_width)
self._PushIndirect(value, Type.INDIRECT_UINT, bit_width)
@InMap
def Float(self, value, byte_width=0):
"""Encodes floating point value.
Args:
value: A floating point value.
byte_width: Number of bytes to use: 4 or 8.
"""
bit_width = BitWidth.F(value) if byte_width == 0 else BitWidth.B(byte_width)
self._stack.append(Value.Float(value, bit_width))
@InMap
def IndirectFloat(self, value, byte_width=0):
"""Encodes floating point value indirectly.
Args:
value: A floating point value.
byte_width: Number of bytes to use: 4 or 8.
"""
bit_width = BitWidth.F(value) if byte_width == 0 else BitWidth.B(byte_width)
self._PushIndirect(value, Type.INDIRECT_FLOAT, bit_width)
def _StartVector(self):
"""Starts vector construction."""
return len(self._stack)
def _EndVector(self, start, typed, fixed):
"""Finishes vector construction by encodung its elements."""
vec = self._CreateVector(self._stack[start:], typed, fixed)
del self._stack[start:]
self._stack.append(vec)
return vec.Value
@contextlib.contextmanager
def Vector(self, key=None):
if key:
self.Key(key)
try:
start = self._StartVector()
yield self
finally:
self._EndVector(start, typed=False, fixed=False)
@InMap
def VectorFromElements(self, elements):
"""Encodes sequence of any elements as a vector.
Args:
elements: sequence of elements, they may have different types.
"""
with self.Vector():
for e in elements:
self.Add(e)
@contextlib.contextmanager
def TypedVector(self, key=None):
if key:
self.Key(key)
try:
start = self._StartVector()
yield self
finally:
self._EndVector(start, typed=True, fixed=False)
@InMap
def TypedVectorFromElements(self, elements, element_type=None):
"""Encodes sequence of elements of the same type as typed vector.
Args:
elements: Sequence of elements, they must be of the same type.
element_type: Suggested element type. Setting it to None means determining
correct value automatically based on the given elements.
"""
if isinstance(elements, array.array):
if elements.typecode == 'f':
self._WriteScalarVector(Type.FLOAT, 4, elements, fixed=False)
elif elements.typecode == 'd':
self._WriteScalarVector(Type.FLOAT, 8, elements, fixed=False)
elif elements.typecode in ('b', 'h', 'i', 'l', 'q'):
self._WriteScalarVector(
Type.INT, elements.itemsize, elements, fixed=False)
elif elements.typecode in ('B', 'H', 'I', 'L', 'Q'):
self._WriteScalarVector(
Type.UINT, elements.itemsize, elements, fixed=False)
else:
raise ValueError('unsupported array typecode: %s' % elements.typecode)
else:
add = self.Add if element_type is None else self.Adder(element_type)
with self.TypedVector():
for e in elements:
add(e)
@InMap
def FixedTypedVectorFromElements(self,
elements,
element_type=None,
byte_width=0):
"""Encodes sequence of elements of the same type as fixed typed vector.
Args:
elements: Sequence of elements, they must be of the same type. Allowed
types are `Type.INT`, `Type.UINT`, `Type.FLOAT`. Allowed number of
elements are 2, 3, or 4.
element_type: Suggested element type. Setting it to None means determining
correct value automatically based on the given elements.
byte_width: Number of bytes to use per element. For `Type.INT` and
`Type.UINT`: 1, 2, 4, or 8. For `Type.FLOAT`: 4 or 8. Setting it to 0
means determining correct value automatically based on the given
elements.
"""
if not 2 <= len(elements) <= 4:
raise ValueError('only 2, 3, or 4 elements are supported')
types = {type(e) for e in elements}
if len(types) != 1:
raise TypeError('all elements must be of the same type')
type_, = types
if element_type is None:
element_type = {int: Type.INT, float: Type.FLOAT}.get(type_)
if not element_type:
raise TypeError('unsupported element_type: %s' % type_)
if byte_width == 0:
width = {
Type.UINT: BitWidth.U,
Type.INT: BitWidth.I,
Type.FLOAT: BitWidth.F
}[element_type]
byte_width = 1 << max(width(e) for e in elements)
self._WriteScalarVector(element_type, byte_width, elements, fixed=True)
def _StartMap(self):
"""Starts map construction."""
return len(self._stack)
def _EndMap(self, start):
"""Finishes map construction by encodung its elements."""
# Interleaved keys and values on the stack.
stack = self._stack[start:]
if len(stack) % 2 != 0:
raise RuntimeError('must be even number of keys and values')
for key in stack[::2]:
if key.Type is not Type.KEY:
raise RuntimeError('all map keys must be of %s type' % Type.KEY)
pairs = zip(stack[::2], stack[1::2]) # [(key, value), ...]
pairs = sorted(pairs, key=lambda pair: self._ReadKey(pair[0].Value))
del self._stack[start:]
for pair in pairs:
self._stack.extend(pair)
keys = self._CreateVector(self._stack[start::2], typed=True, fixed=False)
values = self._CreateVector(
self._stack[start + 1::2], typed=False, fixed=False, keys=keys)
del self._stack[start:]
self._stack.append(values)
return values.Value
@contextlib.contextmanager
def Map(self, key=None):
if key:
self.Key(key)
try:
start = self._StartMap()
yield self
finally:
self._EndMap(start)
def MapFromElements(self, elements):
start = self._StartMap()
for k, v in elements.items():
self.Key(k)
self.Add(v)
self._EndMap(start)
def Adder(self, type_):
return {
Type.BOOL: self.Bool,
Type.INT: self.Int,
Type.INDIRECT_INT: self.IndirectInt,
Type.UINT: self.UInt,
Type.INDIRECT_UINT: self.IndirectUInt,
Type.FLOAT: self.Float,
Type.INDIRECT_FLOAT: self.IndirectFloat,
Type.KEY: self.Key,
Type.BLOB: self.Blob,
Type.STRING: self.String,
}[type_]
@InMapForString
def Add(self, value):
"""Encodes value of any supported type."""
if value is None:
self.Null()
elif isinstance(value, bool):
self.Bool(value)
elif isinstance(value, int):
self.Int(value)
elif isinstance(value, float):
self.Float(value)
elif isinstance(value, str):
self.String(value)
elif isinstance(value, (bytes, bytearray)):
self.Blob(value)
elif isinstance(value, dict):
with self.Map():
for k, v in value.items():
self.Key(k)
self.Add(v)
elif isinstance(value, array.array):
self.TypedVectorFromElements(value)
elif _IsIterable(value):
self.VectorFromElements(value)
else:
raise TypeError('unsupported python type: %s' % type(value))
@property
def LastValue(self):
return self._stack[-1]
@InMap
def ReuseValue(self, value):
self._stack.append(value)
def GetRoot(buf):
"""Returns root `Ref` object for the given buffer."""
if len(buf) < 3:
raise ValueError('buffer is too small')
byte_width = buf[-1]
return Ref.PackedType(
Buf(buf, -(2 + byte_width)), byte_width, packed_type=buf[-2])
def Dumps(obj):
"""Returns bytearray with the encoded python object."""
fbb = Builder()
fbb.Add(obj)
return fbb.Finish()
def Loads(buf):
"""Returns python object decoded from the buffer."""
return GetRoot(buf).Value