Kuldip2411's picture
Upload 13 files
8478c4c verified
import os
import torch
import numpy as np
from PIL import Image
from simple_lama_inpainting.utils.util import prepare_img_and_mask, download_model
LAMA_MODEL_URL = os.environ.get(
"LAMA_MODEL_URL",
"https://github.com/enesmsahin/simple-lama-inpainting/releases/download/v0.1.0/big-lama.pt", # noqa
)
class SimpleLama:
def __init__(
self,
device: torch.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
) -> None:
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else:
model_path = download_model(LAMA_MODEL_URL)
self.model = torch.jit.load(model_path, map_location=device)
self.model.eval()
self.model.to(device)
self.device = device
def __call__(self, image: Image.Image | np.ndarray, mask: Image.Image | np.ndarray):
image, mask = prepare_img_and_mask(image, mask, self.device)
with torch.inference_mode():
inpainted = self.model(image, mask)
cur_res = inpainted[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = np.clip(cur_res * 255, 0, 255).astype(np.uint8)
cur_res = Image.fromarray(cur_res)
return cur_res