removeback / app.py
joermd's picture
Update app.py
6f2b794 verified
raw
history blame
2.79 kB
import streamlit as st
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
# إعداد الجهاز للعمل على CPU
device = torch.device("cpu")
torch.set_float32_matmul_precision("high")
# تحميل النموذج
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
).to(device)
# تحويل الصورة
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
# دالة المعالجة
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
st.title("أداة إزالة الخلفية")
# واجهة المستخدم
tab = st.sidebar.selectbox("اختر طريقة الإدخال:", ["رفع صورة", "رابط صورة", "ملف"])
if tab == "رفع صورة":
uploaded_file = st.file_uploader("ارفع صورة:")
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="الصورة الأصلية", use_column_width=True)
processed_image = process(image)
st.image(processed_image, caption="الصورة المعالجة", use_column_width=True)
elif tab == "رابط صورة":
url = st.text_input("أدخل رابط الصورة:")
if url:
try:
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
st.image(image, caption="الصورة الأصلية", use_column_width=True)
processed_image = process(image)
st.image(processed_image, caption="الصورة المعالجة", use_column_width=True)
except Exception as e:
st.error(f"خطأ أثناء تحميل الصورة: {e}")
elif tab == "ملف":
uploaded_file = st.file_uploader("ارفع ملف:")
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
processed_image = process(image)
output_path = uploaded_file.name.rsplit(".", 1)[0] + ".png"
processed_image.save(output_path)
st.image(processed_image, caption="الصورة المعالجة", use_column_width=True)
st.download_button("تحميل الصورة المعالجة", data=open(output_path, "rb"), file_name=output_path)