File size: 2,906 Bytes
0f308a3
 
2b12344
 
 
d946490
0f308a3
d946490
 
 
4d061ac
d946490
2b12344
 
 
d946490
 
 
409f50f
d946490
2b12344
0f308a3
d946490
 
 
 
 
 
 
 
 
 
 
4d061ac
409f50f
 
 
 
 
d946490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b12344
d946490
2b12344
 
 
409f50f
d946490
2b12344
d946490
 
 
 
 
2b12344
d946490
2b12344
d946490
 
 
 
409f50f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from transformers import pipeline
from datasets import load_dataset
from PIL import Image
import numpy as np
from collections import Counter

# 设置页面
st.set_page_config(page_title="🏠 装修风格分析器", layout="wide")
st.title("AI 装修风格匹配工具")

# 缓存模型(移除了物体检测)
@st.cache_resource
def load_models():
    return {
        "style_classifier": pipeline(
            "image-classification", 
            model="dima806/interior_design_style_classification"
        ),
        "advisor": pipeline("text2text-generation", model="google/flan-t5-small")
    }

# 颜色分析函数(替代物体检测)
def analyze_image(img):
    # 简化的视觉分析:仅提取颜色
    img = img.resize((50,50))
    arr = np.array(img)
    pixels = arr.reshape(-1,3)
    
    # 使用简化版颜色分析(避免sklearn依赖)
    unique_colors = np.unique(pixels, axis=0)
    main_colors = unique_colors[:3]  # 取前3种主要颜色
    return [f"#{r:02x}{g:02x}{b:02x}" for r,g,b in main_colors]

def main():
    uploaded_img = st.file_uploader("上传房间照片", type=["jpg", "png"])
    
    if uploaded_img:
        models = load_models()
        img = Image.open(uploaded_img)
        
        with st.spinner("正在分析..."):
            # 1. 风格分类
            style_result = models["style_classifier"](img)
            main_style = style_result[0]['label']
            
            # 2. 视觉分析(颜色替代物体检测)
            colors = analyze_image(img)
            
            # 3. 从数据集找案例
            try:
                dataset = load_dataset("AntZet/home_decoration_objects_images", streaming=True)
                examples = [ex['image'] for ex in dataset['train'] 
                           if ex['style'] == main_style][:3]
            except:
                examples = []
            
            # 4. 生成建议
            prompt = f"""基于{main_style}风格,给出3条装修建议:
            - 主色调: {colors}
            - 避免: 与风格冲突的元素
            - 预算: 中等成本方案"""
            advice = models["advisor"](prompt, max_length=200)[0]['generated_text']
        
        # 显示结果
        col1, col2 = st.columns(2)
        
        with col1:
            st.image(img, width=300)
            st.success(f"识别风格: {main_style}")
            
            st.subheader("主要色调")
            for color in colors:
                st.markdown(f"<div style='background:{color}; height:30px'></div>", 
                           unsafe_allow_html=True)
        
        with col2:
            st.subheader("风格建议")
            st.write(advice)
            
            if examples:
                st.subheader("参考案例")
                st.image(examples, width=150)

if __name__ == "__main__":
    main()