Spaces:
Runtime error
Runtime error
from diffusers import StableDiffusionInpaintPipeline | |
from src.utils.exceptions import CustomException | |
from cvzone.PoseModule import PoseDetector | |
from src.utils.functions import getConfig | |
from src.utils.logger import logger | |
from PIL.ImageOps import grayscale | |
from PIL import Image | |
import numpy as np | |
import cvzone | |
import torch | |
import math | |
import cv2 | |
import gc | |
class ClothingTryOn: | |
""" | |
A class to simulate clothing try-ons by overlaying clothing images on user images | |
and generating modified outputs using inpainting techniques. | |
This class utilizes a pose detection model to identify key landmarks on the user's | |
body, allowing for accurate placement and scaling of clothing images. It integrates | |
with a Stable Diffusion model for image generation, providing realistic visual | |
outputs based on specified prompts while ensuring that jewelry and accessories | |
do not interfere with the clothing representation. | |
Attributes: | |
detector (PoseDetector): An instance of PoseDetector for identifying body landmarks. | |
config (ConfigParser): Configuration settings loaded from an external config file. | |
pipeline (StableDiffusionInpaintPipeline): The Stable Diffusion inpainting model for | |
generating images based on user prompts and masks. | |
Methods: | |
getBinaryMask(image: Image.Image, jewellery: Image.Image) -> tuple[Image.Image]: | |
Generates a binary mask indicating the presence of the clothing on the user's image. | |
generateImage(image: Image.Image, mask: Image.Image) -> tuple[Image.Image]: | |
Applies inpainting to an image using the provided binary mask, generating new images | |
based on specific color prompts while excluding jewelry and accessories. | |
""" | |
def __init__(self): | |
"""Initialize the NecklaceTryOn class with a PoseDetector and configuration settings.""" | |
self.detector = PoseDetector() | |
self.config = getConfig("config.ini") | |
modelId = self.config.get("CLOTHING TRY ON", "modelId") | |
device = self.config.get("CLOTHING TRY ON", "device") | |
self.pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
modelId, torch_dtype = torch.float16 | |
).to(device) | |
def getBinaryMask(self, image: Image.Image, jewellery: Image.Image) -> tuple[Image.Image]: | |
""" | |
Generate a binary mask indicating the presence of the necklace on the user's image. | |
This function overlays a jewelry image on the user's image and creates a binary mask, where | |
the necklace is represented in white and the background in black. | |
Args: | |
image (Image.Image): The user's image, ideally captured in a standing, upright position. | |
jewellery (Image.Image): The image of the jewelry piece (e.g., necklace) to be overlaid. | |
Returns: | |
tuple[Image.Image]: A tuple containing: | |
- The first image as the necklace try-on output. | |
- The second image as the binary mask, with the necklace shown in white and the background in black. | |
Raises: | |
CustomException: If an error occurs during the image processing. | |
""" | |
try: | |
logger.info("converting images to numpy arrays") | |
image = np.array(image) | |
jewellery = np.array(jewellery) | |
logger.info("creating a copy of original image for actual overlay") | |
copyImage = image.copy() | |
logger.info("detecting body landmarks from the input image") | |
image = self.detector.findPose(image) | |
lmList, _ = self.detector.findPosition(image, bboxWithHands = False, draw = False) | |
pt12, pt11, pt10, pt9 = ( | |
lmList[12][:2], | |
lmList[11][:2], | |
lmList[10][:2], | |
lmList[9][:2], | |
) | |
logger.info("calculating the precise neck points") | |
avgX1 = int(pt12[0] + (pt10[0] - pt12[0]) / 1.75) | |
avgY1 = int(pt12[1] - (pt12[1] - pt10[1]) / 1.75) | |
avgX2 = int(pt11[0] - (pt11[0] - pt9[0]) / 1.75) | |
avgY2 = int(pt11[1] - (pt11[1] - pt9[1]) / 1.75) | |
logger.info("rescaling the necklace to appropriate dimensions") | |
xDist = avgX2 - avgX1 | |
origImgRatio = xDist / jewellery.shape[1] | |
yDist = jewellery.shape[0] * origImgRatio | |
jewellery = cv2.resize( | |
jewellery, (int(xDist), int(yDist)), interpolation = cv2.INTER_CUBIC | |
) | |
logger.info("calculating required offset to be added to the necklace image for perfect fitting") | |
imageGray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY) | |
for offsetOrig in range(imageGray.shape[1]): | |
pixelValue = imageGray[0, :][offsetOrig] | |
if (pixelValue != 255) & (pixelValue != 0): | |
break | |
else: | |
continue | |
offset = int(self.config.getfloat("NECKLACE TRY ON", "offsetFactor") * xDist * (offsetOrig / jewellery.shape[1])) | |
yCoordinate = avgY1 - offset | |
logger.info("tilting the necklace image as per the necklace points") | |
angle = math.ceil( | |
self.detector.findAngle( | |
p1 = (avgX2, avgY2), p2 = (avgX1, avgY1), p3 = (avgX2, avgY1) | |
)[0] | |
) | |
if avgY2 < avgY1: | |
pass | |
else: | |
angle = angle * -1 | |
jewellery = cvzone.rotateImage(jewellery, angle) | |
logger.info("checking if the necklace is getting out of the frame and trimming from above if needed") | |
availableSpace = copyImage.shape[0] - yCoordinate | |
extra = jewellery.shape[0] - availableSpace | |
logger.info("applying the calculated settings") | |
if extra > 0: | |
jewellery = jewellery[extra + 10 :, :] | |
return self.getBinaryMask( | |
Image.fromarray(copyImage), Image.fromarray(jewellery) | |
) | |
else: | |
tryOnOutput = cvzone.overlayPNG(copyImage, jewellery, (avgX1, yCoordinate)) | |
tryOnOutput = Image.fromarray(tryOnOutput.astype(np.uint8)) | |
blackedNecklace = np.zeros(shape = copyImage.shape) | |
cvzone.overlayPNG(blackedNecklace, jewellery, (avgX1, yCoordinate)) | |
blackedNecklace = cv2.cvtColor(blackedNecklace.astype(np.uint8), cv2.COLOR_BGR2GRAY) | |
binaryMask = blackedNecklace * ((blackedNecklace > 5) * 255) | |
binaryMask[binaryMask >= 255] = 255 | |
binaryMask[binaryMask < 255] = 0 | |
binaryMask = Image.fromarray(binaryMask.astype(np.uint8)) | |
return (tryOnOutput, binaryMask) | |
except Exception as e: | |
logger.error(CustomException(e)) | |
print(CustomException(e)) | |
def generateImage(self, image: Image.Image, mask: Image.Image) -> tuple[Image.Image]: | |
""" | |
Apply inpainting to an image using the provided binary mask. | |
This function utilizes the binary mask to inpaint areas of the image, enhancing the visual output | |
by generating new images based on specific color prompts while excluding jewelry and other accessories. | |
Args: | |
image (Image.Image): The input image where inpainting will be applied. | |
mask (Image.Image): The binary mask indicating areas to be inpainted. | |
Returns: | |
tuple: A tuple containing three images generated based on different color prompts. | |
Raises: | |
CustomException: If an error occurs during the image processing. | |
""" | |
try: | |
logger.info("creating a mask where the jewellery is represented") | |
jewelleryMask = Image.fromarray(np.bitwise_and(np.array(mask.convert("RGB")), np.array(image.convert("RGB")))) | |
arrOrig = np.array(grayscale(mask)) | |
logger.info("inpainting the image using the original mask") | |
image = cv2.inpaint(np.array(image), arrOrig, 15, cv2.INPAINT_TELEA) | |
image = Image.fromarray(image) | |
logger.info("preparing the mask for processing") | |
arr = arrOrig.copy() | |
maskY = np.where(arr == arr[arr != 0][0])[0][0] | |
arr[maskY:, :] = 255 | |
newMask = Image.fromarray(arr) | |
mask = newMask.copy() | |
logger.info("resizing images for consistency") | |
origSize = image.size | |
image = image.resize((512, 512)) | |
mask = mask.resize((512, 512)) | |
logger.info("generating images for different colors") | |
results = [] | |
for colour in ["Red", "Blue", "Green"]: | |
prompt = f"{colour}, South Indian Saree, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple" | |
negativePrompt = ("necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, " | |
"jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, " | |
"watermark, text, changed background, wider body, narrower body, bad proportions, " | |
"extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, " | |
"blurry, ugly") | |
output = self.pipeline( | |
prompt = prompt, | |
negative_prompt = negativePrompt, | |
image = image, | |
mask_image = mask, | |
strength = 0.95, | |
guidance_score = 9, | |
).images[0] | |
logger.info("resizing the output to original size") | |
output = output.resize(origSize) | |
tempGenerated = np.bitwise_and( | |
np.array(output), | |
np.bitwise_not(np.array(Image.fromarray(arrOrig).convert("RGB"))), | |
) | |
results.append(tempGenerated) | |
logger.info("combining the results with the jewellery mask") | |
results = [ | |
Image.fromarray(np.bitwise_or(x, np.array(jewelleryMask))) for x in results | |
] | |
logger.info("Image generation completed successfully.") | |
gc.collect() | |
torch.cuda.empty_cache() | |
return (results[0], results[1], results[2]) | |
except Exception as e: | |
logger.error(CustomException(e)) | |
print(CustomException(e)) |