Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,112 +1,3 @@
|
|
1 |
-
import
|
2 |
-
import torch
|
3 |
from transformers import pipeline
|
4 |
-
|
5 |
-
from PIL import Image
|
6 |
-
import numpy as np
|
7 |
-
from collections import Counter
|
8 |
-
import functools
|
9 |
-
'''
|
10 |
-
# 使用标准库的缓存装饰器替代Gradio缓存
|
11 |
-
@functools.lru_cache(maxsize=None)
|
12 |
-
def load_models():
|
13 |
-
return {
|
14 |
-
"detector": pipeline(
|
15 |
-
"object-detection",
|
16 |
-
model="facebook/detr-resnet-50",
|
17 |
-
device=0 if torch.cuda.is_available() else -1
|
18 |
-
),
|
19 |
-
"generator": pipeline(
|
20 |
-
"text2text-generation",
|
21 |
-
model="google/flan-t5-base", # 改用基础版降低资源需求
|
22 |
-
device=0 if torch.cuda.is_available() else -1
|
23 |
-
)
|
24 |
-
}
|
25 |
-
|
26 |
-
# 数据集加载函数(移除Gradio缓存)
|
27 |
-
def load_dataset_data():
|
28 |
-
ds = load_dataset("AntZet/home_decoration_objects_images")
|
29 |
-
return ds['train'].to_pandas()
|
30 |
-
|
31 |
-
# 颜色分析函数保持不变
|
32 |
-
def get_dominant_colors(img, n_colors=3):
|
33 |
-
arr = np.array(img.resize((100,100)))
|
34 |
-
pixels = arr.reshape(-1,3)
|
35 |
-
from sklearn.cluster import KMeans
|
36 |
-
kmeans = KMeans(n_clusters=n_colors)
|
37 |
-
kmeans.fit(pixels)
|
38 |
-
return [f"#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}" for c in kmeans.cluster_centers_]
|
39 |
-
|
40 |
-
# 核心处理函数
|
41 |
-
def generate_recommendation(target_style):
|
42 |
-
try:
|
43 |
-
models = load_models()
|
44 |
-
df = load_dataset_data()
|
45 |
-
|
46 |
-
style_df = df[df['style'] == target_style.lower()]
|
47 |
-
if len(style_df) < 3:
|
48 |
-
return f"⚠️ Not enough samples for {target_style} style"
|
49 |
-
|
50 |
-
sample_images = style_df.sample(5)['image']
|
51 |
-
|
52 |
-
all_objects = []
|
53 |
-
color_palette = []
|
54 |
-
|
55 |
-
for img in sample_images:
|
56 |
-
detected = models["detector"](img)
|
57 |
-
all_objects += [obj['label'] for obj in detected if obj['score'] > 0.9]
|
58 |
-
color_palette += get_dominant_colors(img)
|
59 |
-
|
60 |
-
top_objects = Counter(all_objects).most_common(3)
|
61 |
-
top_colors = Counter(color_palette).most_common(3)
|
62 |
-
|
63 |
-
prompt = f"""Create interior design recommendations for {target_style} style:
|
64 |
-
Key objects: {[o[0] for o in top_objects]}
|
65 |
-
Color palette: {[c[0] for c in top_colors]}
|
66 |
-
Include: 3 essentials, 2 budget tips, common mistakes"""
|
67 |
-
|
68 |
-
advice = models["generator"](prompt, max_length=300)[0]['generated_text']
|
69 |
-
|
70 |
-
output = f"## 🎨 {target_style.title()} Style Guide\n\n"
|
71 |
-
output += "### 🪑 Key Objects\n" + "\n".join(
|
72 |
-
[f"- {o[0]} ({o[1]}x)" for o in top_objects]) + "\n\n"
|
73 |
-
output += "### 🎨 Colors\n" + "\n".join(
|
74 |
-
[f"<span style='color:{c[0]};'>■</span> {c[0]}" for c in top_colors]) + "\n\n"
|
75 |
-
output += "### 💡 Advice\n" + advice.replace(". ", ".\n")
|
76 |
-
|
77 |
-
return output
|
78 |
-
|
79 |
-
except Exception as e:
|
80 |
-
return f"❌ Error: {str(e)}"
|
81 |
-
|
82 |
-
# Gradio界面保持不变
|
83 |
-
with gr.Blocks(title="Design Assistant") as demo:
|
84 |
-
gr.Markdown("# 🏡 AI Design Advisor")
|
85 |
-
|
86 |
-
with gr.Row():
|
87 |
-
style_input = gr.Dropdown(
|
88 |
-
label="Select Style",
|
89 |
-
choices=["Industrial", "Scandinavian", "Bohemian", "Modern"],
|
90 |
-
value="Industrial"
|
91 |
-
)
|
92 |
-
|
93 |
-
submit_btn = gr.Button("Generate Plan", variant="primary")
|
94 |
-
|
95 |
-
with gr.Row():
|
96 |
-
output = gr.Markdown()
|
97 |
-
gallery = gr.Gallery(
|
98 |
-
label="Examples",
|
99 |
-
object_fit="contain",
|
100 |
-
height="300px"
|
101 |
-
)
|
102 |
-
|
103 |
-
def update_gallery(style):
|
104 |
-
df = load_dataset_data()
|
105 |
-
return df[df['style'] == style.lower()].sample(3)['image'].tolist()
|
106 |
-
|
107 |
-
style_input.change(update_gallery, inputs=style_input, outputs=gallery)
|
108 |
-
submit_btn.click(generate_recommendation, inputs=style_input, outputs=output)
|
109 |
-
|
110 |
-
if __name__ == "__main__":
|
111 |
-
demo.launch()
|
112 |
-
'''
|
|
|
1 |
+
import streamlit as st
|
|
|
2 |
from transformers import pipeline
|
3 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|