samanfa9828 commited on
Commit
6d230f3
·
1 Parent(s): 4b01537

Add application file

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import uuid
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+ import os
8
+
9
+ # بارگذاری مدل
10
+ def load_model(device_type):
11
+ device = torch.device(device_type)
12
+ model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
13
+ model.to(device)
14
+ model.eval()
15
+ return model, device
16
+
17
+ # پیش پردازش تصویر
18
+ transform_image = transforms.Compose([
19
+ transforms.Resize((1024, 1024)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
22
+ ])
23
+
24
+ def remove_background(uploaded_image, device_type="cpu"):
25
+ try:
26
+ # بارگذاری مدل بر اساس انتخاب دستگاه
27
+ model, device = load_model(device_type)
28
+
29
+ image = Image.open(uploaded_image)
30
+
31
+ # پیش‌پردازش تصویر
32
+ input_image = transform_image(image).unsqueeze(0).to(device)
33
+
34
+ # پردازش تصویر با مدل
35
+ with torch.no_grad():
36
+ preds = model(input_image)[-1].sigmoid().cpu()
37
+
38
+ pred = preds[0].squeeze()
39
+ mask = transforms.ToPILImage()(pred).resize(image.size)
40
+ image.putalpha(mask)
41
+
42
+ # ایجاد پوشه "media" در صورت عدم وجود
43
+ media_dir = "../media"
44
+ if not os.path.exists(media_dir):
45
+ os.makedirs(media_dir)
46
+
47
+ # ذخیره تصویر پردازش‌شده
48
+ random_filename = str(uuid.uuid4()) + ".png"
49
+ processed_image_path = os.path.join(media_dir, f"processed_{random_filename}")
50
+ image.save(processed_image_path, format="PNG")
51
+
52
+ # بارگذاری تصویر پردازش‌شده و ارسال به عنوان خروجی
53
+ processed_image = Image.open(processed_image_path)
54
+
55
+ # برگرداندن تصویر پردازش‌شده به عنوان خروجی Gradio
56
+ return processed_image
57
+ except Exception as e:
58
+ return f"خطا در پردازش تصویر: {str(e)}"
59
+
60
+ # ایجاد رابط کاربری Gradio
61
+ gradio_app = gr.Interface(
62
+ fn=remove_background, # تابع پردازش تصویر
63
+ inputs=[
64
+ gr.Image(type="filepath"), # ورودی تصویر به صورت مسیر فایل
65
+ gr.Dropdown(choices=["cpu", "cuda"], label="Select Device", value="cpu") # انتخاب CPU یا GPU
66
+ ],
67
+ outputs=gr.Image(type="pil") # خروجی تصویر به صورت PIL
68
+ )
69
+ if __name__ == "__main__":
70
+ # اجرای Gradio
71
+ gradio_app.launch()