finalproject2 / app.py
mxiean's picture
Update app.py
409f50f verified
raw
history blame
2.61 kB
import streamlit as st
from transformers import pipeline
from datasets import load_dataset
from PIL import Image
import numpy as np
# 初始化模型 (缓存)
@st.cache_resource
def load_models():
return {
"detector": pipeline("object-detection", model="facebook/detr-resnet-50"),
"style_classifier":pipeline("image-classification", model="playrobin/furniture-styles"
),
"advisor": pipeline("text2text-generation", model="google/flan-t5-base")
}
# 加载数据集 (缓存)
@st.cache_data
def load_style_examples():
dataset = load_dataset("AntZet/home_decoration_objects_images")
return dataset['train'].to_pandas()
# 主函数
def main():
st.title("🏠 AI 装修风格匹配器")
uploaded_img = st.file_uploader("上传房间照片", type=["jpg", "png"])
if uploaded_img:
img = Image.open(uploaded_img)
models = load_models()
df = load_style_examples()
col1, col2 = st.columns(2)
with col1:
st.image(img, width=300)
# 风格分类
with st.spinner("分析房间风格..."):
style_result = models["style_classifier"](img)
main_style = style_result[0]['label']
confidence = style_result[0]['score']
st.success(f"检测风格: {main_style} (置信度: {confidence:.0%})")
# 物体检测
with st.spinner("识别家具物品..."):
objects = models["detector"](img)
top_objects = [obj['label'] for obj in objects if obj['score'] > 0.7]
st.subheader("检测到的主要物品")
st.write(", ".join(set(top_objects)))
with col2:
# 从数据集中找匹配案例
style_samples = df[df['style'] == main_style].sample(3)
st.subheader(f"{main_style} 风格案例")
st.image(style_samples['image'].tolist(), width=150)
# 生成建议
with st.spinner("生成装修建议..."):
prompt = f"""根据以下条件生成装修建议:
- 当前风格: {main_style}
- 现有物品: {top_objects[:5]}
- 目标风格: {main_style}
提供3条具体改进建议"""
advice = models["advisor"](prompt, max_length=300)[0]['generated_text']
st.subheader("专业建议")
st.write(advice)
if __name__ == "__main__":
main()