samanfa9828's picture
Add application file
6d230f3
raw
history blame
2.56 kB
import gradio as gr
import uuid
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import os
# بارگذاری مدل
def load_model(device_type):
device = torch.device(device_type)
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
model.to(device)
model.eval()
return model, device
# پیش پردازش تصویر
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def remove_background(uploaded_image, device_type="cpu"):
try:
# بارگذاری مدل بر اساس انتخاب دستگاه
model, device = load_model(device_type)
image = Image.open(uploaded_image)
# پیش‌پردازش تصویر
input_image = transform_image(image).unsqueeze(0).to(device)
# پردازش تصویر با مدل
with torch.no_grad():
preds = model(input_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
mask = transforms.ToPILImage()(pred).resize(image.size)
image.putalpha(mask)
# ایجاد پوشه "media" در صورت عدم وجود
media_dir = "../media"
if not os.path.exists(media_dir):
os.makedirs(media_dir)
# ذخیره تصویر پردازش‌شده
random_filename = str(uuid.uuid4()) + ".png"
processed_image_path = os.path.join(media_dir, f"processed_{random_filename}")
image.save(processed_image_path, format="PNG")
# بارگذاری تصویر پردازش‌شده و ارسال به عنوان خروجی
processed_image = Image.open(processed_image_path)
# برگرداندن تصویر پردازش‌شده به عنوان خروجی Gradio
return processed_image
except Exception as e:
return f"خطا در پردازش تصویر: {str(e)}"
# ایجاد رابط کاربری Gradio
gradio_app = gr.Interface(
fn=remove_background, # تابع پردازش تصویر
inputs=[
gr.Image(type="filepath"), # ورودی تصویر به صورت مسیر فایل
gr.Dropdown(choices=["cpu", "cuda"], label="Select Device", value="cpu") # انتخاب CPU یا GPU
],
outputs=gr.Image(type="pil") # خروجی تصویر به صورت PIL
)
if __name__ == "__main__":
# اجرای Gradio
gradio_app.launch()