mxiean commited on
Commit
409f50f
·
verified ·
1 Parent(s): 2b12344

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -65
app.py CHANGED
@@ -3,85 +3,74 @@ from transformers import pipeline
3
  from datasets import load_dataset
4
  from PIL import Image
5
  import numpy as np
6
- from collections import Counter
7
 
8
- # 设置页面标题
9
- st.set_page_config(page_title="🏠 Airbnb Design Advisor", layout="wide")
10
-
11
- # 缓存模型和数据集
12
  @st.cache_resource
13
  def load_models():
14
  return {
15
  "detector": pipeline("object-detection", model="facebook/detr-resnet-50"),
16
- "generator": pipeline("text2text-generation", model="google/flan-t5-base")
 
 
17
  }
18
 
 
19
  @st.cache_data
20
- def load_data():
21
- return load_dataset("AntZet/home_decoration_objects_images")['train'].to_pandas()
22
-
23
- # 颜色提取函数
24
- def get_colors(img, n=3):
25
- arr = np.array(img.resize((50,50)))
26
- from sklearn.cluster import KMeans
27
- return [f"#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}"
28
- for c in KMeans(n_clusters=n).fit(arr.reshape(-1,3)).cluster_centers_]
29
-
30
- # 主界面
31
- st.title("AI-Powered Airbnb Design Advisor")
32
 
33
- # 侧边栏控制
34
- with st.sidebar:
35
- st.header("Settings")
36
- style = st.selectbox(
37
- "Select Style",
38
- ["industrial", "scandinavian", "bohemian", "modern"],
39
- index=0
40
- )
41
- analyze_btn = st.button("Analyze")
42
-
43
- # 主内容区
44
- if analyze_btn:
45
- models = load_models()
46
- df = load_data()
47
 
48
- with st.spinner("Generating recommendations..."):
49
- # 获取风格示例
50
- examples = df[df['style'] == style].sample(3)
51
-
52
- # 分析对象和颜色
53
- objects = []
54
- colors = []
55
- for img in examples['image']:
56
- detected = models["detector"](img)
57
- objects += [obj['label'] for obj in detected if obj['score'] > 0.8]
58
- colors += get_colors(img)
59
-
60
- # 生成建议
61
- prompt = f"""Create {style} style decoration tips for Airbnb with:
62
- - Key objects: {Counter(objects).most_common(3)}
63
- - Color palette: {Counter(colors).most_common(3)}
64
- Include: 3 essentials, 2 budget tips"""
65
-
66
- advice = models["generator"](prompt, max_length=300)[0]['generated_text']
67
 
68
- # 显示结果
69
  col1, col2 = st.columns(2)
70
 
71
  with col1:
72
- st.subheader("Key Elements")
73
- for obj, count in Counter(objects).most_common(3):
74
- st.markdown(f"- {obj} (appears in {count} samples)")
75
 
76
- st.subheader("Color Palette")
77
- for color in Counter(colors).most_common(3):
78
- st.markdown(f"<div style='background:{color[0]}; width:100%; height:30px'></div>",
79
- unsafe_allow_html=True)
80
- st.caption(color[0])
81
-
 
 
 
 
 
 
 
 
 
 
82
  with col2:
83
- st.subheader("Professional Advice")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  st.write(advice)
85
-
86
- st.subheader("Example Images")
87
- st.image(examples['image'].tolist(), width=300)
 
3
  from datasets import load_dataset
4
  from PIL import Image
5
  import numpy as np
 
6
 
7
+ # 初始化模型 (缓存)
 
 
 
8
  @st.cache_resource
9
  def load_models():
10
  return {
11
  "detector": pipeline("object-detection", model="facebook/detr-resnet-50"),
12
+ "style_classifier":pipeline("image-classification", model="playrobin/furniture-styles"
13
+ ),
14
+ "advisor": pipeline("text2text-generation", model="google/flan-t5-base")
15
  }
16
 
17
+ # 加载数据集 (缓存)
18
  @st.cache_data
19
+ def load_style_examples():
20
+ dataset = load_dataset("AntZet/home_decoration_objects_images")
21
+ return dataset['train'].to_pandas()
 
 
 
 
 
 
 
 
 
22
 
23
+ # 主函数
24
+ def main():
25
+ st.title("🏠 AI 装修风格匹配器")
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ uploaded_img = st.file_uploader("上传房间照片", type=["jpg", "png"])
28
+
29
+ if uploaded_img:
30
+ img = Image.open(uploaded_img)
31
+ models = load_models()
32
+ df = load_style_examples()
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  col1, col2 = st.columns(2)
35
 
36
  with col1:
37
+ st.image(img, width=300)
 
 
38
 
39
+ # 风格分类
40
+ with st.spinner("分析房间风格..."):
41
+ style_result = models["style_classifier"](img)
42
+ main_style = style_result[0]['label']
43
+ confidence = style_result[0]['score']
44
+
45
+ st.success(f"检测风格: {main_style} (置信度: {confidence:.0%})")
46
+
47
+ # 物体检测
48
+ with st.spinner("识别家具物品..."):
49
+ objects = models["detector"](img)
50
+ top_objects = [obj['label'] for obj in objects if obj['score'] > 0.7]
51
+
52
+ st.subheader("检测到的主要物品")
53
+ st.write(", ".join(set(top_objects)))
54
+
55
  with col2:
56
+ # 从数据集中找匹配案例
57
+ style_samples = df[df['style'] == main_style].sample(3)
58
+
59
+ st.subheader(f"{main_style} 风格案例")
60
+ st.image(style_samples['image'].tolist(), width=150)
61
+
62
+ # 生成建议
63
+ with st.spinner("生成装修建议..."):
64
+ prompt = f"""根据以下条件生成装修建议:
65
+ - 当前风格: {main_style}
66
+ - 现有物品: {top_objects[:5]}
67
+ - 目标风格: {main_style}
68
+ 提供3条具体改进建议"""
69
+
70
+ advice = models["advisor"](prompt, max_length=300)[0]['generated_text']
71
+
72
+ st.subheader("专业建议")
73
  st.write(advice)
74
+
75
+ if __name__ == "__main__":
76
+ main()