Spaces:
Runtime error
Runtime error
# import the necessary packages | |
from tensorflow import keras | |
import tensorflow as tf | |
# Patch conv | |
class PatchConvNet(keras.Model): | |
def __init__( | |
self, | |
stem, | |
trunk, | |
attention_pooling, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.stem = stem | |
self.trunk = trunk | |
self.attention_pooling = attention_pooling | |
def call(self, images): | |
# pass through the stem | |
x = self.stem(images) | |
# pass through the trunk | |
x = self.trunk(x) | |
# pass through the attention pooling block | |
predictions, viz_weights = self.attention_pooling(x) | |
return predictions, viz_weights |