mxiean commited on
Commit
2b12344
·
verified ·
1 Parent(s): 64ec1cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -54
app.py CHANGED
@@ -1,64 +1,87 @@
1
- # import part
2
  import streamlit as st
3
  from transformers import pipeline
4
- import torch
 
 
 
5
 
6
- # function part
7
- # img2text
8
- def img2text(url):
9
- image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
10
- text = image_to_text_model(url)[0]["generated_text"]
11
- return text
12
 
13
- # text2story
14
- def text2story(text):
15
- pipe = pipeline("text-generation", model="gpt2")
16
- story_text = pipe(text, max_length=100)[0]['generated_text']
17
- return story_text
 
 
18
 
19
- # text2audio
20
- def text2audio(story_text):
21
- pipe = pipeline("text-to-speech", model="facebook/mms-tts-eng")
22
- audio_data = pipe(story_text)
23
- return audio_data
24
 
 
 
 
 
 
 
25
 
26
- # main part
 
27
 
28
- st.set_page_config(page_title="Your Image to Audio Story",
29
- page_icon="🦜")
30
- st.header("Turn Your Image to Audio Story")
31
- uploaded_file = st.file_uploader("Select an Image...")#, type=["jpg", "jpeg", "png"])
32
-
33
-
34
- if uploaded_file is not None:
35
- print(uploaded_file)
36
- bytes_data = uploaded_file.getvalue()
37
- with open(uploaded_file.name, "wb") as file:
38
- file.write(bytes_data)
39
- st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
40
 
 
 
 
 
41
 
42
-
43
- #Stage 1: Image to Text
44
- st.text('Processing img2text...')
45
- scenario = img2text(uploaded_file.name)
46
- st.write(scenario)
47
-
48
-
49
- #Stage 2: Text to Story
50
- st.text('Generating a story...')
51
- story = text2story(scenario)
52
- st.write(story)
53
-
54
- #Stage 3: Story to Audio data
55
- st.text('Generating audio data...')
56
- audio_data =text2audio(story)
57
-
58
- # Play button
59
- if st.button("Play Audio"):
60
- st.audio(audio_data['audio'],
61
- format="audio/wav",
62
- start_time=0,
63
- sample_rate = audio_data['sampling_rate'])
64
- # st.audio("kids_playing_audio.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from datasets import load_dataset
4
+ from PIL import Image
5
+ import numpy as np
6
+ from collections import Counter
7
 
8
+ # 设置页面标题
9
+ st.set_page_config(page_title="🏠 Airbnb Design Advisor", layout="wide")
 
 
 
 
10
 
11
+ # 缓存模型和数据集
12
+ @st.cache_resource
13
+ def load_models():
14
+ return {
15
+ "detector": pipeline("object-detection", model="facebook/detr-resnet-50"),
16
+ "generator": pipeline("text2text-generation", model="google/flan-t5-base")
17
+ }
18
 
19
+ @st.cache_data
20
+ def load_data():
21
+ return load_dataset("AntZet/home_decoration_objects_images")['train'].to_pandas()
 
 
22
 
23
+ # 颜色提取函数
24
+ def get_colors(img, n=3):
25
+ arr = np.array(img.resize((50,50)))
26
+ from sklearn.cluster import KMeans
27
+ return [f"#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}"
28
+ for c in KMeans(n_clusters=n).fit(arr.reshape(-1,3)).cluster_centers_]
29
 
30
+ # 主界面
31
+ st.title("AI-Powered Airbnb Design Advisor")
32
 
33
+ # 侧边栏控制
34
+ with st.sidebar:
35
+ st.header("Settings")
36
+ style = st.selectbox(
37
+ "Select Style",
38
+ ["industrial", "scandinavian", "bohemian", "modern"],
39
+ index=0
40
+ )
41
+ analyze_btn = st.button("Analyze")
 
 
 
42
 
43
+ # 主内容区
44
+ if analyze_btn:
45
+ models = load_models()
46
+ df = load_data()
47
 
48
+ with st.spinner("Generating recommendations..."):
49
+ # 获取风格示例
50
+ examples = df[df['style'] == style].sample(3)
51
+
52
+ # 分析对象和颜色
53
+ objects = []
54
+ colors = []
55
+ for img in examples['image']:
56
+ detected = models["detector"](img)
57
+ objects += [obj['label'] for obj in detected if obj['score'] > 0.8]
58
+ colors += get_colors(img)
59
+
60
+ # 生成建议
61
+ prompt = f"""Create {style} style decoration tips for Airbnb with:
62
+ - Key objects: {Counter(objects).most_common(3)}
63
+ - Color palette: {Counter(colors).most_common(3)}
64
+ Include: 3 essentials, 2 budget tips"""
65
+
66
+ advice = models["generator"](prompt, max_length=300)[0]['generated_text']
67
+
68
+ # 显示结果
69
+ col1, col2 = st.columns(2)
70
+
71
+ with col1:
72
+ st.subheader("Key Elements")
73
+ for obj, count in Counter(objects).most_common(3):
74
+ st.markdown(f"- {obj} (appears in {count} samples)")
75
+
76
+ st.subheader("Color Palette")
77
+ for color in Counter(colors).most_common(3):
78
+ st.markdown(f"<div style='background:{color[0]}; width:100%; height:30px'></div>",
79
+ unsafe_allow_html=True)
80
+ st.caption(color[0])
81
+
82
+ with col2:
83
+ st.subheader("Professional Advice")
84
+ st.write(advice)
85
+
86
+ st.subheader("Example Images")
87
+ st.image(examples['image'].tolist(), width=300)