# import the necessary packages
from tensorflow import keras
import tensorflow as tf

# Patch conv
class PatchConvNet(keras.Model):
	def __init__(
		self,
		stem,
		trunk,
		attention_pooling,
		**kwargs,
	):
		super().__init__(**kwargs)
		self.stem = stem
		self.trunk = trunk
		self.attention_pooling = attention_pooling

	@tf.function(
	input_signature=[
		tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.uint8)
	])
	def call(self, images):
		# pass through the stem
		x = self.stem(images)
		# pass through the trunk
		x = self.trunk(x)
		# pass through the attention pooling block
		predictions, viz_weights = self.attention_pooling(x)
		return predictions, viz_weights