import gradio as gr import uuid from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation import os # بارگذاری مدل def load_model(device_type): device = torch.device(device_type) model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) model.to(device) model.eval() return model, device # پیش پردازش تصویر transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def remove_background(uploaded_image, device_type="cpu"): try: # بارگذاری مدل بر اساس انتخاب دستگاه model, device = load_model(device_type) image = Image.open(uploaded_image) # پیش‌پردازش تصویر input_image = transform_image(image).unsqueeze(0).to(device) # پردازش تصویر با مدل with torch.no_grad(): preds = model(input_image)[-1].sigmoid().cpu() pred = preds[0].squeeze() mask = transforms.ToPILImage()(pred).resize(image.size) image.putalpha(mask) # ایجاد پوشه "media" در صورت عدم وجود media_dir = "../media" if not os.path.exists(media_dir): os.makedirs(media_dir) # ذخیره تصویر پردازش‌شده random_filename = str(uuid.uuid4()) + ".png" processed_image_path = os.path.join(media_dir, f"processed_{random_filename}") image.save(processed_image_path, format="PNG") # بارگذاری تصویر پردازش‌شده و ارسال به عنوان خروجی processed_image = Image.open(processed_image_path) # برگرداندن تصویر پردازش‌شده به عنوان خروجی Gradio return processed_image except Exception as e: return f"خطا در پردازش تصویر: {str(e)}" # ایجاد رابط کاربری Gradio gradio_app = gr.Interface( fn=remove_background, # تابع پردازش تصویر inputs=[ gr.Image(type="filepath"), # ورودی تصویر به صورت مسیر فایل gr.Dropdown(choices=["cpu", "cuda"], label="Select Device", value="cpu") # انتخاب CPU یا GPU ], outputs=gr.Image(type="pil") # خروجی تصویر به صورت PIL ) if __name__ == "__main__": # اجرای Gradio gradio_app.launch()