# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Flax modules and functions for using external memory.""" from typing import Any, Optional, Tuple from absl import logging from flax import linen import gin import jax from transformer import memory_layer PRNGKey = Any Shape = Tuple[int] Dtype = Any Array = Any MemoryResource = Any class MemoryManager: """Manages any external resources that may be required by external memory. MemoryManager also functions as a factory, to create Flax modules that will read and write to whatever external memory has been configured. """ def __init__(self, batch_size: int, mode: str, num_heads: int, key_size: int, value_size: int, database_size: Optional[int] = None, dtype: Dtype = "float32", off_device_memory: Optional[MemoryResource] = None): """Create a MemoryManager object. A MemoryManager configures external memory, and is used as a factory to construct flax modules that read or write to the memory. Args: batch_size: The number of separate documents in a batch. mode: e.g. ("train", or "test") num_heads: The number of transformer heads. key_size: The length of the key vectors. value_size: The length of the value vectors. database_size: The total number of tokens in the database. dtype: The datatype used for keys and values. off_device_memory: An object which manages underlying SCAM memory. If None, then the model will use on-device memory. """ self.batch_size = batch_size self.mode = mode self.num_heads = num_heads self.key_size = key_size self.value_size = value_size self.database_size = database_size self.dtype = dtype self.off_device_memory = off_device_memory def create_memory_layer(self) -> linen.Module: """Create a flax Module that implements external memory.""" num_datasets = ( self.batch_size * self.num_heads # if self.off_device_memory is None # else self.num_heads) if self.off_device_memory is not None: mem_layer = None if mem_layer is None: raise ValueError("Off-device memory is not supported at this time.") return memory_layer.BatchedMemory( mem_layer, split_dimensions=(-2,), ) else: assert self.database_size is not None mem_layer = memory_layer.MemoryOnTpu(num_datasets=num_datasets, key_features=self.key_size, value_features=self.value_size, database_size=self.database_size, dtype=self.dtype) # Handle queries of shape [batch_size, seq_len, num_heads, kv_features] return memory_layer.BatchedMemory(mem_layer, split_dimensions=(0, -2)) @gin.configurable def memory_on_tpu_factory(batch_size: int, mode: str, num_heads: int = gin.REQUIRED, key_size: int = gin.REQUIRED, value_size: int = gin.REQUIRED, database_size: int = gin.REQUIRED, dtype: Dtype = gin.REQUIRED) -> MemoryManager: """Implement SCAM memory on device.""" return MemoryManager(batch_size=batch_size, mode=mode, num_heads=num_heads, key_size=key_size, value_size=value_size, database_size=database_size, dtype=dtype, off_device_memory=None)