|
import os |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from simple_lama_inpainting.utils.util import prepare_img_and_mask, download_model |
|
|
|
LAMA_MODEL_URL = os.environ.get( |
|
"LAMA_MODEL_URL", |
|
"https://github.com/enesmsahin/simple-lama-inpainting/releases/download/v0.1.0/big-lama.pt", |
|
) |
|
|
|
|
|
class SimpleLama: |
|
def __init__( |
|
self, |
|
device: torch.device = torch.device( |
|
"cuda" if torch.cuda.is_available() else "cpu" |
|
), |
|
) -> None: |
|
if os.environ.get("LAMA_MODEL"): |
|
model_path = os.environ.get("LAMA_MODEL") |
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError( |
|
f"lama torchscript model not found: {model_path}" |
|
) |
|
else: |
|
model_path = download_model(LAMA_MODEL_URL) |
|
|
|
self.model = torch.jit.load(model_path, map_location=device) |
|
self.model.eval() |
|
self.model.to(device) |
|
self.device = device |
|
|
|
def __call__(self, image: Image.Image | np.ndarray, mask: Image.Image | np.ndarray): |
|
image, mask = prepare_img_and_mask(image, mask, self.device) |
|
|
|
with torch.inference_mode(): |
|
inpainted = self.model(image, mask) |
|
|
|
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy() |
|
cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8) |
|
|
|
cur_res = Image.fromarray(cur_res) |
|
return cur_res |
|
|