removeback / app.py
joermd's picture
Update app.py
93ccd5a verified
raw
history blame
3.81 kB
import streamlit as st
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO
# إعداد الجهاز للعمل على CPU
device = torch.device("cpu")
torch.set_float32_matmul_precision("high")
# تحميل النموذج
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
).to(device)
# تحويل الصورة
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
# دالة المعالجة
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
# إعداد واجهة المستخدم
st.set_page_config(page_title="أداة إزالة الخلفية", layout="centered")
st.title("🌟 أداة إزالة الخلفية")
# اختيار نوع الإدخال
st.subheader("اختر طريقة إدخال الصورة:")
tab = st.radio("", ["رفع صورة", "رابط صورة", "ملف"])
# عملية إزالة الخلفية
if tab == "رفع صورة":
uploaded_file = st.file_uploader("ارفع صورة:")
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="📷 الصورة الأصلية", use_column_width=True)
with st.spinner("🔄 يتم إزالة الخلفية، يرجى الانتظار..."):
processed_image = process(image)
st.image(processed_image, caption="✨ الصورة المعالجة", use_column_width=True)
elif tab == "رابط صورة":
url = st.text_input("أدخل رابط الصورة:")
if url:
try:
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
st.image(image, caption="📷 الصورة الأصلية", use_column_width=True)
with st.spinner("🔄 يتم إزالة الخلفية، يرجى الانتظار..."):
processed_image = process(image)
st.image(processed_image, caption="✨ الصورة المعالجة", use_column_width=True)
except Exception as e:
st.error(f"❌ خطأ أثناء تحميل الصورة: {e}")
elif tab == "ملف":
uploaded_file = st.file_uploader("ارفع ملف:")
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
with st.spinner("🔄 يتم إزالة الخلفية، يرجى الانتظار..."):
processed_image = process(image)
output_path = uploaded_file.name.rsplit(".", 1)[0] + ".png"
processed_image.save(output_path)
st.image(processed_image, caption="✨ الصورة المعالجة", use_column_width=True)
st.download_button(
"📥 تحميل الصورة المعالجة",
data=open(output_path, "rb"),
file_name=output_path,
mime="image/png",
)
# تحسين الألوان والخط
st.markdown(
"""
<style>
body {
background-color: #f8f9fa;
}
.stButton>button {
background-color: #4CAF50;
color: white;
border-radius: 10px;
}
.stRadio>div>label {
font-size: 16px;
color: #333;
}
.stSpinner {
color: #FF5733;
}
</style>
""",
unsafe_allow_html=True,
)