# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import numpy as np from onnx.reference.op_run import OpRun def _conv_implementation( # type: ignore X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides ): if dilations is None: dilations = [1 for s in X.shape[2:]] if kernel_shape is None: kernel_shape = W.shape[2:] if pads is None: pads = [0 for s in X.shape[2:]] * 2 if strides is None: strides = [1 for s in X.shape[2:]] if X.shape[1] != W.shape[1] * group or W.shape[0] % group != 0: raise ValueError( f"Shape inconsistencies, X.shape={X.shape}, W.shape={W.shape}, group={group}, " f"W should be {(W.shape[0], X.shape[1] // group, np.prod(W.shape[1:]) // X.shape[1] * group)}." ) if group > 1: res = [] td = 0 mg = W.shape[0] // group dw = W.shape[1] for b in range(X.shape[0]): for g in range(group): gx = X[b : b + 1, g * dw : (g + 1) * dw] gw = W[g * mg : (g + 1) * mg] try: cv = _conv_implementation( gx, gw, None, auto_pad, dilations, 1, kernel_shape, pads, strides, ) except (ValueError, RuntimeError) as e: raise ValueError( f"Shape inconsistencies, X.shape={X.shape}, W.shape={W.shape}, group={g}/{group}, " f"gx.shape={gx.shape}, gw.shape={gw.shape}, auto_pad={auto_pad}, " f"dilations={dilations}, kernel_shape={kernel_shape}, pads={pads}, " f"strides={strides}." ) from e if b == 0: td += cv.shape[1] res.append((b, cv)) new_shape = [X.shape[0], *list(res[0][1].shape[1:])] new_shape[1] = td final = np.zeros(tuple(new_shape), dtype=res[0][1].dtype) p = 0 for b, cv in res: final[b : b + 1, p : p + cv.shape[1]] = cv p += cv.shape[1] if p >= final.shape[1]: p = 0 if B is not None: new_shape = [1 for s in final.shape] new_shape[1] = B.shape[0] b = B.reshape(tuple(new_shape)) final += b return final if dilations[0] != 1 or min(dilations) != max(dilations): # Let's compute the dilated kernel. nd = len(dilations) new_kernel_shape = [] new_shape = list(W.shape[:-nd]) for i, d in enumerate(dilations): di = len(W.shape) - nd + i new_shape.append(W.shape[di] + (W.shape[di] - 1) * (d - 1)) new_kernel_shape.append(kernel_shape[i] + (kernel_shape[i] - 1) * (d - 1)) new_w = np.zeros(tuple(new_shape), dtype=W.dtype) indices = [slice(0, new_w.shape[0]), slice(0, new_w.shape[1])] for i, d in enumerate(dilations): di = len(W.shape) - nd + i indices.append(slice(0, new_w.shape[di], d)) new_w[tuple(indices)] = W W = new_w kernel_shape = new_kernel_shape if auto_pad in {"SAME_LOWER", "SAME_UPPER", "VALID"}: head = [] tail = [] for i in range(len(X.shape) - 2): d = X.shape[i] target_size = (d + strides[i] - 1) // strides[i] pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d if auto_pad == "SAME_LOWER": pad_head = (pad_needed + 1) // 2 else: pad_head = pad_needed // 2 pad_tail = pad_needed - pad_head head.append(pad_head) tail.append(pad_tail) pads = head + tail if len(X.shape) == 3: sN, sC, sH = X.shape # M, C_group, kH, kW = W.shape (kh,) = kernel_shape (sth,) = strides h_out = int(((sH - kh + pads[0] + pads[1]) / sth) + 1) h0 = pads[0] oh = -1 * (kh % 2) bh = -h0 eh = h_out * sth res = np.zeros((X.shape[0], W.shape[0], h_out)) # type: ignore[assignment] if B is not None: res[:, :, :] += B.reshape((1, -1, 1)) # type: ignore for n in range(0, sN): for nw in range(W.shape[0]): for c in range(0, sC): w = W[nw : nw + 1, c : c + 1] for io in range(bh, eh, sth): hr = (io - bh) // sth if hr >= h_out: continue i = io + kh % 2 ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH) img = X[n : n + 1, c : c + 1, ih1:ih2] if img.shape != w.shape: jh1, jh2 = max(-oh - i, 0), min(kh, kh + sH - (i + oh + kh)) w_ = w[:1, :1, jh1:jh2] if img.shape != w_.shape: raise RuntimeError( f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, " f"i={i}, kh={kh}, sH={sH}, sth={sth}." ) s = np.dot(img.reshape((1, -1)), w_.reshape((-1, 1)))[ 0, 0 ] # (img * w_).sum() else: s = np.dot(img.reshape((1, -1)), w.reshape((-1, 1)))[ 0, 0 ] # (img * w).sum() res[n, nw, hr] += s # type: ignore return res if len(X.shape) == 4: sN, sC, sH, sW = X.shape # M, C_group, kH, kW = W.shape kh, kw = kernel_shape sth, stw = strides h_out = int(((sH - kh + pads[0] + pads[2]) / sth) + 1) w_out = int(((sW - kw + pads[1] + pads[3]) / stw) + 1) h0, w0 = pads[0], pads[1] oh, ow = -1 * (kh % 2), -1 * (kw % 2) bh, bw = -h0, -w0 eh, ew = h_out * sth, w_out * stw res = np.zeros((X.shape[0], W.shape[0], h_out, w_out)) # type: ignore[assignment] if B is not None: res[:, :, :, :] = B.reshape((1, -1, 1, 1)) # type: ignore for n in range(0, sN): for nw in range(W.shape[0]): for c in range(0, sC): w = W[nw : nw + 1, c : c + 1] for io in range(bh, eh, sth): hr = (io - bh) // sth if hr >= h_out: continue i = io + kh % 2 ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH) for jo in range(bw, ew, stw): wr = (jo - bw) // stw if wr >= w_out: continue j = jo + kw % 2 iw1, iw2 = max(0, j + ow), min(j + ow + kw, sW) img = X[n : n + 1, c : c + 1, ih1:ih2, iw1:iw2] if img.shape != w.shape: jh1, jh2 = max(-oh - i, 0), min( kh, kh + sH - (i + oh + kh) ) jw1, jw2 = max(-ow - j, 0), min( kw, kw + sW - (j + ow + kw) ) w_ = w[:1, :1, jh1:jh2, jw1:jw2] if img.shape != w_.shape: raise RuntimeError( f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, ow={ow}, " f"i={i}, j={j}, kh={kh}, kw={kw}, sH={sH}, sW={sW}, sth={sth}, stw={stw}." ) s = np.dot(img.reshape((1, -1)), w_.reshape((-1, 1)))[ 0, 0 ] # (img * w_).sum() else: s = np.dot(img.reshape((1, -1)), w.reshape((-1, 1)))[ 0, 0 ] # (img * w).sum() res[n, nw, hr, wr] += s # type: ignore return res if len(X.shape) == 5: sN, sC, sH, sW, sZ = X.shape kh, kw, kz = kernel_shape sth, stw, stz = strides h_out = int(((sH - kh + pads[0] + pads[3]) / sth) + 1) w_out = int(((sW - kw + pads[1] + pads[4]) / stw) + 1) z_out = int(((sZ - kz + pads[2] + pads[5]) / stz) + 1) h0, w0, z0 = pads[0], pads[1], pads[2] oh, ow, oz = -1 * (kh % 2), -1 * (kw % 2), -1 * (kz % 2) bh, bw, bz = -h0, -w0, -z0 eh, ew, ez = h_out * sth, w_out * stw, z_out * stz res = np.zeros((X.shape[0], W.shape[0], h_out, w_out, z_out)) # type: ignore[assignment] if B is not None: res[:, :, :, :, :] = B.reshape((1, -1, 1, 1, 1)) # type: ignore for n in range(0, sN): for nw in range(W.shape[0]): for c in range(0, sC): w = W[nw : nw + 1, c : c + 1] for io in range(bh, eh, sth): hr = (io - bh) // sth if hr >= h_out: continue i = io + kh % 2 ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH) for jo in range(bw, ew, stw): wr = (jo - bw) // stw if wr >= w_out: continue j = jo + kw % 2 iw1, iw2 = max(0, j + ow), min(j + ow + kw, sW) for zo in range(bz, ez, stz): zr = (zo - bz) // stz if zr >= z_out: continue z = zo + kz % 2 iz1, iz2 = max(0, z + oz), min(z + oz + kz, sZ) img = X[n : n + 1, c : c + 1, ih1:ih2, iw1:iw2, iz1:iz2] if img.shape != w.shape: jh1, jh2 = max(-oh - i, 0), min( kh, kh + sH - (i + oh + kh) ) jw1, jw2 = max(-ow - j, 0), min( kw, kw + sW - (j + ow + kw) ) jz1, jz2 = max(-oz - z, 0), min( kz, kz + sZ - (z + oz + kz) ) w_ = w[:1, :1, jh1:jh2, jw1:jw2, jz1:jz2] if img.shape != w_.shape: raise RuntimeError( f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, ow={ow}, oz={oz}, " f"i={i}, j={j}, z={z}, kh={kh}, kw={kw}, kz={kz}, " f"sH={sH}, sW={sW}, sZ={sZ}, sth={sth}, stw={stw}, stz={stz}." ) s = np.dot( img.reshape((1, -1)), w_.reshape((-1, 1)) )[ 0, 0 ] # (img * w_).sum() else: s = np.dot( img.reshape((1, -1)), w.reshape((-1, 1)) )[ 0, 0 ] # (img * w).sum() res[n, nw, hr, wr, zr] += s # type: ignore return res raise RuntimeError( f"The convolution for X.shape={X.shape}, W.shape={W.shape}, " f"kernel_shape={kernel_shape} is not implemented yet." ) class Conv(OpRun): def _run( # type: ignore self, X, W, B=None, auto_pad=None, dilations=None, group=None, kernel_shape=None, pads=None, strides=None, ): if len(X.shape) < 3: raise ValueError( f"X must have at least 3 dimensions but its shape is {X.shape}." ) return ( # _conv_implementation( _conv_implementation( X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides ).astype(X.dtype), )