Newgen_2025 / app.py
gaur3009's picture
Update app.py
1d6741e verified
raw
history blame
3.01 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import cv2
import numpy as np
import gradio as gr
from skimage.transform import warp, PiecewiseAffineTransform
# Define U-Net model for cloth fold segmentation
class ClothFoldUNet(nn.Module):
def __init__(self):
super(ClothFoldUNet, self).__init__()
self.model = smp.Unet(
encoder_name="resnet34", # Pre-trained backbone
encoder_weights="imagenet",
in_channels=3,
classes=1, # Single channel output for segmentation
)
def forward(self, x):
return self.model(x)
# Load dataset (placeholder, replace with real dataset)
def get_dataloader(batch_size=8):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
dataset = datasets.FakeData(transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Train function
def train_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ClothFoldUNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()
dataloader = get_dataloader()
for epoch in range(10): # Placeholder epoch count
for images, _ in dataloader:
images = images.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, torch.ones_like(outputs)) # Placeholder loss
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}: Loss {loss.item():.4f}")
# Function to apply design onto cloth using segmentation mask
def apply_design(image, design, mask):
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
design = cv2.resize(design, (image.shape[1], image.shape[0]))
mask = np.expand_dims(mask, axis=-1)
blended = (mask * design) + ((1 - mask) * image)
return blended.astype(np.uint8)
# Gradio Interface
def process_image(image, design):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ClothFoldUNet().to(device)
model.eval()
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
mask = model(img_tensor).squeeze().cpu().numpy()
result = apply_design(np.array(image), np.array(design), mask)
return result
iface = gr.Interface(
fn=process_image,
inputs=["image", "image"],
outputs="image",
title="AI Cloth Design Blending",
description="Upload a cloth image and a design to blend the design onto the cloth while considering the folds."
)
# Run Gradio app
if __name__ == "__main__":
train_model()
iface.launch()