File size: 2,918 Bytes
0f308a3
 
2b12344
 
 
0f308a3
4d061ac
 
409f50f
2b12344
 
 
 
409f50f
 
 
2b12344
0f308a3
409f50f
2b12344
409f50f
 
 
0f308a3
4d061ac
 
 
 
 
 
 
 
 
409f50f
 
 
0f308a3
409f50f
 
 
 
 
 
2b12344
 
 
 
409f50f
2b12344
409f50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b12344
409f50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b12344
409f50f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
from transformers import pipeline
from datasets import load_dataset
from PIL import Image
import numpy as np


        
# 初始化模型 (缓存)
@st.cache_resource
def load_models():
    return {
        "detector": pipeline("object-detection", model="facebook/detr-resnet-50"),
        "style_classifier":pipeline("image-classification", model="playrobin/furniture-styles"
        ),
        "advisor": pipeline("text2text-generation", model="google/flan-t5-base")
    }

# 加载数据集 (缓存)
@st.cache_data
def load_style_examples():
    dataset = load_dataset("AntZet/home_decoration_objects_images")
    return dataset['train'].to_pandas()

@st.cache_data
def load_style_examples():
    try:
        dataset = load_dataset("AntZet/home_decoration_objects_images", streaming=True)
        return dataset['train'].to_pandas()
    except:
        return pd.DataFrame(columns=['style', 'image'])  # 返回空数据集作为fallback

        
# 主函数
def main():
    st.title("🏠 AI 装修风格匹配器")
    
    uploaded_img = st.file_uploader("上传房间照片", type=["jpg", "png"])
    
    if uploaded_img:
        img = Image.open(uploaded_img)
        models = load_models()
        df = load_style_examples()
        
        col1, col2 = st.columns(2)
        
        with col1:
            st.image(img, width=300)
            
            # 风格分类
            with st.spinner("分析房间风格..."):
                style_result = models["style_classifier"](img)
                main_style = style_result[0]['label']
                confidence = style_result[0]['score']
            
            st.success(f"检测风格: {main_style} (置信度: {confidence:.0%})")
            
            # 物体检测
            with st.spinner("识别家具物品..."):
                objects = models["detector"](img)
                top_objects = [obj['label'] for obj in objects if obj['score'] > 0.7]
            
            st.subheader("检测到的主要物品")
            st.write(", ".join(set(top_objects)))

        with col2:
            # 从数据集中找匹配案例
            style_samples = df[df['style'] == main_style].sample(3)
            
            st.subheader(f"{main_style} 风格案例")
            st.image(style_samples['image'].tolist(), width=150)
            
            # 生成建议
            with st.spinner("生成装修建议..."):
                prompt = f"""根据以下条件生成装修建议:
                - 当前风格: {main_style}
                - 现有物品: {top_objects[:5]}
                - 目标风格: {main_style}
                提供3条具体改进建议"""
                
                advice = models["advisor"](prompt, max_length=300)[0]['generated_text']
            
            st.subheader("专业建议")
            st.write(advice)

if __name__ == "__main__":
    main()