using TorchSharp; | |
public class AutoencoderKL | |
{ | |
private readonly torch.jit.ScriptModule _model; | |
private readonly float _scale; | |
public torch.Device Device {get;} | |
public AutoencoderKL(string modelPath, torch.Device device, float scale = 0.18215f) | |
{ | |
_model = TorchSharp.torch.jit.load(modelPath); | |
Device = device; | |
_model.to(Device); | |
_model.eval(); | |
_scale = scale; | |
} | |
public torch.Tensor Forward(torch.Tensor tokenTensor) | |
{ | |
var context = torch.enable_grad(false); | |
tokenTensor = 1.0f / _scale * tokenTensor; | |
return (torch.Tensor)_model.forward(tokenTensor); | |
} | |
} |