|
""" |
|
23-nov-21 |
|
https://github.com/sniklaus/revisiting-sepconv/blob/fea509d98157170df1fb35bf615bd41d98858e1a/run.py |
|
https://github.com/sniklaus/revisiting-sepconv/blob/fea509d98157170df1fb35bf615bd41d98858e1a/sepconv/sepconv.py |
|
Deleted stuffs about arguments_strModel and getopt |
|
""" |
|
|
|
import torch |
|
import typing |
|
from comfy.model_management import get_torch_device |
|
|
|
|
|
from vfi_models.ops import sepconv_func |
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
import math |
|
import numpy |
|
import os |
|
import PIL |
|
import PIL.Image |
|
import sys |
|
import typing |
|
|
|
|
|
|
|
assert ( |
|
int(str("").join(torch.__version__.split(".")[0:2])) >= 13 |
|
) |
|
|
|
torch.set_grad_enabled( |
|
False |
|
) |
|
|
|
torch.backends.cudnn.enabled = ( |
|
True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class Basic(torch.nn.Module): |
|
def __init__( |
|
self, |
|
strType: str, |
|
intChans: typing.List[int], |
|
objScratch: typing.Optional[typing.Dict] = None, |
|
): |
|
super().__init__() |
|
|
|
self.strType = strType |
|
self.netEvenize = None |
|
self.netMain = None |
|
self.netShortcut = None |
|
|
|
intIn = intChans[0] |
|
intOut = intChans[-1] |
|
netMain = [] |
|
intChans = intChans.copy() |
|
fltStride = 1.0 |
|
|
|
for intPart, strPart in enumerate(self.strType.split("+")[0].split("-")): |
|
if strPart.startswith("conv") == True: |
|
intKsize = 3 |
|
intPad = 1 |
|
strPad = "zeros" |
|
|
|
if "(" in strPart: |
|
intKsize = int(strPart.split("(")[1].split(")")[0].split(",")[0]) |
|
intPad = int(math.floor(0.5 * (intKsize - 1))) |
|
|
|
if "replpad" in strPart.split("(")[1].split(")")[0].split(","): |
|
strPad = "replicate" |
|
if "reflpad" in strPart.split("(")[1].split(")")[0].split(","): |
|
strPad = "reflect" |
|
|
|
|
|
if "nopad" in self.strType.split("+"): |
|
intPad = 0 |
|
|
|
|
|
netMain += [ |
|
torch.nn.Conv2d( |
|
in_channels=intChans[0], |
|
out_channels=intChans[1], |
|
kernel_size=intKsize, |
|
stride=1, |
|
padding=intPad, |
|
padding_mode=strPad, |
|
bias="nobias" not in self.strType.split("+"), |
|
) |
|
] |
|
intChans = intChans[1:] |
|
fltStride *= 1.0 |
|
|
|
elif strPart.startswith("sconv") == True: |
|
intKsize = 3 |
|
intPad = 1 |
|
strPad = "zeros" |
|
|
|
if "(" in strPart: |
|
intKsize = int(strPart.split("(")[1].split(")")[0].split(",")[0]) |
|
intPad = int(math.floor(0.5 * (intKsize - 1))) |
|
|
|
if "replpad" in strPart.split("(")[1].split(")")[0].split(","): |
|
strPad = "replicate" |
|
if "reflpad" in strPart.split("(")[1].split(")")[0].split(","): |
|
strPad = "reflect" |
|
|
|
|
|
if "nopad" in self.strType.split("+"): |
|
intPad = 0 |
|
|
|
|
|
netMain += [ |
|
torch.nn.Conv2d( |
|
in_channels=intChans[0], |
|
out_channels=intChans[1], |
|
kernel_size=intKsize, |
|
stride=2, |
|
padding=intPad, |
|
padding_mode=strPad, |
|
bias="nobias" not in self.strType.split("+"), |
|
) |
|
] |
|
intChans = intChans[1:] |
|
fltStride *= 2.0 |
|
|
|
elif strPart.startswith("up") == True: |
|
|
|
class Up(torch.nn.Module): |
|
def __init__(self, strType): |
|
super().__init__() |
|
|
|
self.strType = strType |
|
|
|
|
|
|
|
def forward(self, tenIn: torch.Tensor) -> torch.Tensor: |
|
if self.strType == "nearest": |
|
return torch.nn.functional.interpolate( |
|
input=tenIn, |
|
scale_factor=2.0, |
|
mode="nearest", |
|
align_corners=False, |
|
) |
|
|
|
elif self.strType == "bilinear": |
|
return torch.nn.functional.interpolate( |
|
input=tenIn, |
|
scale_factor=2.0, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
elif self.strType == "pyramid": |
|
return pyramid(tenIn, None, "up") |
|
|
|
elif self.strType == "shuffle": |
|
return torch.nn.functional.pixel_shuffle( |
|
tenIn, upscale_factor=2 |
|
) |
|
|
|
|
|
|
|
assert False |
|
|
|
|
|
|
|
|
|
|
|
strType = "bilinear" |
|
|
|
if "(" in strPart: |
|
if "nearest" in strPart.split("(")[1].split(")")[0].split(","): |
|
strType = "nearest" |
|
if "pyramid" in strPart.split("(")[1].split(")")[0].split(","): |
|
strType = "pyramid" |
|
if "shuffle" in strPart.split("(")[1].split(")")[0].split(","): |
|
strType = "shuffle" |
|
|
|
|
|
netMain += [Up(strType)] |
|
fltStride *= 0.5 |
|
|
|
elif strPart.startswith("prelu") == True: |
|
netMain += [ |
|
torch.nn.PReLU( |
|
num_parameters=1, |
|
init=float(strPart.split("(")[1].split(")")[0].split(",")[0]), |
|
) |
|
] |
|
fltStride *= 1.0 |
|
|
|
elif True: |
|
assert False |
|
|
|
|
|
|
|
|
|
self.netMain = torch.nn.Sequential(*netMain) |
|
|
|
for strPart in self.strType.split("+")[1:]: |
|
if strPart.startswith("skip") == True: |
|
if intIn == intOut and fltStride == 1.0: |
|
self.netShortcut = torch.nn.Identity() |
|
|
|
elif intIn != intOut and fltStride == 1.0: |
|
self.netShortcut = torch.nn.Conv2d( |
|
in_channels=intIn, |
|
out_channels=intOut, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias="nobias" not in self.strType.split("+"), |
|
) |
|
|
|
elif intIn == intOut and fltStride != 1.0: |
|
|
|
class Down(torch.nn.Module): |
|
def __init__(self, fltScale): |
|
super().__init__() |
|
|
|
self.fltScale = fltScale |
|
|
|
|
|
|
|
def forward(self, tenIn: torch.Tensor) -> torch.Tensor: |
|
return torch.nn.functional.interpolate( |
|
input=tenIn, |
|
scale_factor=self.fltScale, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.netShortcut = Down(1.0 / fltStride) |
|
|
|
elif intIn != intOut and fltStride != 1.0: |
|
|
|
class Down(torch.nn.Module): |
|
def __init__(self, fltScale): |
|
super().__init__() |
|
|
|
self.fltScale = fltScale |
|
|
|
|
|
|
|
def forward(self, tenIn: torch.Tensor) -> torch.Tensor: |
|
return torch.nn.functional.interpolate( |
|
input=tenIn, |
|
scale_factor=self.fltScale, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.netShortcut = torch.nn.Sequential( |
|
Down(1.0 / fltStride), |
|
torch.nn.Conv2d( |
|
in_channels=intIn, |
|
out_channels=intOut, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias="nobias" not in self.strType.split("+"), |
|
), |
|
) |
|
|
|
|
|
|
|
elif strPart.startswith("...") == True: |
|
pass |
|
|
|
|
|
|
|
|
|
assert len(intChans) == 1 |
|
|
|
|
|
|
|
def forward(self, tenIn: torch.Tensor) -> torch.Tensor: |
|
if self.netEvenize is not None: |
|
tenIn = self.netEvenize(tenIn) |
|
|
|
|
|
tenOut = self.netMain(tenIn) |
|
|
|
if self.netShortcut is not None: |
|
tenOut = tenOut + self.netShortcut(tenIn) |
|
|
|
|
|
return tenOut |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Encode(torch.nn.Module): |
|
objScratch: typing.Dict[str, typing.List[int]] = None |
|
|
|
def __init__( |
|
self, |
|
intIns: typing.List[int], |
|
intOuts: typing.List[int], |
|
strHor: str, |
|
strVer: str, |
|
objScratch: typing.Dict[str, typing.List[int]], |
|
): |
|
super().__init__() |
|
|
|
assert len(intIns) == len(intOuts) |
|
assert len(intOuts) == len(intIns) |
|
|
|
self.intRows = len(intIns) and len(intOuts) |
|
self.intIns = intIns.copy() |
|
self.intOuts = intOuts.copy() |
|
self.strHor = strHor |
|
self.strVer = strVer |
|
self.objScratch = objScratch |
|
|
|
self.netHor = torch.nn.ModuleList() |
|
self.netVer = torch.nn.ModuleList() |
|
|
|
for intRow in range(self.intRows): |
|
netHor = torch.nn.Identity() |
|
netVer = torch.nn.Identity() |
|
|
|
if self.intOuts[intRow] != 0: |
|
if self.intIns[intRow] != 0: |
|
netHor = Basic( |
|
self.strHor, |
|
[ |
|
self.intIns[intRow], |
|
self.intOuts[intRow], |
|
self.intOuts[intRow], |
|
], |
|
objScratch, |
|
) |
|
|
|
|
|
if intRow != 0: |
|
netVer = Basic( |
|
self.strVer, |
|
[ |
|
self.intOuts[intRow - 1], |
|
self.intOuts[intRow], |
|
self.intOuts[intRow], |
|
], |
|
objScratch, |
|
) |
|
|
|
|
|
|
|
self.netHor.append(netHor) |
|
self.netVer.append(netVer) |
|
|
|
|
|
|
|
|
|
def forward(self, tenIns: typing.List[torch.Tensor]) -> typing.List[torch.Tensor]: |
|
intRow = 0 |
|
for netHor in self.netHor: |
|
if self.intOuts[intRow] != 0: |
|
if self.intIns[intRow] != 0: |
|
tenIns[intRow] = netHor(tenIns[intRow]) |
|
|
|
|
|
intRow += 1 |
|
|
|
|
|
intRow = 0 |
|
for netVer in self.netVer: |
|
if self.intOuts[intRow] != 0: |
|
if intRow != 0: |
|
tenIns[intRow] = tenIns[intRow] + netVer(tenIns[intRow - 1]) |
|
|
|
|
|
intRow += 1 |
|
|
|
|
|
for intRow, tenIn in enumerate(tenIns): |
|
self.objScratch["levelshape" + str(intRow)] = tenIn.shape |
|
|
|
|
|
return tenIns |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Decode(torch.nn.Module): |
|
objScratch: typing.Dict[str, typing.List[int]] = None |
|
|
|
def __init__( |
|
self, |
|
intIns: typing.List[int], |
|
intOuts: typing.List[int], |
|
strHor: str, |
|
strVer: str, |
|
objScratch: typing.Dict[str, typing.List[int]], |
|
): |
|
super().__init__() |
|
|
|
assert len(intIns) == len(intOuts) |
|
assert len(intOuts) == len(intIns) |
|
|
|
self.intRows = len(intIns) and len(intOuts) |
|
self.intIns = intIns.copy() |
|
self.intOuts = intOuts.copy() |
|
self.strHor = strHor |
|
self.strVer = strVer |
|
self.objScratch = objScratch |
|
|
|
self.netHor = torch.nn.ModuleList() |
|
self.netVer = torch.nn.ModuleList() |
|
|
|
for intRow in range(self.intRows - 1, -1, -1): |
|
netHor = torch.nn.Identity() |
|
netVer = torch.nn.Identity() |
|
|
|
if self.intOuts[intRow] != 0: |
|
if self.intIns[intRow] != 0: |
|
netHor = Basic( |
|
self.strHor, |
|
[ |
|
self.intIns[intRow], |
|
self.intOuts[intRow], |
|
self.intOuts[intRow], |
|
], |
|
objScratch, |
|
) |
|
|
|
|
|
if intRow != self.intRows - 1: |
|
netVer = Basic( |
|
self.strVer, |
|
[ |
|
self.intOuts[intRow + 1], |
|
self.intOuts[intRow], |
|
self.intOuts[intRow], |
|
], |
|
objScratch, |
|
) |
|
|
|
|
|
|
|
self.netHor.append(netHor) |
|
self.netVer.append(netVer) |
|
|
|
|
|
|
|
|
|
def forward(self, tenIns: typing.List[torch.Tensor]) -> typing.List[torch.Tensor]: |
|
intRow = self.intRows - 1 |
|
for netHor in self.netHor: |
|
if self.intOuts[intRow] != 0: |
|
if self.intIns[intRow] != 0: |
|
tenIns[intRow] = netHor(tenIns[intRow]) |
|
|
|
|
|
intRow -= 1 |
|
|
|
|
|
intRow = self.intRows - 1 |
|
for netVer in self.netVer: |
|
if self.intOuts[intRow] != 0: |
|
if intRow != self.intRows - 1: |
|
tenVer = netVer(tenIns[intRow + 1]) |
|
|
|
if "levelshape" + str(intRow) in self.objScratch: |
|
if ( |
|
tenVer.shape[2] |
|
== self.objScratch["levelshape" + str(intRow)][2] + 1 |
|
): |
|
tenVer = torch.nn.functional.pad( |
|
input=tenVer, |
|
pad=[0, 0, 0, -1], |
|
mode="constant", |
|
value=0.0, |
|
) |
|
if ( |
|
tenVer.shape[3] |
|
== self.objScratch["levelshape" + str(intRow)][3] + 1 |
|
): |
|
tenVer = torch.nn.functional.pad( |
|
input=tenVer, |
|
pad=[0, -1, 0, 0], |
|
mode="constant", |
|
value=0.0, |
|
) |
|
|
|
|
|
tenIns[intRow] = tenIns[intRow] + tenVer |
|
|
|
|
|
intRow -= 1 |
|
|
|
|
|
return tenIns |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Network(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.intEncdec = [1, 1] |
|
self.intChannels = [32, 64, 128, 256, 512] |
|
|
|
self.objScratch = {} |
|
|
|
self.netInput = torch.nn.Conv2d( |
|
in_channels=3, |
|
out_channels=int(round(0.5 * self.intChannels[0])), |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
padding_mode="zeros", |
|
) |
|
|
|
self.netEncode = torch.nn.Sequential( |
|
*( |
|
[ |
|
Encode( |
|
[0] * len(self.intChannels), |
|
self.intChannels, |
|
"prelu(0.25)-conv(3)-prelu(0.25)-conv(3)+skip", |
|
"prelu(0.25)-sconv(3)-prelu(0.25)-conv(3)", |
|
self.objScratch, |
|
) |
|
] |
|
+ [ |
|
Encode( |
|
self.intChannels, |
|
self.intChannels, |
|
"prelu(0.25)-conv(3)-prelu(0.25)-conv(3)+skip", |
|
"prelu(0.25)-sconv(3)-prelu(0.25)-conv(3)", |
|
self.objScratch, |
|
) |
|
for intEncdec in range(1, self.intEncdec[0]) |
|
] |
|
) |
|
) |
|
|
|
self.netDecode = torch.nn.Sequential( |
|
*( |
|
[ |
|
Decode( |
|
[0] + self.intChannels[1:], |
|
[0] + self.intChannels[1:], |
|
"prelu(0.25)-conv(3)-prelu(0.25)-conv(3)+skip", |
|
"prelu(0.25)-up(bilinear)-conv(3)-prelu(0.25)-conv(3)", |
|
self.objScratch, |
|
) |
|
for intEncdec in range(0, self.intEncdec[1]) |
|
] |
|
) |
|
) |
|
|
|
self.netVerone = Basic( |
|
"up(bilinear)-conv(3)-prelu(0.25)-conv(3)", |
|
[self.intChannels[1], self.intChannels[1], 51], |
|
) |
|
self.netVertwo = Basic( |
|
"up(bilinear)-conv(3)-prelu(0.25)-conv(3)", |
|
[self.intChannels[1], self.intChannels[1], 51], |
|
) |
|
self.netHorone = Basic( |
|
"up(bilinear)-conv(3)-prelu(0.25)-conv(3)", |
|
[self.intChannels[1], self.intChannels[1], 51], |
|
) |
|
self.netHortwo = Basic( |
|
"up(bilinear)-conv(3)-prelu(0.25)-conv(3)", |
|
[self.intChannels[1], self.intChannels[1], 51], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x1, x2): |
|
|
|
intWidth = x1.shape[3] |
|
intHeight = x1.shape[2] |
|
|
|
intPadr = (2 - (intWidth % 2)) % 2 |
|
intPadb = (2 - (intHeight % 2)) % 2 |
|
|
|
tenOne = torch.nn.functional.pad( |
|
input=x1, pad=[0, intPadr, 0, intPadb], mode="replicate" |
|
) |
|
tenTwo = torch.nn.functional.pad( |
|
input=x2, pad=[0, intPadr, 0, intPadb], mode="replicate" |
|
) |
|
|
|
|
|
tenSeq = [tenOne, tenTwo] |
|
|
|
with torch.set_grad_enabled(False): |
|
tenStack = torch.stack(tenSeq, 1) |
|
tenMean = ( |
|
tenStack.view(tenStack.shape[0], -1) |
|
.mean(1, True) |
|
.view(tenStack.shape[0], 1, 1, 1) |
|
) |
|
tenStd = ( |
|
tenStack.view(tenStack.shape[0], -1) |
|
.std(1, True) |
|
.view(tenStack.shape[0], 1, 1, 1) |
|
) |
|
tenSeq = [ |
|
(tenFrame - tenMean) / (tenStd + 0.0000001) for tenFrame in tenSeq |
|
] |
|
tenSeq = [tenFrame.detach() for tenFrame in tenSeq] |
|
|
|
|
|
tenOut = self.netDecode( |
|
self.netEncode( |
|
[torch.cat([self.netInput(tenSeq[0]), self.netInput(tenSeq[1])], 1)] |
|
+ ([0.0] * (len(self.intChannels) - 1)) |
|
) |
|
)[1] |
|
|
|
tenOne = torch.nn.functional.pad( |
|
input=tenOne, |
|
pad=[ |
|
int(math.floor(0.5 * 51)), |
|
int(math.floor(0.5 * 51)), |
|
int(math.floor(0.5 * 51)), |
|
int(math.floor(0.5 * 51)), |
|
], |
|
mode="replicate", |
|
) |
|
tenTwo = torch.nn.functional.pad( |
|
input=tenTwo, |
|
pad=[ |
|
int(math.floor(0.5 * 51)), |
|
int(math.floor(0.5 * 51)), |
|
int(math.floor(0.5 * 51)), |
|
int(math.floor(0.5 * 51)), |
|
], |
|
mode="replicate", |
|
) |
|
|
|
tenOne = torch.cat( |
|
[ |
|
tenOne, |
|
tenOne.new_ones([tenOne.shape[0], 1, tenOne.shape[2], tenOne.shape[3]]), |
|
], |
|
1, |
|
).detach() |
|
tenTwo = torch.cat( |
|
[ |
|
tenTwo, |
|
tenTwo.new_ones([tenTwo.shape[0], 1, tenTwo.shape[2], tenTwo.shape[3]]), |
|
], |
|
1, |
|
).detach() |
|
|
|
tenVerone = self.netVerone(tenOut) |
|
tenVertwo = self.netVertwo(tenOut) |
|
tenHorone = self.netHorone(tenOut) |
|
tenHortwo = self.netHortwo(tenOut) |
|
|
|
tenOut = sepconv_func.apply(tenOne, tenVerone, tenHorone) + sepconv_func.apply( |
|
tenTwo, tenVertwo, tenHortwo |
|
) |
|
|
|
tenNormalize = tenOut[:, -1:, :, :] |
|
tenNormalize[tenNormalize.abs() < 0.01] = 1.0 |
|
tenOut = tenOut[:, :-1, :, :] / tenNormalize |
|
|
|
|
|
return tenOut[:, :, :intHeight, :intWidth] |
|
|
|
|
|
|
|
|
|
|
|
|
|
netNetwork = None |
|
|
|
|
|
|
|
|
|
def estimate(tenOne, tenTwo): |
|
global netNetwork |
|
|
|
if netNetwork is None: |
|
netNetwork = Network().to(get_torch_device()).eval() |
|
|
|
|
|
assert tenOne.shape[1] == tenTwo.shape[1] |
|
assert tenOne.shape[2] == tenTwo.shape[2] |
|
|
|
intWidth = tenOne.shape[2] |
|
intHeight = tenOne.shape[1] |
|
|
|
assert ( |
|
intWidth <= 1280 |
|
) |
|
assert ( |
|
intHeight <= 720 |
|
) |
|
|
|
tenPreprocessedOne = tenOne.to(get_torch_device()).view(1, 3, intHeight, intWidth) |
|
tenPreprocessedTwo = tenTwo.to(get_torch_device()).view(1, 3, intHeight, intWidth) |
|
|
|
intPadr = (2 - (intWidth % 2)) % 2 |
|
intPadb = (2 - (intHeight % 2)) % 2 |
|
|
|
tenPreprocessedOne = torch.nn.functional.pad( |
|
input=tenPreprocessedOne, pad=[0, intPadr, 0, intPadb], mode="replicate" |
|
) |
|
tenPreprocessedTwo = torch.nn.functional.pad( |
|
input=tenPreprocessedTwo, pad=[0, intPadr, 0, intPadb], mode="replicate" |
|
) |
|
|
|
return netNetwork([tenPreprocessedOne, tenPreprocessedTwo])[ |
|
0, :, :intHeight, :intWidth |
|
].cpu() |
|
|
|
|
|
|
|
|