mxiean commited on
Commit
d946490
·
verified ·
1 Parent(s): 4d061ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -57
app.py CHANGED
@@ -3,85 +3,84 @@ from transformers import pipeline
3
  from datasets import load_dataset
4
  from PIL import Image
5
  import numpy as np
 
6
 
 
 
 
7
 
8
-
9
- # 初始化模型 (缓存)
10
  @st.cache_resource
11
  def load_models():
12
  return {
13
- "detector": pipeline("object-detection", model="facebook/detr-resnet-50"),
14
- "style_classifier":pipeline("image-classification", model="playrobin/furniture-styles"
 
15
  ),
16
- "advisor": pipeline("text2text-generation", model="google/flan-t5-base")
17
  }
18
 
19
- # 加载数据集 (缓存)
20
- @st.cache_data
21
- def load_style_examples():
22
- dataset = load_dataset("AntZet/home_decoration_objects_images")
23
- return dataset['train'].to_pandas()
24
-
25
- @st.cache_data
26
- def load_style_examples():
27
- try:
28
- dataset = load_dataset("AntZet/home_decoration_objects_images", streaming=True)
29
- return dataset['train'].to_pandas()
30
- except:
31
- return pd.DataFrame(columns=['style', 'image']) # 返回空数据集作为fallback
32
 
33
-
34
- # 主函数
35
  def main():
36
- st.title("🏠 AI 装修风格匹配器")
37
-
38
  uploaded_img = st.file_uploader("上传房间照片", type=["jpg", "png"])
39
 
40
  if uploaded_img:
41
- img = Image.open(uploaded_img)
42
  models = load_models()
43
- df = load_style_examples()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
45
  col1, col2 = st.columns(2)
46
 
47
  with col1:
48
  st.image(img, width=300)
 
49
 
50
- # 风格分类
51
- with st.spinner("分析房间风格..."):
52
- style_result = models["style_classifier"](img)
53
- main_style = style_result[0]['label']
54
- confidence = style_result[0]['score']
55
-
56
- st.success(f"检测风格: {main_style} (置信度: {confidence:.0%})")
57
-
58
- # 物体检测
59
- with st.spinner("识别家具物品..."):
60
- objects = models["detector"](img)
61
- top_objects = [obj['label'] for obj in objects if obj['score'] > 0.7]
62
-
63
- st.subheader("检测到的主要物品")
64
- st.write(", ".join(set(top_objects)))
65
-
66
  with col2:
67
- # 从数据集中找匹配案例
68
- style_samples = df[df['style'] == main_style].sample(3)
69
-
70
- st.subheader(f"{main_style} 风格案例")
71
- st.image(style_samples['image'].tolist(), width=150)
72
-
73
- # 生成建议
74
- with st.spinner("生成装修建议..."):
75
- prompt = f"""根据以下条件生成装修建议:
76
- - 当前风格: {main_style}
77
- - 现有物品: {top_objects[:5]}
78
- - 目标风格: {main_style}
79
- 提供3条具体改进建议"""
80
-
81
- advice = models["advisor"](prompt, max_length=300)[0]['generated_text']
82
-
83
- st.subheader("专业建议")
84
  st.write(advice)
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
  main()
 
3
  from datasets import load_dataset
4
  from PIL import Image
5
  import numpy as np
6
+ from collections import Counter
7
 
8
+ # 设置页面
9
+ st.set_page_config(page_title="🏠 装修风格分析器", layout="wide")
10
+ st.title("AI 装修风格匹配工具")
11
 
12
+ # 缓存模型(移除了物体检测)
 
13
  @st.cache_resource
14
  def load_models():
15
  return {
16
+ "style_classifier": pipeline(
17
+ "image-classification",
18
+ model="dima806/interior_design_style_classification"
19
  ),
20
+ "advisor": pipeline("text2text-generation", model="google/flan-t5-small")
21
  }
22
 
23
+ # 颜色分析函数(替代物体检测)
24
+ def analyze_image(img):
25
+ # 简化的视觉分析:仅提取颜色
26
+ img = img.resize((50,50))
27
+ arr = np.array(img)
28
+ pixels = arr.reshape(-1,3)
29
+
30
+ # 使用简化版颜色分析(避免sklearn依赖)
31
+ unique_colors = np.unique(pixels, axis=0)
32
+ main_colors = unique_colors[:3] # 取前3种主要颜色
33
+ return [f"#{r:02x}{g:02x}{b:02x}" for r,g,b in main_colors]
 
 
34
 
 
 
35
  def main():
 
 
36
  uploaded_img = st.file_uploader("上传房间照片", type=["jpg", "png"])
37
 
38
  if uploaded_img:
 
39
  models = load_models()
40
+ img = Image.open(uploaded_img)
41
+
42
+ with st.spinner("正在分析..."):
43
+ # 1. 风格分类
44
+ style_result = models["style_classifier"](img)
45
+ main_style = style_result[0]['label']
46
+
47
+ # 2. 视觉分析(颜色替代物体检测)
48
+ colors = analyze_image(img)
49
+
50
+ # 3. 从数据集找案例
51
+ try:
52
+ dataset = load_dataset("AntZet/home_decoration_objects_images", streaming=True)
53
+ examples = [ex['image'] for ex in dataset['train']
54
+ if ex['style'] == main_style][:3]
55
+ except:
56
+ examples = []
57
+
58
+ # 4. 生成建议
59
+ prompt = f"""基于{main_style}风格,给出3条装修建议:
60
+ - 主色调: {colors}
61
+ - 避免: 与风格冲突的元素
62
+ - 预算: 中等成本方案"""
63
+ advice = models["advisor"](prompt, max_length=200)[0]['generated_text']
64
 
65
+ # 显示结果
66
  col1, col2 = st.columns(2)
67
 
68
  with col1:
69
  st.image(img, width=300)
70
+ st.success(f"识别风格: {main_style}")
71
 
72
+ st.subheader("主要色调")
73
+ for color in colors:
74
+ st.markdown(f"<div style='background:{color}; height:30px'></div>",
75
+ unsafe_allow_html=True)
76
+
 
 
 
 
 
 
 
 
 
 
 
77
  with col2:
78
+ st.subheader("风格建议")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  st.write(advice)
80
+
81
+ if examples:
82
+ st.subheader("参考案例")
83
+ st.image(examples, width=150)
84
 
85
  if __name__ == "__main__":
86
  main()