from PIL import Image

class Equalize:
	MIN_LEVEL = 0.0
	MAX_LEVEL = 1.0
	R = 0
	G = 1
	B = 2

	@property
	def level(self):
		return self._level

	@property
	def level_b(self) -> float:
		return self._level[Equalize.B]

	@property
	def level_g(self) -> float:
		return self._level[Equalize.G]

	@property
	def level_r(self) -> float:
		return self._level[Equalize.R]

	def __init__(self, level_b: float, level_g: float, level_r: float) -> None:
		if level_b < Equalize.MIN_LEVEL or level_b > Equalize.MAX_LEVEL:
			raise ValueError('level_b')
		if level_g < Equalize.MIN_LEVEL or level_g > Equalize.MAX_LEVEL:
			raise ValueError('level_g')
		if level_r < Equalize.MIN_LEVEL or level_r > Equalize.MAX_LEVEL:
			raise ValueError('level_r')
		self._level = [level_b, level_g, level_r]

	def set_level(self, image: Image) -> Image:
		src = image.convert('RGB')
		dest = Image.new('RGB', src.size)
		width, height = src.size
		histogram = [[0 for _ in range(256)], [0 for _ in range(256)], [0 for _ in range(256)]]
		src_bytes = src.tobytes()
		dest_bytes = bytearray(0 for _ in range(width * 3 * height))
		for y in range(height):
			for x in range(width):
				i = y * width * 3 + x * 3
				histogram[Equalize.R][src_bytes[i]] += 1
				histogram[Equalize.G][src_bytes[i + 1]] += 1
				histogram[Equalize.B][src_bytes[i + 2]] += 1
		self._level = self._get_auto_level(histogram)

	def filter(self, image: Image) -> Image:
		src = image.convert('RGB')
		dest = Image.new('RGB', src.size)
		width, height = src.size
		histogram = [[0 for _ in range(256)], [0 for _ in range(256)], [0 for _ in range(256)]]
		src_bytes = src.tobytes()
		dest_bytes = bytearray(0 for _ in range(width * 3 * height))
		for y in range(height):
			for x in range(width):
				i = y * width * 3 + x * 3
				histogram[Equalize.R][src_bytes[i]] += 1
				histogram[Equalize.G][src_bytes[i + 1]] += 1
				histogram[Equalize.B][src_bytes[i + 2]] += 1
		eqmap = [[0 for _ in range(256)], [0 for _ in range(256)], [0 for _ in range(256)]]
		self._create_map(histogram, eqmap)
		for y in range(height):
			for x in range(width):
				i = y * width * 3 + x * 3
				dest_bytes[i] = eqmap[Equalize.R][src_bytes[i]]
				dest_bytes[i + 1] = eqmap[Equalize.G][src_bytes[i + 1]]
				dest_bytes[i + 2] = eqmap[Equalize.B][src_bytes[i + 2]]
		dest.frombytes(dest_bytes)
		return dest

	def _get_auto_level(self, histogram: [[int]]) -> [float, float, float]:
		mapbgr = [[0 for _ in range(256)], [0 for _ in range(256)], [0 for _ in range(256)]]
		b = 0
		g = 0
		r = 0
		for i in range(256):
			b += histogram[Equalize.B][i]
			mapbgr[Equalize.B][i] = b
			g += histogram[Equalize.G][i]
			mapbgr[Equalize.G][i] = g
			r += histogram[Equalize.R][i]
			mapbgr[Equalize.R][i] = r

		threshold = 64
		level = [1.0, 1.0, 1.0]
		for color in range(3):
			map = mapbgr[color]
			low = map[0]
			high = map[255]
			max_diff = 0
			for i in range(256):
				diff = abs(int((map[i] - low) * 255 / max(high - low, 1)) - i)
				if diff > max_diff:
					max_diff = diff
			if max_diff > threshold:
				level[color] = threshold / max_diff
		return level

	def _create_map(self, histogram: [[int]], eqmap: [[int]]) -> None:
		mapbgr = [[0 for _ in range(256)], [0 for _ in range(256)], [0 for _ in range(256)]]
		b = 0
		g = 0
		r = 0
		for i in range(256):
			b += histogram[Equalize.B][i]
			mapbgr[Equalize.B][i] = b
			g += histogram[Equalize.G][i]
			mapbgr[Equalize.G][i] = g
			r += histogram[Equalize.R][i]
			mapbgr[Equalize.R][i] = r

		for color in range(3):
			map = mapbgr[color]
			low = map[0]
			high = map[255]
			level = self.level[color]
			for i in range(256):
				c = i
				if level > Equalize.MIN_LEVEL:
					value = int((map[i] - low) * 255 / max(high - low, 1))
					if level == Equalize.MAX_LEVEL:
						c = value
					else:
						c = i + int((value - i) * level)
				eqmap[color][i] = c