Spaces:
Running
Running
Commit
·
6d230f3
1
Parent(s):
4b01537
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import uuid
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from transformers import AutoModelForImageSegmentation
|
7 |
+
import os
|
8 |
+
|
9 |
+
# بارگذاری مدل
|
10 |
+
def load_model(device_type):
|
11 |
+
device = torch.device(device_type)
|
12 |
+
model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
|
13 |
+
model.to(device)
|
14 |
+
model.eval()
|
15 |
+
return model, device
|
16 |
+
|
17 |
+
# پیش پردازش تصویر
|
18 |
+
transform_image = transforms.Compose([
|
19 |
+
transforms.Resize((1024, 1024)),
|
20 |
+
transforms.ToTensor(),
|
21 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
22 |
+
])
|
23 |
+
|
24 |
+
def remove_background(uploaded_image, device_type="cpu"):
|
25 |
+
try:
|
26 |
+
# بارگذاری مدل بر اساس انتخاب دستگاه
|
27 |
+
model, device = load_model(device_type)
|
28 |
+
|
29 |
+
image = Image.open(uploaded_image)
|
30 |
+
|
31 |
+
# پیشپردازش تصویر
|
32 |
+
input_image = transform_image(image).unsqueeze(0).to(device)
|
33 |
+
|
34 |
+
# پردازش تصویر با مدل
|
35 |
+
with torch.no_grad():
|
36 |
+
preds = model(input_image)[-1].sigmoid().cpu()
|
37 |
+
|
38 |
+
pred = preds[0].squeeze()
|
39 |
+
mask = transforms.ToPILImage()(pred).resize(image.size)
|
40 |
+
image.putalpha(mask)
|
41 |
+
|
42 |
+
# ایجاد پوشه "media" در صورت عدم وجود
|
43 |
+
media_dir = "../media"
|
44 |
+
if not os.path.exists(media_dir):
|
45 |
+
os.makedirs(media_dir)
|
46 |
+
|
47 |
+
# ذخیره تصویر پردازششده
|
48 |
+
random_filename = str(uuid.uuid4()) + ".png"
|
49 |
+
processed_image_path = os.path.join(media_dir, f"processed_{random_filename}")
|
50 |
+
image.save(processed_image_path, format="PNG")
|
51 |
+
|
52 |
+
# بارگذاری تصویر پردازششده و ارسال به عنوان خروجی
|
53 |
+
processed_image = Image.open(processed_image_path)
|
54 |
+
|
55 |
+
# برگرداندن تصویر پردازششده به عنوان خروجی Gradio
|
56 |
+
return processed_image
|
57 |
+
except Exception as e:
|
58 |
+
return f"خطا در پردازش تصویر: {str(e)}"
|
59 |
+
|
60 |
+
# ایجاد رابط کاربری Gradio
|
61 |
+
gradio_app = gr.Interface(
|
62 |
+
fn=remove_background, # تابع پردازش تصویر
|
63 |
+
inputs=[
|
64 |
+
gr.Image(type="filepath"), # ورودی تصویر به صورت مسیر فایل
|
65 |
+
gr.Dropdown(choices=["cpu", "cuda"], label="Select Device", value="cpu") # انتخاب CPU یا GPU
|
66 |
+
],
|
67 |
+
outputs=gr.Image(type="pil") # خروجی تصویر به صورت PIL
|
68 |
+
)
|
69 |
+
if __name__ == "__main__":
|
70 |
+
# اجرای Gradio
|
71 |
+
gradio_app.launch()
|