mxiean commited on
Commit
e132e99
·
verified ·
1 Parent(s): 635275b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -111
app.py CHANGED
@@ -1,112 +1,3 @@
1
- import gradio as gr
2
- import torch
3
  from transformers import pipeline
4
- from datasets import load_dataset
5
- from PIL import Image
6
- import numpy as np
7
- from collections import Counter
8
- import functools
9
- '''
10
- # 使用标准库的缓存装饰器替代Gradio缓存
11
- @functools.lru_cache(maxsize=None)
12
- def load_models():
13
- return {
14
- "detector": pipeline(
15
- "object-detection",
16
- model="facebook/detr-resnet-50",
17
- device=0 if torch.cuda.is_available() else -1
18
- ),
19
- "generator": pipeline(
20
- "text2text-generation",
21
- model="google/flan-t5-base", # 改用基础版降低资源需求
22
- device=0 if torch.cuda.is_available() else -1
23
- )
24
- }
25
-
26
- # 数据集加载函数(移除Gradio缓存)
27
- def load_dataset_data():
28
- ds = load_dataset("AntZet/home_decoration_objects_images")
29
- return ds['train'].to_pandas()
30
-
31
- # 颜色分析函数保持不变
32
- def get_dominant_colors(img, n_colors=3):
33
- arr = np.array(img.resize((100,100)))
34
- pixels = arr.reshape(-1,3)
35
- from sklearn.cluster import KMeans
36
- kmeans = KMeans(n_clusters=n_colors)
37
- kmeans.fit(pixels)
38
- return [f"#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}" for c in kmeans.cluster_centers_]
39
-
40
- # 核心处理函数
41
- def generate_recommendation(target_style):
42
- try:
43
- models = load_models()
44
- df = load_dataset_data()
45
-
46
- style_df = df[df['style'] == target_style.lower()]
47
- if len(style_df) < 3:
48
- return f"⚠️ Not enough samples for {target_style} style"
49
-
50
- sample_images = style_df.sample(5)['image']
51
-
52
- all_objects = []
53
- color_palette = []
54
-
55
- for img in sample_images:
56
- detected = models["detector"](img)
57
- all_objects += [obj['label'] for obj in detected if obj['score'] > 0.9]
58
- color_palette += get_dominant_colors(img)
59
-
60
- top_objects = Counter(all_objects).most_common(3)
61
- top_colors = Counter(color_palette).most_common(3)
62
-
63
- prompt = f"""Create interior design recommendations for {target_style} style:
64
- Key objects: {[o[0] for o in top_objects]}
65
- Color palette: {[c[0] for c in top_colors]}
66
- Include: 3 essentials, 2 budget tips, common mistakes"""
67
-
68
- advice = models["generator"](prompt, max_length=300)[0]['generated_text']
69
-
70
- output = f"## 🎨 {target_style.title()} Style Guide\n\n"
71
- output += "### 🪑 Key Objects\n" + "\n".join(
72
- [f"- {o[0]} ({o[1]}x)" for o in top_objects]) + "\n\n"
73
- output += "### 🎨 Colors\n" + "\n".join(
74
- [f"<span style='color:{c[0]};'>■</span> {c[0]}" for c in top_colors]) + "\n\n"
75
- output += "### 💡 Advice\n" + advice.replace(". ", ".\n")
76
-
77
- return output
78
-
79
- except Exception as e:
80
- return f"❌ Error: {str(e)}"
81
-
82
- # Gradio界面保持不变
83
- with gr.Blocks(title="Design Assistant") as demo:
84
- gr.Markdown("# 🏡 AI Design Advisor")
85
-
86
- with gr.Row():
87
- style_input = gr.Dropdown(
88
- label="Select Style",
89
- choices=["Industrial", "Scandinavian", "Bohemian", "Modern"],
90
- value="Industrial"
91
- )
92
-
93
- submit_btn = gr.Button("Generate Plan", variant="primary")
94
-
95
- with gr.Row():
96
- output = gr.Markdown()
97
- gallery = gr.Gallery(
98
- label="Examples",
99
- object_fit="contain",
100
- height="300px"
101
- )
102
-
103
- def update_gallery(style):
104
- df = load_dataset_data()
105
- return df[df['style'] == style.lower()].sample(3)['image'].tolist()
106
-
107
- style_input.change(update_gallery, inputs=style_input, outputs=gallery)
108
- submit_btn.click(generate_recommendation, inputs=style_input, outputs=output)
109
-
110
- if __name__ == "__main__":
111
- demo.launch()
112
- '''
 
1
+ import streamlit as st
 
2
  from transformers import pipeline
3
+ import torch