Spaces:
Sleeping
Sleeping
Delete pages
Browse files- pages/10_Saved_Scenarios.py +0 -407
- pages/11_Optimized_Result_Analysis.py +0 -453
- pages/1_Data_Import.py +0 -1547
- pages/2_Data_Validation.py +0 -509
- pages/3_Transformations.py +0 -686
- pages/4_Model_Build.py +0 -1062
- pages/5_Model_Tuning.py +0 -912
- pages/6_AI_Model_Results.py +0 -728
- pages/7_Current_Media_Performance.py +0 -573
- pages/8_Build_Response_Curves.py +0 -596
- pages/9_Scenario_Planner.py +0 -1712
pages/10_Saved_Scenarios.py
DELETED
@@ -1,407 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from numerize.numerize import numerize
|
3 |
-
import io
|
4 |
-
import pandas as pd
|
5 |
-
from utilities import (
|
6 |
-
format_numbers,
|
7 |
-
decimal_formater,
|
8 |
-
channel_name_formating,
|
9 |
-
load_local_css,
|
10 |
-
set_header,
|
11 |
-
initialize_data,
|
12 |
-
load_authenticator,
|
13 |
-
)
|
14 |
-
from openpyxl import Workbook
|
15 |
-
from openpyxl.styles import Alignment, Font, PatternFill
|
16 |
-
import pickle
|
17 |
-
import streamlit_authenticator as stauth
|
18 |
-
import yaml
|
19 |
-
from yaml import SafeLoader
|
20 |
-
from classes import class_from_dict
|
21 |
-
from utilities import update_db
|
22 |
-
|
23 |
-
st.set_page_config(layout="wide")
|
24 |
-
load_local_css("styles.css")
|
25 |
-
set_header()
|
26 |
-
|
27 |
-
# for k, v in st.session_state.items():
|
28 |
-
# if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
|
29 |
-
# st.session_state[k] = v
|
30 |
-
|
31 |
-
|
32 |
-
def create_scenario_summary(scenario_dict):
|
33 |
-
summary_rows = []
|
34 |
-
for channel_dict in scenario_dict["channels"]:
|
35 |
-
name_mod = channel_name_formating(channel_dict["name"])
|
36 |
-
summary_rows.append(
|
37 |
-
[
|
38 |
-
name_mod,
|
39 |
-
channel_dict.get("actual_total_spends")
|
40 |
-
* channel_dict.get("conversion_rate"),
|
41 |
-
channel_dict.get("modified_total_spends")
|
42 |
-
* channel_dict.get("conversion_rate"),
|
43 |
-
channel_dict.get("actual_total_sales"),
|
44 |
-
channel_dict.get("modified_total_sales"),
|
45 |
-
channel_dict.get("actual_total_sales")
|
46 |
-
/ (
|
47 |
-
channel_dict.get("actual_total_spends")
|
48 |
-
* channel_dict.get("conversion_rate")
|
49 |
-
),
|
50 |
-
channel_dict.get("modified_total_sales")
|
51 |
-
/ (
|
52 |
-
channel_dict.get("modified_total_spends")
|
53 |
-
* channel_dict.get("conversion_rate")
|
54 |
-
),
|
55 |
-
channel_dict.get("actual_mroi"),
|
56 |
-
channel_dict.get("modified_mroi"),
|
57 |
-
channel_dict.get("actual_total_spends")
|
58 |
-
* channel_dict.get("conversion_rate")
|
59 |
-
/ channel_dict.get("actual_total_sales"),
|
60 |
-
channel_dict.get("modified_total_spends")
|
61 |
-
* channel_dict.get("conversion_rate")
|
62 |
-
/ channel_dict.get("modified_total_sales"),
|
63 |
-
]
|
64 |
-
)
|
65 |
-
|
66 |
-
summary_rows.append(
|
67 |
-
[
|
68 |
-
"Total",
|
69 |
-
scenario_dict.get("actual_total_spends"),
|
70 |
-
scenario_dict.get("modified_total_spends"),
|
71 |
-
scenario_dict.get("actual_total_sales"),
|
72 |
-
scenario_dict.get("modified_total_sales"),
|
73 |
-
scenario_dict.get("actual_total_sales")
|
74 |
-
/ scenario_dict.get("actual_total_spends"),
|
75 |
-
scenario_dict.get("modified_total_sales")
|
76 |
-
/ scenario_dict.get("modified_total_spends"),
|
77 |
-
"-",
|
78 |
-
"-",
|
79 |
-
scenario_dict.get("actual_total_spends")
|
80 |
-
/ scenario_dict.get("actual_total_sales"),
|
81 |
-
scenario_dict.get("modified_total_spends")
|
82 |
-
/ scenario_dict.get("modified_total_sales"),
|
83 |
-
]
|
84 |
-
)
|
85 |
-
|
86 |
-
columns_index = pd.MultiIndex.from_product(
|
87 |
-
[[""], ["Channel"]], names=["first", "second"]
|
88 |
-
)
|
89 |
-
columns_index = columns_index.append(
|
90 |
-
pd.MultiIndex.from_product(
|
91 |
-
[
|
92 |
-
["Spends", "NRPU", "ROI", "MROI", "Spend per NRPU"],
|
93 |
-
["Actual", "Simulated"],
|
94 |
-
],
|
95 |
-
names=["first", "second"],
|
96 |
-
)
|
97 |
-
)
|
98 |
-
return pd.DataFrame(summary_rows, columns=columns_index)
|
99 |
-
|
100 |
-
|
101 |
-
def summary_df_to_worksheet(df, ws):
|
102 |
-
heading_fill = PatternFill(
|
103 |
-
fill_type="solid", start_color="FF11B6BD", end_color="FF11B6BD"
|
104 |
-
)
|
105 |
-
for j, header in enumerate(df.columns.values):
|
106 |
-
col = j + 1
|
107 |
-
for i in range(1, 3):
|
108 |
-
ws.cell(row=i, column=j + 1, value=header[i - 1]).font = Font(
|
109 |
-
bold=True, color="FF11B6BD"
|
110 |
-
)
|
111 |
-
ws.cell(row=i, column=j + 1).fill = heading_fill
|
112 |
-
if col > 1 and (col - 6) % 5 == 0:
|
113 |
-
ws.merge_cells(start_row=1, end_row=1, start_column=col - 3, end_column=col)
|
114 |
-
ws.cell(row=1, column=col).alignment = Alignment(horizontal="center")
|
115 |
-
for i, row in enumerate(df.itertuples()):
|
116 |
-
for j, value in enumerate(row):
|
117 |
-
if j == 0:
|
118 |
-
continue
|
119 |
-
elif (j - 2) % 4 == 0 or (j - 3) % 4 == 0:
|
120 |
-
ws.cell(row=i + 3, column=j, value=value).number_format = "$#,##0.0"
|
121 |
-
else:
|
122 |
-
ws.cell(row=i + 3, column=j, value=value)
|
123 |
-
|
124 |
-
|
125 |
-
from openpyxl.utils import get_column_letter
|
126 |
-
from openpyxl.styles import Font, PatternFill
|
127 |
-
import logging
|
128 |
-
|
129 |
-
|
130 |
-
def scenario_df_to_worksheet(df, ws):
|
131 |
-
heading_fill = PatternFill(
|
132 |
-
start_color="FF11B6BD", end_color="FF11B6BD", fill_type="solid"
|
133 |
-
)
|
134 |
-
|
135 |
-
for j, header in enumerate(df.columns.values):
|
136 |
-
cell = ws.cell(row=1, column=j + 1, value=header)
|
137 |
-
cell.font = Font(bold=True, color="FF11B6BD")
|
138 |
-
cell.fill = heading_fill
|
139 |
-
|
140 |
-
for i, row in enumerate(df.itertuples()):
|
141 |
-
for j, value in enumerate(
|
142 |
-
row[1:], start=1
|
143 |
-
): # Start from index 1 to skip the index column
|
144 |
-
try:
|
145 |
-
cell = ws.cell(row=i + 2, column=j, value=value)
|
146 |
-
if isinstance(value, (int, float)):
|
147 |
-
cell.number_format = "$#,##0.0"
|
148 |
-
elif isinstance(value, str):
|
149 |
-
cell.value = value[:32767]
|
150 |
-
else:
|
151 |
-
cell.value = str(value)
|
152 |
-
except ValueError as e:
|
153 |
-
logging.error(
|
154 |
-
f"Error assigning value '{value}' to cell {get_column_letter(j)}{i+2}: {e}"
|
155 |
-
)
|
156 |
-
cell.value = None # Assign None to the cell where the error occurred
|
157 |
-
|
158 |
-
return ws
|
159 |
-
|
160 |
-
|
161 |
-
def download_scenarios():
|
162 |
-
"""
|
163 |
-
Makes a excel with all saved scenarios and saves it locally
|
164 |
-
"""
|
165 |
-
## create summary page
|
166 |
-
if len(scenarios_to_download) == 0:
|
167 |
-
return
|
168 |
-
wb = Workbook()
|
169 |
-
wb.iso_dates = True
|
170 |
-
wb.remove(wb.active)
|
171 |
-
st.session_state["xlsx_buffer"] = io.BytesIO()
|
172 |
-
summary_df = None
|
173 |
-
# print(scenarios_to_download)
|
174 |
-
for scenario_name in scenarios_to_download:
|
175 |
-
scenario_dict = st.session_state["saved_scenarios"][scenario_name]
|
176 |
-
_spends = []
|
177 |
-
column_names = ["Date"]
|
178 |
-
_sales = None
|
179 |
-
dates = None
|
180 |
-
summary_rows = []
|
181 |
-
for channel in scenario_dict["channels"]:
|
182 |
-
if dates is None:
|
183 |
-
dates = channel.get("dates")
|
184 |
-
_spends.append(dates)
|
185 |
-
if _sales is None:
|
186 |
-
_sales = channel.get("modified_sales")
|
187 |
-
else:
|
188 |
-
_sales += channel.get("modified_sales")
|
189 |
-
_spends.append(
|
190 |
-
channel.get("modified_spends") * channel.get("conversion_rate")
|
191 |
-
)
|
192 |
-
column_names.append(channel.get("name"))
|
193 |
-
|
194 |
-
name_mod = channel_name_formating(channel["name"])
|
195 |
-
summary_rows.append(
|
196 |
-
[
|
197 |
-
name_mod,
|
198 |
-
channel.get("modified_total_spends")
|
199 |
-
* channel.get("conversion_rate"),
|
200 |
-
channel.get("modified_total_sales"),
|
201 |
-
channel.get("modified_total_sales")
|
202 |
-
/ channel.get("modified_total_spends")
|
203 |
-
* channel.get("conversion_rate"),
|
204 |
-
channel.get("modified_mroi"),
|
205 |
-
channel.get("modified_total_sales")
|
206 |
-
/ channel.get("modified_total_spends")
|
207 |
-
* channel.get("conversion_rate"),
|
208 |
-
]
|
209 |
-
)
|
210 |
-
_spends.append(_sales)
|
211 |
-
column_names.append("NRPU")
|
212 |
-
scenario_df = pd.DataFrame(_spends).T
|
213 |
-
scenario_df.columns = column_names
|
214 |
-
## write to sheet
|
215 |
-
ws = wb.create_sheet(scenario_name)
|
216 |
-
scenario_df_to_worksheet(scenario_df, ws)
|
217 |
-
summary_rows.append(
|
218 |
-
[
|
219 |
-
"Total",
|
220 |
-
scenario_dict.get("modified_total_spends"),
|
221 |
-
scenario_dict.get("modified_total_sales"),
|
222 |
-
scenario_dict.get("modified_total_sales")
|
223 |
-
/ scenario_dict.get("modified_total_spends"),
|
224 |
-
"-",
|
225 |
-
scenario_dict.get("modified_total_spends")
|
226 |
-
/ scenario_dict.get("modified_total_sales"),
|
227 |
-
]
|
228 |
-
)
|
229 |
-
columns_index = pd.MultiIndex.from_product(
|
230 |
-
[[""], ["Channel"]], names=["first", "second"]
|
231 |
-
)
|
232 |
-
columns_index = columns_index.append(
|
233 |
-
pd.MultiIndex.from_product(
|
234 |
-
[[scenario_name], ["Spends", "NRPU", "ROI", "MROI", "Spends per NRPU"]],
|
235 |
-
names=["first", "second"],
|
236 |
-
)
|
237 |
-
)
|
238 |
-
if summary_df is None:
|
239 |
-
summary_df = pd.DataFrame(summary_rows, columns=columns_index)
|
240 |
-
summary_df = summary_df.set_index(("", "Channel"))
|
241 |
-
else:
|
242 |
-
_df = pd.DataFrame(summary_rows, columns=columns_index)
|
243 |
-
_df = _df.set_index(("", "Channel"))
|
244 |
-
summary_df = summary_df.merge(_df, left_index=True, right_index=True)
|
245 |
-
ws = wb.create_sheet("Summary", 0)
|
246 |
-
summary_df_to_worksheet(summary_df.reset_index(), ws)
|
247 |
-
wb.save(st.session_state["xlsx_buffer"])
|
248 |
-
st.session_state["disable_download_button"] = False
|
249 |
-
|
250 |
-
|
251 |
-
def disable_download_button():
|
252 |
-
st.session_state["disable_download_button"] = True
|
253 |
-
|
254 |
-
|
255 |
-
def transform(x):
|
256 |
-
if x.name == ("", "Channel"):
|
257 |
-
return x
|
258 |
-
elif x.name[0] == "ROI" or x.name[0] == "MROI":
|
259 |
-
return x.apply(
|
260 |
-
lambda y: (
|
261 |
-
y
|
262 |
-
if isinstance(y, str)
|
263 |
-
else decimal_formater(
|
264 |
-
format_numbers(y, include_indicator=False, n_decimals=4),
|
265 |
-
n_decimals=4,
|
266 |
-
)
|
267 |
-
)
|
268 |
-
)
|
269 |
-
else:
|
270 |
-
return x.apply(lambda y: y if isinstance(y, str) else format_numbers(y))
|
271 |
-
|
272 |
-
|
273 |
-
def delete_scenario():
|
274 |
-
if selected_scenario in st.session_state["saved_scenarios"]:
|
275 |
-
del st.session_state["saved_scenarios"][selected_scenario]
|
276 |
-
with open("../saved_scenarios.pkl", "wb") as f:
|
277 |
-
pickle.dump(st.session_state["saved_scenarios"], f)
|
278 |
-
|
279 |
-
|
280 |
-
def load_scenario():
|
281 |
-
if selected_scenario in st.session_state["saved_scenarios"]:
|
282 |
-
st.session_state["scenario"] = class_from_dict(selected_scenario_details)
|
283 |
-
|
284 |
-
|
285 |
-
authenticator = st.session_state.get("authenticator")
|
286 |
-
if authenticator is None:
|
287 |
-
authenticator = load_authenticator()
|
288 |
-
|
289 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
290 |
-
auth_status = st.session_state.get("authentication_status")
|
291 |
-
|
292 |
-
if auth_status == True:
|
293 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
294 |
-
if not is_state_initiaized:
|
295 |
-
# print("Scenario page state reloaded")
|
296 |
-
initialize_data()
|
297 |
-
|
298 |
-
saved_scenarios = st.session_state["saved_scenarios"]
|
299 |
-
|
300 |
-
if len(saved_scenarios) == 0:
|
301 |
-
st.header("No saved scenarios")
|
302 |
-
|
303 |
-
else:
|
304 |
-
selected_scenario_list = list(saved_scenarios.keys())
|
305 |
-
if "selected_scenario_selectbox_key" not in st.session_state:
|
306 |
-
st.session_state["selected_scenario_selectbox_key"] = (
|
307 |
-
selected_scenario_list[
|
308 |
-
st.session_state["project_dct"]["saved_scenarios"][
|
309 |
-
"selected_scenario_selectbox_key"
|
310 |
-
]
|
311 |
-
]
|
312 |
-
)
|
313 |
-
|
314 |
-
col_a, col_b = st.columns(2)
|
315 |
-
selected_scenario = col_a.selectbox(
|
316 |
-
"Pick a scenario to view details",
|
317 |
-
selected_scenario_list,
|
318 |
-
# key="selected_scenario_selectbox_key",
|
319 |
-
index=st.session_state["project_dct"]["saved_scenarios"][
|
320 |
-
"selected_scenario_selectbox_key"
|
321 |
-
],
|
322 |
-
)
|
323 |
-
st.session_state["project_dct"]["saved_scenarios"][
|
324 |
-
"selected_scenario_selectbox_key"
|
325 |
-
] = selected_scenario_list.index(selected_scenario)
|
326 |
-
|
327 |
-
scenarios_to_download = col_b.multiselect(
|
328 |
-
"Select scenarios to download",
|
329 |
-
list(saved_scenarios.keys()),
|
330 |
-
on_change=disable_download_button,
|
331 |
-
)
|
332 |
-
|
333 |
-
with col_a:
|
334 |
-
col3, col4 = st.columns(2)
|
335 |
-
|
336 |
-
col4.button(
|
337 |
-
"Delete scenarios",
|
338 |
-
on_click=delete_scenario,
|
339 |
-
use_container_width=True,
|
340 |
-
)
|
341 |
-
col3.button(
|
342 |
-
"Load Scenario",
|
343 |
-
on_click=load_scenario,
|
344 |
-
use_container_width=True,
|
345 |
-
)
|
346 |
-
|
347 |
-
with col_b:
|
348 |
-
col1, col2 = st.columns(2)
|
349 |
-
|
350 |
-
col1.button(
|
351 |
-
"Prepare download",
|
352 |
-
on_click=download_scenarios,
|
353 |
-
use_container_width=True,
|
354 |
-
)
|
355 |
-
col2.download_button(
|
356 |
-
label="Download Scenarios",
|
357 |
-
data=st.session_state["xlsx_buffer"].getvalue(),
|
358 |
-
file_name="scenarios.xlsx",
|
359 |
-
mime="application/vnd.ms-excel",
|
360 |
-
disabled=st.session_state["disable_download_button"],
|
361 |
-
on_click=disable_download_button,
|
362 |
-
use_container_width=True,
|
363 |
-
)
|
364 |
-
|
365 |
-
# column_1, column_2, column_3 = st.columns((6, 1, 1))
|
366 |
-
# with column_1:
|
367 |
-
# st.header(selected_scenario)
|
368 |
-
# with column_2:
|
369 |
-
# st.button("Delete scenarios", on_click=delete_scenario)
|
370 |
-
# with column_3:
|
371 |
-
# st.button("Load Scenario", on_click=load_scenario)
|
372 |
-
|
373 |
-
selected_scenario_details = saved_scenarios[selected_scenario]
|
374 |
-
|
375 |
-
pd.set_option("display.max_colwidth", 100)
|
376 |
-
|
377 |
-
st.markdown(
|
378 |
-
create_scenario_summary(selected_scenario_details)
|
379 |
-
.transform(transform)
|
380 |
-
.style.set_table_styles(
|
381 |
-
[
|
382 |
-
{"selector": "th", "props": [("background-color", "#11B6BD")]},
|
383 |
-
{
|
384 |
-
"selector": "tr:nth-child(even)",
|
385 |
-
"props": [("background-color", "#11B6BD")],
|
386 |
-
},
|
387 |
-
]
|
388 |
-
)
|
389 |
-
.to_html(),
|
390 |
-
unsafe_allow_html=True,
|
391 |
-
)
|
392 |
-
|
393 |
-
elif auth_status == False:
|
394 |
-
st.error("Username/Password is incorrect")
|
395 |
-
|
396 |
-
if auth_status != True:
|
397 |
-
try:
|
398 |
-
username_forgot_pw, email_forgot_password, random_password = (
|
399 |
-
authenticator.forgot_password("Forgot password")
|
400 |
-
)
|
401 |
-
if username_forgot_pw:
|
402 |
-
st.success("New password sent securely")
|
403 |
-
# Random password to be transferred to user securely
|
404 |
-
elif username_forgot_pw == False:
|
405 |
-
st.error("Username not found")
|
406 |
-
except Exception as e:
|
407 |
-
st.error(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/11_Optimized_Result_Analysis.py
DELETED
@@ -1,453 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from numerize.numerize import numerize
|
3 |
-
import pandas as pd
|
4 |
-
from utilities import (format_numbers,decimal_formater,
|
5 |
-
load_local_css,set_header,
|
6 |
-
initialize_data,
|
7 |
-
load_authenticator)
|
8 |
-
import pickle
|
9 |
-
import streamlit_authenticator as stauth
|
10 |
-
import yaml
|
11 |
-
from yaml import SafeLoader
|
12 |
-
from classes import class_from_dict
|
13 |
-
import plotly.express as px
|
14 |
-
import numpy as np
|
15 |
-
import plotly.graph_objects as go
|
16 |
-
import pandas as pd
|
17 |
-
from plotly.subplots import make_subplots
|
18 |
-
import sqlite3
|
19 |
-
from utilities import update_db
|
20 |
-
def format_number(x):
|
21 |
-
if x >= 1_000_000:
|
22 |
-
return f'{x / 1_000_000:.2f}M'
|
23 |
-
elif x >= 1_000:
|
24 |
-
return f'{x / 1_000:.2f}K'
|
25 |
-
else:
|
26 |
-
return f'{x:.2f}'
|
27 |
-
|
28 |
-
def summary_plot(data, x, y, title, text_column, color, format_as_percent=False, format_as_decimal=False):
|
29 |
-
fig = px.bar(data, x=x, y=y, orientation='h',
|
30 |
-
title=title, text=text_column, color=color)
|
31 |
-
fig.update_layout(showlegend=False)
|
32 |
-
data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
|
33 |
-
|
34 |
-
# Update the format of the displayed text based on the chosen format
|
35 |
-
if format_as_percent:
|
36 |
-
fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
|
37 |
-
elif format_as_decimal:
|
38 |
-
fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
|
39 |
-
else:
|
40 |
-
fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
|
41 |
-
|
42 |
-
fig.update_layout(xaxis_title=x, yaxis_title='Channel Name', showlegend=False)
|
43 |
-
return fig
|
44 |
-
|
45 |
-
|
46 |
-
def stacked_summary_plot(data, x, y, title, text_column, color_column, stack_column=None, format_as_percent=False, format_as_decimal=False):
|
47 |
-
fig = px.bar(data, x=x, y=y, orientation='h',
|
48 |
-
title=title, text=text_column, color=color_column, facet_col=stack_column)
|
49 |
-
fig.update_layout(showlegend=False)
|
50 |
-
data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
|
51 |
-
|
52 |
-
# Update the format of the displayed text based on the chosen format
|
53 |
-
if format_as_percent:
|
54 |
-
fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
|
55 |
-
elif format_as_decimal:
|
56 |
-
fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
|
57 |
-
else:
|
58 |
-
fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
|
59 |
-
|
60 |
-
fig.update_layout(xaxis_title=x, yaxis_title='', showlegend=False)
|
61 |
-
return fig
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
def funnel_plot(data, x, y, title, text_column, color_column, format_as_percent=False, format_as_decimal=False):
|
66 |
-
data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
|
67 |
-
|
68 |
-
# Round the numeric values in the text column to two decimal points
|
69 |
-
data[text_column] = data[text_column].round(2)
|
70 |
-
|
71 |
-
# Create a color map for categorical data
|
72 |
-
color_map = {category: f'rgb({i * 30 % 255},{i * 50 % 255},{i * 70 % 255})' for i, category in enumerate(data[color_column].unique())}
|
73 |
-
|
74 |
-
fig = go.Figure(go.Funnel(
|
75 |
-
y=data[y],
|
76 |
-
x=data[x],
|
77 |
-
text=data[text_column],
|
78 |
-
marker=dict(color=data[color_column].map(color_map)),
|
79 |
-
textinfo="value",
|
80 |
-
hoverinfo='y+x+text'
|
81 |
-
))
|
82 |
-
|
83 |
-
# Update the format of the displayed text based on the chosen format
|
84 |
-
if format_as_percent:
|
85 |
-
fig.update_layout(title=title, funnelmode="percent")
|
86 |
-
elif format_as_decimal:
|
87 |
-
fig.update_layout(title=title, funnelmode="overlay")
|
88 |
-
else:
|
89 |
-
fig.update_layout(title=title, funnelmode="group")
|
90 |
-
|
91 |
-
return fig
|
92 |
-
|
93 |
-
|
94 |
-
st.set_page_config(layout='wide')
|
95 |
-
load_local_css('styles.css')
|
96 |
-
set_header()
|
97 |
-
|
98 |
-
# for k, v in st.session_state.items():
|
99 |
-
# if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
|
100 |
-
# st.session_state[k] = v
|
101 |
-
|
102 |
-
st.empty()
|
103 |
-
st.header('Model Result Analysis')
|
104 |
-
spends_data=pd.read_excel('Overview_data_test.xlsx')
|
105 |
-
|
106 |
-
with open('summary_df.pkl', 'rb') as file:
|
107 |
-
summary_df_sorted = pickle.load(file)
|
108 |
-
#st.write(summary_df_sorted)
|
109 |
-
|
110 |
-
selected_scenario= st.selectbox('Select Saved Scenarios',['S1','S2'])
|
111 |
-
summary_df_sorted=summary_df_sorted.sort_values(by=['Optimized_spend'],ascending=False)
|
112 |
-
st.header('Optimized Spends Overview')
|
113 |
-
___columns=st.columns(3)
|
114 |
-
with ___columns[2]:
|
115 |
-
fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent',color='Channel_name')
|
116 |
-
st.plotly_chart(fig,use_container_width=True)
|
117 |
-
with ___columns[0]:
|
118 |
-
fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend',color='Channel_name')
|
119 |
-
st.plotly_chart(fig,use_container_width=True)
|
120 |
-
with ___columns[1]:
|
121 |
-
fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
|
122 |
-
st.plotly_chart(fig,use_container_width=False)
|
123 |
-
|
124 |
-
st.header(' Budget Allocation')
|
125 |
-
summary_df_sorted['Perc_alloted']=np.round(summary_df_sorted['Optimized_spend']/summary_df_sorted['Optimized_spend'].sum(),2)
|
126 |
-
columns2=st.columns(2)
|
127 |
-
with columns2[0]:
|
128 |
-
fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
|
129 |
-
st.plotly_chart(fig,use_container_width=True)
|
130 |
-
with columns2[1]:
|
131 |
-
fig=summary_plot(summary_df_sorted, x='Perc_alloted', y='Channel_name', title='% Split', text_column='Perc_alloted',color='Channel_name',format_as_percent=True)
|
132 |
-
st.plotly_chart(fig,use_container_width=True)
|
133 |
-
|
134 |
-
|
135 |
-
if 'raw_data' not in st.session_state:
|
136 |
-
st.session_state['raw_data']=pd.read_excel('raw_data_nov7_combined1.xlsx')
|
137 |
-
st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['MediaChannelName'].isin(summary_df_sorted['Channel_name'].unique())]
|
138 |
-
st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['Date'].isin(spends_data["Date"].unique())]
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
#st.write(st.session_state['raw_data']['ResponseMetricName'])
|
143 |
-
# st.write(st.session_state['raw_data'])
|
144 |
-
|
145 |
-
|
146 |
-
st.header('Response Forecast Overview')
|
147 |
-
raw_data=st.session_state['raw_data']
|
148 |
-
effectiveness_overall=raw_data.groupby('ResponseMetricName').agg({'ResponseMetricValue': 'sum'}).reset_index()
|
149 |
-
effectiveness_overall['Efficiency']=effectiveness_overall['ResponseMetricValue'].map(lambda x: x/raw_data['Media Spend'].sum() )
|
150 |
-
# st.write(effectiveness_overall)
|
151 |
-
|
152 |
-
columns6=st.columns(3)
|
153 |
-
|
154 |
-
effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False,inplace=True)
|
155 |
-
effectiveness_overall=np.round(effectiveness_overall,2)
|
156 |
-
effectiveness_overall['ResponseMetric'] = effectiveness_overall['ResponseMetricName'].apply(lambda x: 'BAU' if 'BAU' in x else ('Gamified' if 'Gamified' in x else x))
|
157 |
-
# effectiveness_overall=np.where(effectiveness_overall[effectiveness_overall['ResponseMetricName']=="Adjusted Account Approval BAU"],"Adjusted Account Approval BAU",effectiveness_overall['ResponseMetricName'])
|
158 |
-
|
159 |
-
effectiveness_overall.replace({'ResponseMetricName':{'BAU approved clients - Appsflyer':'Approved clients - Appsflyer',
|
160 |
-
'Gamified approved clients - Appsflyer':'Approved clients - Appsflyer'}},inplace=True)
|
161 |
-
|
162 |
-
# st.write(effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False))
|
163 |
-
|
164 |
-
|
165 |
-
condition = effectiveness_overall['ResponseMetricName'] == "Adjusted Account Approval BAU"
|
166 |
-
condition1= effectiveness_overall['ResponseMetricName'] == "Approved clients - Appsflyer"
|
167 |
-
effectiveness_overall['ResponseMetric'] = np.where(condition, "Adjusted Account Approval BAU", effectiveness_overall['ResponseMetric'])
|
168 |
-
|
169 |
-
effectiveness_overall['ResponseMetricName'] = np.where(condition1, "Approved clients - Appsflyer (BAU, Gamified)", effectiveness_overall['ResponseMetricName'])
|
170 |
-
# effectiveness_overall=pd.DataFrame({'ResponseMetricName':["App Installs - Appsflyer",'Account Requests - Appsflyer',
|
171 |
-
# 'Total Adjusted Account Approval','Adjusted Account Approval BAU',
|
172 |
-
# 'Approved clients - Appsflyer','Approved clients - Appsflyer'],
|
173 |
-
# 'ResponseMetricValue':[683067,367020,112315,79768,36661,16834],
|
174 |
-
# 'Efficiency':[1.24,0.67,0.2,0.14,0.07,0.03],
|
175 |
-
custom_colors = {
|
176 |
-
'App Installs - Appsflyer': 'rgb(255, 135, 0)', # Steel Blue (Blue)
|
177 |
-
'Account Requests - Appsflyer': 'rgb(125, 239, 161)', # Cornflower Blue (Blue)
|
178 |
-
'Adjusted Account Approval': 'rgb(129, 200, 255)', # Dodger Blue (Blue)
|
179 |
-
'Adjusted Account Approval BAU': 'rgb(255, 207, 98)', # Light Sky Blue (Blue)
|
180 |
-
'Approved clients - Appsflyer': 'rgb(0, 97, 198)', # Light Blue (Blue)
|
181 |
-
"BAU": 'rgb(41, 176, 157)', # Steel Blue (Blue)
|
182 |
-
"Gamified": 'rgb(213, 218, 229)' # Silver (Gray)
|
183 |
-
# Add more categories and their respective shades of blue as needed
|
184 |
-
}
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
with columns6[0]:
|
192 |
-
revenue=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Total Approved Accounts - Revenue']['ResponseMetricValue']).iloc[0]
|
193 |
-
revenue=round(revenue / 1_000_000, 2)
|
194 |
-
|
195 |
-
# st.metric('Total Revenue', f"${revenue} M")
|
196 |
-
# with columns6[1]:
|
197 |
-
# BAU=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='BAU approved clients - Revenue']['ResponseMetricValue']).iloc[0]
|
198 |
-
# BAU=round(BAU / 1_000_000, 2)
|
199 |
-
# st.metric('BAU approved clients - Revenue', f"${BAU} M")
|
200 |
-
# with columns6[2]:
|
201 |
-
# Gam=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Gamified approved clients - Revenue']['ResponseMetricValue']).iloc[0]
|
202 |
-
# Gam=round(Gam / 1_000_000, 2)
|
203 |
-
# st.metric('Gamified approved clients - Revenue', f"${Gam} M")
|
204 |
-
|
205 |
-
# st.write(effectiveness_overall)
|
206 |
-
data = {'Revenue': ['BAU approved clients - Revenue', 'Gamified approved clients- Revenue'],
|
207 |
-
'ResponseMetricValue': [70200000, 1770000],
|
208 |
-
'Efficiency':[127.54,3.21]}
|
209 |
-
df = pd.DataFrame(data)
|
210 |
-
|
211 |
-
|
212 |
-
columns9=st.columns([0.60,0.40])
|
213 |
-
with columns9[0]:
|
214 |
-
figd = px.pie(df,
|
215 |
-
names='Revenue',
|
216 |
-
values='ResponseMetricValue',
|
217 |
-
hole=0.3, # set the size of the hole in the donut
|
218 |
-
title='Effectiveness')
|
219 |
-
figd.update_layout(
|
220 |
-
margin=dict(l=0, r=0, b=0, t=0),width=100, height=180,legend=dict(
|
221 |
-
orientation='v', # set orientation to horizontal
|
222 |
-
x=0, # set x to 0 to move to the left
|
223 |
-
y=0.8 # adjust y as needed
|
224 |
-
)
|
225 |
-
)
|
226 |
-
|
227 |
-
st.plotly_chart(figd, use_container_width=True)
|
228 |
-
|
229 |
-
with columns9[1]:
|
230 |
-
figd1 = px.pie(df,
|
231 |
-
names='Revenue',
|
232 |
-
values='Efficiency',
|
233 |
-
hole=0.3, # set the size of the hole in the donut
|
234 |
-
title='Efficiency')
|
235 |
-
figd1.update_layout(
|
236 |
-
margin=dict(l=0, r=0, b=0, t=0),width=100,height=180,showlegend=False
|
237 |
-
)
|
238 |
-
st.plotly_chart(figd1, use_container_width=True)
|
239 |
-
|
240 |
-
effectiveness_overall['Response Metric Name']=effectiveness_overall['ResponseMetricName']
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
columns4= st.columns([0.55,0.45])
|
245 |
-
with columns4[0]:
|
246 |
-
fig=px.funnel(effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
|
247 |
-
'BAU approved clients - Revenue',
|
248 |
-
'Gamified approved clients - Revenue',
|
249 |
-
"Total Approved Accounts - Appsflyer"]))],
|
250 |
-
x='ResponseMetricValue', y='Response Metric Name',color='ResponseMetric',
|
251 |
-
color_discrete_map=custom_colors,title='Effectiveness',
|
252 |
-
labels=None)
|
253 |
-
custom_y_labels=['App Installs - Appsflyer','Account Requests - Appsflyer','Adjusted Account Approval','Adjusted Account Approval BAU',
|
254 |
-
"Approved clients - Appsflyer (BAU, Gamified)"
|
255 |
-
]
|
256 |
-
fig.update_layout(showlegend=False,
|
257 |
-
yaxis=dict(
|
258 |
-
tickmode='array',
|
259 |
-
ticktext=custom_y_labels,
|
260 |
-
)
|
261 |
-
)
|
262 |
-
fig.update_traces(textinfo='value', textposition='inside', texttemplate='%{x:.2s} ', hoverinfo='y+x+percent initial')
|
263 |
-
|
264 |
-
last_trace_index = len(fig.data) - 1
|
265 |
-
fig.update_traces(marker=dict(line=dict(color='black', width=2)), selector=dict(marker=dict(color='blue')))
|
266 |
-
|
267 |
-
st.plotly_chart(fig,use_container_width=True)
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
with columns4[1]:
|
274 |
-
|
275 |
-
# Your existing code for creating the bar chart
|
276 |
-
fig1 = px.bar((effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
|
277 |
-
'BAU approved clients - Revenue',
|
278 |
-
'Gamified approved clients - Revenue',
|
279 |
-
"Total Approved Accounts - Appsflyer"]))]).sort_values(by='ResponseMetricValue'),
|
280 |
-
x='Efficiency', y='Response Metric Name',
|
281 |
-
color_discrete_map=custom_colors, color='ResponseMetric',
|
282 |
-
labels=None,text_auto=True,title='Efficiency'
|
283 |
-
)
|
284 |
-
|
285 |
-
# Update layout and traces
|
286 |
-
fig1.update_traces(customdata=effectiveness_overall['Efficiency'],
|
287 |
-
textposition='auto')
|
288 |
-
fig1.update_layout(showlegend=False)
|
289 |
-
fig1.update_yaxes(title='',showticklabels=False)
|
290 |
-
fig1.update_xaxes(title='',showticklabels=False)
|
291 |
-
fig1.update_xaxes(tickfont=dict(size=20))
|
292 |
-
fig1.update_yaxes(tickfont=dict(size=20))
|
293 |
-
st.plotly_chart(fig1, use_container_width=True)
|
294 |
-
|
295 |
-
|
296 |
-
effectiveness_overall_revenue=pd.DataFrame({'ResponseMetricName':['Approved Clients','Approved Clients'],
|
297 |
-
'ResponseMetricValue':[70201070,1768900],
|
298 |
-
'Efficiency':[127.54,3.21],
|
299 |
-
'ResponseMetric':['BAU','Gamified']
|
300 |
-
})
|
301 |
-
# from plotly.subplots import make_subplots
|
302 |
-
# fig = make_subplots(rows=1, cols=2,
|
303 |
-
# subplot_titles=["Effectiveness", "Efficiency"])
|
304 |
-
|
305 |
-
# # Add first plot as subplot
|
306 |
-
# fig.add_trace(go.Funnel(
|
307 |
-
# x = fig.data[0].x,
|
308 |
-
# y = fig.data[0].y,
|
309 |
-
# textinfo = 'value+percent initial',
|
310 |
-
# hoverinfo = 'x+y+percent initial'
|
311 |
-
# ), row=1, col=1)
|
312 |
-
|
313 |
-
# # Update layout for first subplot
|
314 |
-
# fig.update_xaxes(title_text="Response Metric Value", row=1, col=1)
|
315 |
-
# fig.update_yaxes(ticktext = custom_y_labels, row=1, col=1)
|
316 |
-
|
317 |
-
# # Add second plot as subplot
|
318 |
-
# fig.add_trace(go.Bar(
|
319 |
-
# x = fig1.data[0].x,
|
320 |
-
# y = fig1.data[0].y,
|
321 |
-
# customdata = fig1.data[0].customdata,
|
322 |
-
# textposition = 'auto'
|
323 |
-
# ), row=1, col=2)
|
324 |
-
|
325 |
-
# # Update layout for second subplot
|
326 |
-
# fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
|
327 |
-
# fig.update_yaxes(title='', showticklabels=False, row=1, col=2)
|
328 |
-
|
329 |
-
# fig.update_layout(height=600, width=800, title_text="Key Metrics")
|
330 |
-
# st.plotly_chart(fig)
|
331 |
-
|
332 |
-
|
333 |
-
st.header('Return Forecast by Media Channel')
|
334 |
-
with st.expander("Return Forecast by Media Channel"):
|
335 |
-
metric_data=[val for val in list(st.session_state['raw_data']['ResponseMetricName'].unique()) if val!=np.NaN]
|
336 |
-
# st.write(metric_data)
|
337 |
-
metric=st.selectbox('Select Metric',metric_data,index=1)
|
338 |
-
|
339 |
-
selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
|
340 |
-
# st.dataframe(selected_metric.head(2))
|
341 |
-
selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
|
342 |
-
effectiveness=selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()
|
343 |
-
effectiveness_df=pd.DataFrame({'Channel':effectiveness.index,"ResponseMetricValue":effectiveness.values})
|
344 |
-
|
345 |
-
summary_df_sorted=summary_df_sorted.merge(effectiveness_df,left_on="Channel_name",right_on='Channel')
|
346 |
-
|
347 |
-
#
|
348 |
-
summary_df_sorted['Efficiency'] = summary_df_sorted['ResponseMetricValue'] / summary_df_sorted['Optimized_spend']
|
349 |
-
summary_df_sorted=summary_df_sorted.sort_values(by='Optimized_spend',ascending=True)
|
350 |
-
#st.dataframe(summary_df_sorted)
|
351 |
-
|
352 |
-
channel_colors = px.colors.qualitative.Plotly
|
353 |
-
|
354 |
-
fig = make_subplots(rows=1, cols=3, subplot_titles=('Optimized Spends', 'Effectiveness', 'Efficiency'), horizontal_spacing=0.05)
|
355 |
-
|
356 |
-
for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
|
357 |
-
channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
|
358 |
-
channel_color = channel_colors[i % len(channel_colors)]
|
359 |
-
|
360 |
-
fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
|
361 |
-
y=channel_df['Channel_name'],
|
362 |
-
text=channel_df['Optimized_spend'].apply(format_number),
|
363 |
-
marker_color=channel_color,
|
364 |
-
orientation='h'), row=1, col=1)
|
365 |
-
|
366 |
-
fig.add_trace(go.Bar(x=channel_df['ResponseMetricValue'],
|
367 |
-
y=channel_df['Channel_name'],
|
368 |
-
text=channel_df['ResponseMetricValue'].apply(format_number),
|
369 |
-
marker_color=channel_color,
|
370 |
-
orientation='h', showlegend=False), row=1, col=2)
|
371 |
-
|
372 |
-
fig.add_trace(go.Bar(x=channel_df['Efficiency'],
|
373 |
-
y=channel_df['Channel_name'],
|
374 |
-
text=channel_df['Efficiency'].apply(format_number),
|
375 |
-
marker_color=channel_color,
|
376 |
-
orientation='h', showlegend=False), row=1, col=3)
|
377 |
-
|
378 |
-
fig.update_layout(
|
379 |
-
height=600,
|
380 |
-
width=900,
|
381 |
-
title='Media Channel Performance',
|
382 |
-
showlegend=False
|
383 |
-
)
|
384 |
-
|
385 |
-
fig.update_yaxes(showticklabels=False ,row=1, col=2 )
|
386 |
-
fig.update_yaxes(showticklabels=False, row=1, col=3)
|
387 |
-
|
388 |
-
fig.update_xaxes(showticklabels=False, row=1, col=1)
|
389 |
-
fig.update_xaxes(showticklabels=False, row=1, col=2)
|
390 |
-
fig.update_xaxes(showticklabels=False, row=1, col=3)
|
391 |
-
|
392 |
-
|
393 |
-
st.plotly_chart(fig, use_container_width=True)
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
# columns= st.columns(3)
|
398 |
-
# with columns[0]:
|
399 |
-
# fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='', text_column='Optimized_spend',color='Channel_name')
|
400 |
-
# st.plotly_chart(fig,use_container_width=True)
|
401 |
-
# with columns[1]:
|
402 |
-
|
403 |
-
# # effectiveness=(selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()).values
|
404 |
-
# # effectiveness_df=pd.DataFrame({'Channel':st.session_state['raw_data']['MediaChannelName'].unique(),"ResponseMetricValue":effectiveness})
|
405 |
-
# # # effectiveness.reset_index(inplace=True)
|
406 |
-
# # # st.dataframe(effectiveness.head())
|
407 |
-
|
408 |
-
|
409 |
-
# fig=summary_plot(summary_df_sorted, x='ResponseMetricValue', y='Channel_name', title='Effectiveness', text_column='ResponseMetricValue',color='Channel_name')
|
410 |
-
# st.plotly_chart(fig,use_container_width=True)
|
411 |
-
|
412 |
-
# with columns[2]:
|
413 |
-
# fig=summary_plot(summary_df_sorted, x='Efficiency', y='Channel_name', title='Efficiency', text_column='Efficiency',color='Channel_name',format_as_decimal=True)
|
414 |
-
# st.plotly_chart(fig,use_container_width=True)
|
415 |
-
|
416 |
-
|
417 |
-
# Create figure with subplots
|
418 |
-
# fig = make_subplots(rows=1, cols=2)
|
419 |
-
|
420 |
-
# # Add funnel plot to subplot 1
|
421 |
-
# fig.add_trace(
|
422 |
-
# go.Funnel(
|
423 |
-
# x=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricValue'],
|
424 |
-
# y=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricName'],
|
425 |
-
# textposition="inside",
|
426 |
-
# texttemplate="%{x:.2s}",
|
427 |
-
# customdata=effectiveness_overall['Efficiency'],
|
428 |
-
# hovertemplate="%{customdata:.2f}<extra></extra>"
|
429 |
-
# ),
|
430 |
-
# row=1, col=1
|
431 |
-
# )
|
432 |
-
|
433 |
-
# # Add bar plot to subplot 2
|
434 |
-
# fig.add_trace(
|
435 |
-
# go.Bar(
|
436 |
-
# x=effectiveness_overall.sort_values(by='ResponseMetricValue')['Efficiency'],
|
437 |
-
# y=effectiveness_overall.sort_values(by='ResponseMetricValue')['ResponseMetricName'],
|
438 |
-
# marker_color=effectiveness_overall['ResponseMetric'],
|
439 |
-
# customdata=effectiveness_overall['Efficiency'],
|
440 |
-
# hovertemplate="%{customdata:.2f}<extra></extra>",
|
441 |
-
# textposition="outside"
|
442 |
-
# ),
|
443 |
-
# row=1, col=2
|
444 |
-
# )
|
445 |
-
|
446 |
-
# # Update layout
|
447 |
-
# fig.update_layout(title_text="Effectiveness")
|
448 |
-
# fig.update_yaxes(title_text="", row=1, col=1)
|
449 |
-
# fig.update_yaxes(title_text="", showticklabels=False, row=1, col=2)
|
450 |
-
# fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
|
451 |
-
|
452 |
-
# # Show figure
|
453 |
-
# st.plotly_chart(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/1_Data_Import.py
DELETED
@@ -1,1547 +0,0 @@
|
|
1 |
-
# Importing necessary libraries
|
2 |
-
import streamlit as st
|
3 |
-
import os
|
4 |
-
|
5 |
-
# from Home_redirecting import home
|
6 |
-
from utilities import update_db
|
7 |
-
|
8 |
-
st.set_page_config(
|
9 |
-
page_title="Data Import",
|
10 |
-
page_icon=":shark:",
|
11 |
-
layout="wide",
|
12 |
-
initial_sidebar_state="collapsed",
|
13 |
-
)
|
14 |
-
|
15 |
-
import pickle
|
16 |
-
import pandas as pd
|
17 |
-
from utilities import set_header, load_local_css
|
18 |
-
import streamlit_authenticator as stauth
|
19 |
-
import yaml
|
20 |
-
from yaml import SafeLoader
|
21 |
-
import sqlite3
|
22 |
-
|
23 |
-
load_local_css("styles.css")
|
24 |
-
set_header()
|
25 |
-
|
26 |
-
for k, v in st.session_state.items():
|
27 |
-
if (
|
28 |
-
k not in ["logout", "login", "config"]
|
29 |
-
and not k.startswith("FormSubmitter")
|
30 |
-
and not k.startswith("data-editor")
|
31 |
-
):
|
32 |
-
st.session_state[k] = v
|
33 |
-
with open("config.yaml") as file:
|
34 |
-
config = yaml.load(file, Loader=SafeLoader)
|
35 |
-
st.session_state["config"] = config
|
36 |
-
authenticator = stauth.Authenticate(
|
37 |
-
config["credentials"],
|
38 |
-
config["cookie"]["name"],
|
39 |
-
config["cookie"]["key"],
|
40 |
-
config["cookie"]["expiry_days"],
|
41 |
-
config["preauthorized"],
|
42 |
-
)
|
43 |
-
st.session_state["authenticator"] = authenticator
|
44 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
45 |
-
auth_status = st.session_state.get("authentication_status")
|
46 |
-
|
47 |
-
if auth_status == True:
|
48 |
-
authenticator.logout("Logout", "main")
|
49 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
50 |
-
|
51 |
-
if not is_state_initiaized:
|
52 |
-
|
53 |
-
if "session_name" not in st.session_state:
|
54 |
-
st.session_state["session_name"] = None
|
55 |
-
|
56 |
-
# Function to validate date column in dataframe
|
57 |
-
|
58 |
-
if "project_dct" not in st.session_state:
|
59 |
-
# home()
|
60 |
-
st.warning("please select a project from Home page")
|
61 |
-
st.stop()
|
62 |
-
|
63 |
-
def validate_date_column(df):
|
64 |
-
try:
|
65 |
-
# Attempt to convert the 'Date' column to datetime
|
66 |
-
df["date"] = pd.to_datetime(df["date"], format="%d-%m-%Y")
|
67 |
-
return True
|
68 |
-
except:
|
69 |
-
return False
|
70 |
-
|
71 |
-
# Function to determine data interval
|
72 |
-
def determine_data_interval(common_freq):
|
73 |
-
if common_freq == 1:
|
74 |
-
return "daily"
|
75 |
-
elif common_freq == 7:
|
76 |
-
return "weekly"
|
77 |
-
elif 28 <= common_freq <= 31:
|
78 |
-
return "monthly"
|
79 |
-
else:
|
80 |
-
return "irregular"
|
81 |
-
|
82 |
-
# Function to read each uploaded Excel file into a pandas DataFrame and stores them in a dictionary
|
83 |
-
st.cache_resource(show_spinner=False)
|
84 |
-
|
85 |
-
def files_to_dataframes(uploaded_files):
|
86 |
-
df_dict = {}
|
87 |
-
for uploaded_file in uploaded_files:
|
88 |
-
# Extract file name without extension
|
89 |
-
file_name = uploaded_file.name.rsplit(".", 1)[0]
|
90 |
-
|
91 |
-
# Check for duplicate file names
|
92 |
-
if file_name in df_dict:
|
93 |
-
st.warning(
|
94 |
-
f"Duplicate File: {file_name}. This file will be skipped.",
|
95 |
-
icon="⚠️",
|
96 |
-
)
|
97 |
-
continue
|
98 |
-
|
99 |
-
# Read the file into a DataFrame
|
100 |
-
df = pd.read_excel(uploaded_file)
|
101 |
-
|
102 |
-
# Convert all column names to lowercase
|
103 |
-
df.columns = df.columns.str.lower().str.strip()
|
104 |
-
|
105 |
-
# Separate numeric and non-numeric columns
|
106 |
-
numeric_cols = list(df.select_dtypes(include=["number"]).columns)
|
107 |
-
non_numeric_cols = [
|
108 |
-
col
|
109 |
-
for col in df.select_dtypes(exclude=["number"]).columns
|
110 |
-
if col.lower() != "date"
|
111 |
-
]
|
112 |
-
|
113 |
-
# Check for 'Date' column
|
114 |
-
if not (validate_date_column(df) and len(numeric_cols) > 0):
|
115 |
-
st.warning(
|
116 |
-
f"File Name: {file_name} ➜ Please upload data with Date column in 'DD-MM-YYYY' format and at least one media/exogenous column. This file will be skipped.",
|
117 |
-
icon="⚠️",
|
118 |
-
)
|
119 |
-
continue
|
120 |
-
|
121 |
-
# Check for interval
|
122 |
-
common_freq = common_freq = (
|
123 |
-
pd.Series(df["date"].unique())
|
124 |
-
.diff()
|
125 |
-
.dt.days.dropna()
|
126 |
-
.mode()[0]
|
127 |
-
)
|
128 |
-
# Calculate the data interval (daily, weekly, monthly or irregular)
|
129 |
-
interval = determine_data_interval(common_freq)
|
130 |
-
if interval == "irregular":
|
131 |
-
st.warning(
|
132 |
-
f"File Name: {file_name} ➜ Please upload data in daily, weekly or monthly interval. This file will be skipped.",
|
133 |
-
icon="⚠️",
|
134 |
-
)
|
135 |
-
continue
|
136 |
-
|
137 |
-
# Store both DataFrames in the dictionary under their respective keys
|
138 |
-
df_dict[file_name] = {
|
139 |
-
"numeric": numeric_cols,
|
140 |
-
"non_numeric": non_numeric_cols,
|
141 |
-
"interval": interval,
|
142 |
-
"df": df,
|
143 |
-
}
|
144 |
-
|
145 |
-
return df_dict
|
146 |
-
|
147 |
-
# Function to adjust dataframe granularity
|
148 |
-
def adjust_dataframe_granularity(
|
149 |
-
df, current_granularity, target_granularity
|
150 |
-
):
|
151 |
-
# Set index
|
152 |
-
df.set_index("date", inplace=True)
|
153 |
-
|
154 |
-
# Define aggregation rules for resampling
|
155 |
-
aggregation_rules = {
|
156 |
-
col: "sum" if pd.api.types.is_numeric_dtype(df[col]) else "first"
|
157 |
-
for col in df.columns
|
158 |
-
}
|
159 |
-
|
160 |
-
# Initialize resampled_df
|
161 |
-
resampled_df = df
|
162 |
-
if current_granularity == "daily" and target_granularity == "weekly":
|
163 |
-
resampled_df = df.resample(
|
164 |
-
"W-MON", closed="left", label="left"
|
165 |
-
).agg(aggregation_rules)
|
166 |
-
|
167 |
-
elif (
|
168 |
-
current_granularity == "daily" and target_granularity == "monthly"
|
169 |
-
):
|
170 |
-
resampled_df = df.resample("MS", closed="left", label="left").agg(
|
171 |
-
aggregation_rules
|
172 |
-
)
|
173 |
-
|
174 |
-
elif current_granularity == "daily" and target_granularity == "daily":
|
175 |
-
resampled_df = df.resample("D").agg(aggregation_rules)
|
176 |
-
|
177 |
-
elif (
|
178 |
-
current_granularity in ["weekly", "monthly"]
|
179 |
-
and target_granularity == "daily"
|
180 |
-
):
|
181 |
-
# For higher to lower granularity, distribute numeric and replicate non-numeric values equally across the new period
|
182 |
-
expanded_data = []
|
183 |
-
for _, row in df.iterrows():
|
184 |
-
if current_granularity == "weekly":
|
185 |
-
period_range = pd.date_range(start=row.name, periods=7)
|
186 |
-
elif current_granularity == "monthly":
|
187 |
-
period_range = pd.date_range(
|
188 |
-
start=row.name, periods=row.name.days_in_month
|
189 |
-
)
|
190 |
-
|
191 |
-
for date in period_range:
|
192 |
-
new_row = {}
|
193 |
-
for col in df.columns:
|
194 |
-
if pd.api.types.is_numeric_dtype(df[col]):
|
195 |
-
if current_granularity == "weekly":
|
196 |
-
new_row[col] = row[col] / 7
|
197 |
-
elif current_granularity == "monthly":
|
198 |
-
new_row[col] = (
|
199 |
-
row[col] / row.name.days_in_month
|
200 |
-
)
|
201 |
-
else:
|
202 |
-
new_row[col] = row[col]
|
203 |
-
expanded_data.append((date, new_row))
|
204 |
-
|
205 |
-
resampled_df = pd.DataFrame(
|
206 |
-
[data for _, data in expanded_data],
|
207 |
-
index=[date for date, _ in expanded_data],
|
208 |
-
)
|
209 |
-
|
210 |
-
# Reset index
|
211 |
-
resampled_df = resampled_df.reset_index().rename(
|
212 |
-
columns={"index": "date"}
|
213 |
-
)
|
214 |
-
|
215 |
-
return resampled_df
|
216 |
-
|
217 |
-
# Function to clean and extract unique values of Panel_1 and Panel_2
|
218 |
-
st.cache_resource(show_spinner=False)
|
219 |
-
|
220 |
-
def clean_and_extract_unique_values(files_dict, selections):
|
221 |
-
all_panel1_values = set()
|
222 |
-
all_panel2_values = set()
|
223 |
-
|
224 |
-
for file_name, file_data in files_dict.items():
|
225 |
-
df = file_data["df"]
|
226 |
-
|
227 |
-
# 'Panel_1' and 'Panel_2' selections
|
228 |
-
selected_panel1 = selections[file_name].get("Panel_1")
|
229 |
-
selected_panel2 = selections[file_name].get("Panel_2")
|
230 |
-
|
231 |
-
# Clean and standardize Panel_1 column if it exists and is selected
|
232 |
-
if (
|
233 |
-
selected_panel1
|
234 |
-
and selected_panel1 != "N/A"
|
235 |
-
and selected_panel1 in df.columns
|
236 |
-
):
|
237 |
-
df[selected_panel1] = (
|
238 |
-
df[selected_panel1]
|
239 |
-
.str.lower()
|
240 |
-
.str.strip()
|
241 |
-
.str.replace("_", " ")
|
242 |
-
)
|
243 |
-
all_panel1_values.update(df[selected_panel1].dropna().unique())
|
244 |
-
|
245 |
-
# Clean and standardize Panel_2 column if it exists and is selected
|
246 |
-
if (
|
247 |
-
selected_panel2
|
248 |
-
and selected_panel2 != "N/A"
|
249 |
-
and selected_panel2 in df.columns
|
250 |
-
):
|
251 |
-
df[selected_panel2] = (
|
252 |
-
df[selected_panel2]
|
253 |
-
.str.lower()
|
254 |
-
.str.strip()
|
255 |
-
.str.replace("_", " ")
|
256 |
-
)
|
257 |
-
all_panel2_values.update(df[selected_panel2].dropna().unique())
|
258 |
-
|
259 |
-
# Update the processed DataFrame back in the dictionary
|
260 |
-
files_dict[file_name]["df"] = df
|
261 |
-
|
262 |
-
return all_panel1_values, all_panel2_values
|
263 |
-
|
264 |
-
# Function to format values for display
|
265 |
-
st.cache_resource(show_spinner=False)
|
266 |
-
|
267 |
-
def format_values_for_display(values_list):
|
268 |
-
# Capitalize the first letter of each word and replace underscores with spaces
|
269 |
-
formatted_list = [
|
270 |
-
value.replace("_", " ").title() for value in values_list
|
271 |
-
]
|
272 |
-
# Join values with commas and 'and' before the last value
|
273 |
-
if len(formatted_list) > 1:
|
274 |
-
return (
|
275 |
-
", ".join(formatted_list[:-1]) + ", and " + formatted_list[-1]
|
276 |
-
)
|
277 |
-
elif formatted_list:
|
278 |
-
return formatted_list[0]
|
279 |
-
return "No values available"
|
280 |
-
|
281 |
-
# Function to normalizes all data within files_dict to a daily granularity
|
282 |
-
st.cache(show_spinner=False, allow_output_mutation=True)
|
283 |
-
|
284 |
-
def standardize_data_to_daily(files_dict, selections):
|
285 |
-
# Normalize all data to a daily granularity using a provided function
|
286 |
-
files_dict = apply_granularity_to_all(files_dict, "daily", selections)
|
287 |
-
|
288 |
-
# Update the "interval" attribute for each dataset to indicate the new granularity
|
289 |
-
for files_name, files_data in files_dict.items():
|
290 |
-
files_data["interval"] = "daily"
|
291 |
-
|
292 |
-
return files_dict
|
293 |
-
|
294 |
-
# Function to apply granularity transformation to all DataFrames in files_dict
|
295 |
-
st.cache_resource(show_spinner=False)
|
296 |
-
|
297 |
-
def apply_granularity_to_all(
|
298 |
-
files_dict, granularity_selection, selections
|
299 |
-
):
|
300 |
-
for file_name, file_data in files_dict.items():
|
301 |
-
df = file_data["df"].copy()
|
302 |
-
|
303 |
-
# Handling when Panel_1 or Panel_2 might be 'N/A'
|
304 |
-
selected_panel1 = selections[file_name].get("Panel_1")
|
305 |
-
selected_panel2 = selections[file_name].get("Panel_2")
|
306 |
-
|
307 |
-
# Correcting the segment selection logic & handling 'N/A'
|
308 |
-
if selected_panel1 != "N/A" and selected_panel2 != "N/A":
|
309 |
-
unique_combinations = df[
|
310 |
-
[selected_panel1, selected_panel2]
|
311 |
-
].drop_duplicates()
|
312 |
-
elif selected_panel1 != "N/A":
|
313 |
-
unique_combinations = df[[selected_panel1]].drop_duplicates()
|
314 |
-
selected_panel2 = None # Ensure Panel_2 is ignored if N/A
|
315 |
-
elif selected_panel2 != "N/A":
|
316 |
-
unique_combinations = df[[selected_panel2]].drop_duplicates()
|
317 |
-
selected_panel1 = None # Ensure Panel_1 is ignored if N/A
|
318 |
-
else:
|
319 |
-
# If both are 'N/A', process the entire dataframe as is
|
320 |
-
df = adjust_dataframe_granularity(
|
321 |
-
df, file_data["interval"], granularity_selection
|
322 |
-
)
|
323 |
-
files_dict[file_name]["df"] = df
|
324 |
-
continue # Skip to the next file
|
325 |
-
|
326 |
-
transformed_segments = []
|
327 |
-
for _, combo in unique_combinations.iterrows():
|
328 |
-
if selected_panel1 and selected_panel2:
|
329 |
-
segment = df[
|
330 |
-
(df[selected_panel1] == combo[selected_panel1])
|
331 |
-
& (df[selected_panel2] == combo[selected_panel2])
|
332 |
-
]
|
333 |
-
elif selected_panel1:
|
334 |
-
segment = df[df[selected_panel1] == combo[selected_panel1]]
|
335 |
-
elif selected_panel2:
|
336 |
-
segment = df[df[selected_panel2] == combo[selected_panel2]]
|
337 |
-
|
338 |
-
# Adjust granularity of the segment
|
339 |
-
transformed_segment = adjust_dataframe_granularity(
|
340 |
-
segment, file_data["interval"], granularity_selection
|
341 |
-
)
|
342 |
-
transformed_segments.append(transformed_segment)
|
343 |
-
|
344 |
-
# Combine all transformed segments into a single DataFrame for this file
|
345 |
-
transformed_df = pd.concat(transformed_segments, ignore_index=True)
|
346 |
-
files_dict[file_name]["df"] = transformed_df
|
347 |
-
|
348 |
-
return files_dict
|
349 |
-
|
350 |
-
# Function to create main dataframe structure
|
351 |
-
st.cache_resource(show_spinner=False)
|
352 |
-
|
353 |
-
def create_main_dataframe(
|
354 |
-
files_dict, all_panel1_values, all_panel2_values, granularity_selection
|
355 |
-
):
|
356 |
-
# Determine the global start and end dates across all DataFrames
|
357 |
-
global_start = min(
|
358 |
-
df["df"]["date"].min() for df in files_dict.values()
|
359 |
-
)
|
360 |
-
global_end = max(df["df"]["date"].max() for df in files_dict.values())
|
361 |
-
|
362 |
-
# Adjust the date_range generation based on the granularity_selection
|
363 |
-
if granularity_selection == "weekly":
|
364 |
-
# Generate a weekly range, with weeks starting on Monday
|
365 |
-
date_range = pd.date_range(
|
366 |
-
start=global_start, end=global_end, freq="W-MON"
|
367 |
-
)
|
368 |
-
elif granularity_selection == "monthly":
|
369 |
-
# Generate a monthly range, starting from the first day of each month
|
370 |
-
date_range = pd.date_range(
|
371 |
-
start=global_start, end=global_end, freq="MS"
|
372 |
-
)
|
373 |
-
else: # Default to daily if not weekly or monthly
|
374 |
-
date_range = pd.date_range(
|
375 |
-
start=global_start, end=global_end, freq="D"
|
376 |
-
)
|
377 |
-
|
378 |
-
# Collect all unique Panel_1 and Panel_2 values, excluding 'N/A'
|
379 |
-
all_panel1s = all_panel1_values
|
380 |
-
all_panel2s = all_panel2_values
|
381 |
-
|
382 |
-
# Dynamically build the list of dimensions (Panel_1, Panel_2) to include in the main DataFrame based on availability
|
383 |
-
dimensions, merge_keys = [], []
|
384 |
-
if all_panel1s:
|
385 |
-
dimensions.append(all_panel1s)
|
386 |
-
merge_keys.append("Panel_1")
|
387 |
-
if all_panel2s:
|
388 |
-
dimensions.append(all_panel2s)
|
389 |
-
merge_keys.append("Panel_2")
|
390 |
-
|
391 |
-
dimensions.append(date_range) # Date range is always included
|
392 |
-
merge_keys.append("date") # Date range is always included
|
393 |
-
|
394 |
-
# Create a main DataFrame template with the dimensions
|
395 |
-
main_df = pd.MultiIndex.from_product(
|
396 |
-
dimensions,
|
397 |
-
names=[name for name, _ in zip(merge_keys, dimensions)],
|
398 |
-
).to_frame(index=False)
|
399 |
-
|
400 |
-
return main_df.reset_index(drop=True)
|
401 |
-
|
402 |
-
# Function to prepare and merge dataFrames
|
403 |
-
st.cache_resource(show_spinner=False)
|
404 |
-
|
405 |
-
def merge_into_main_df(main_df, files_dict, selections):
|
406 |
-
for file_name, file_data in files_dict.items():
|
407 |
-
df = file_data["df"].copy()
|
408 |
-
|
409 |
-
# Rename selected Panel_1 and Panel_2 columns if not 'N/A'
|
410 |
-
selected_panel1 = selections[file_name].get("Panel_1", "N/A")
|
411 |
-
selected_panel2 = selections[file_name].get("Panel_2", "N/A")
|
412 |
-
if selected_panel1 != "N/A":
|
413 |
-
df.rename(columns={selected_panel1: "Panel_1"}, inplace=True)
|
414 |
-
if selected_panel2 != "N/A":
|
415 |
-
df.rename(columns={selected_panel2: "Panel_2"}, inplace=True)
|
416 |
-
|
417 |
-
# Merge current DataFrame into main_df based on 'date', and where applicable, 'Panel_1' and 'Panel_2'
|
418 |
-
merge_keys = ["date"]
|
419 |
-
if "Panel_1" in df.columns:
|
420 |
-
merge_keys.append("Panel_1")
|
421 |
-
if "Panel_2" in df.columns:
|
422 |
-
merge_keys.append("Panel_2")
|
423 |
-
main_df = pd.merge(main_df, df, on=merge_keys, how="left")
|
424 |
-
|
425 |
-
# After all merges, sort by 'date' and reset index for cleanliness
|
426 |
-
sort_by = ["date"]
|
427 |
-
if "Panel_1" in main_df.columns:
|
428 |
-
sort_by.append("Panel_1")
|
429 |
-
if "Panel_2" in main_df.columns:
|
430 |
-
sort_by.append("Panel_2")
|
431 |
-
main_df.sort_values(by=sort_by, inplace=True)
|
432 |
-
main_df.reset_index(drop=True, inplace=True)
|
433 |
-
|
434 |
-
return main_df
|
435 |
-
|
436 |
-
# Function to categorize column
|
437 |
-
def categorize_column(column_name):
|
438 |
-
# Define keywords for each category
|
439 |
-
internal_keywords = [
|
440 |
-
"Price",
|
441 |
-
"Discount",
|
442 |
-
"product_price",
|
443 |
-
"cost",
|
444 |
-
"margin",
|
445 |
-
"inventory",
|
446 |
-
"sales",
|
447 |
-
"revenue",
|
448 |
-
"turnover",
|
449 |
-
"expense",
|
450 |
-
]
|
451 |
-
exogenous_keywords = [
|
452 |
-
"GDP",
|
453 |
-
"Tax",
|
454 |
-
"Inflation",
|
455 |
-
"interest_rate",
|
456 |
-
"employment_rate",
|
457 |
-
"exchange_rate",
|
458 |
-
"consumer_spending",
|
459 |
-
"retail_sales",
|
460 |
-
"oil_prices",
|
461 |
-
"weather",
|
462 |
-
]
|
463 |
-
|
464 |
-
# Check if the column name matches any of the keywords for Internal or Exogenous categories
|
465 |
-
|
466 |
-
if (
|
467 |
-
column_name
|
468 |
-
in st.session_state["project_dct"]["data_import"]["cat_dct"].keys()
|
469 |
-
and st.session_state["project_dct"]["data_import"]["cat_dct"][
|
470 |
-
column_name
|
471 |
-
]
|
472 |
-
is not None
|
473 |
-
):
|
474 |
-
|
475 |
-
return st.session_state["project_dct"]["data_import"]["cat_dct"][
|
476 |
-
column_name
|
477 |
-
] # resume project manoj
|
478 |
-
|
479 |
-
else:
|
480 |
-
for keyword in internal_keywords:
|
481 |
-
if keyword.lower() in column_name.lower():
|
482 |
-
return "Internal"
|
483 |
-
for keyword in exogenous_keywords:
|
484 |
-
if keyword.lower() in column_name.lower():
|
485 |
-
return "Exogenous"
|
486 |
-
|
487 |
-
# Default to Media if no match found
|
488 |
-
return "Media"
|
489 |
-
|
490 |
-
# Function to calculate missing stats and prepare for editable DataFrame
|
491 |
-
st.cache_resource(show_spinner=False)
|
492 |
-
|
493 |
-
def prepare_missing_stats_df(df):
|
494 |
-
missing_stats = []
|
495 |
-
for column in df.columns:
|
496 |
-
if (
|
497 |
-
column == "date" or column == "Panel_2" or column == "Panel_1"
|
498 |
-
): # Skip Date, Panel_1 and Panel_2 column
|
499 |
-
continue
|
500 |
-
|
501 |
-
missing = df[column].isnull().sum()
|
502 |
-
pct_missing = round((missing / len(df)) * 100, 2)
|
503 |
-
|
504 |
-
# Dynamically assign category based on column name
|
505 |
-
category = categorize_column(column)
|
506 |
-
# category = "Media" # Keep default bin as Media
|
507 |
-
|
508 |
-
missing_stats.append(
|
509 |
-
{
|
510 |
-
"Column": column,
|
511 |
-
"Missing Values": missing,
|
512 |
-
"Missing Percentage": pct_missing,
|
513 |
-
"Impute Method": "Fill with 0", # Default value
|
514 |
-
"Category": category,
|
515 |
-
}
|
516 |
-
)
|
517 |
-
stats_df = pd.DataFrame(missing_stats)
|
518 |
-
|
519 |
-
return stats_df
|
520 |
-
|
521 |
-
# Function to add API DataFrame details to the files dictionary
|
522 |
-
st.cache_resource(show_spinner=False)
|
523 |
-
|
524 |
-
def add_api_dataframe_to_dict(main_df, files_dict):
|
525 |
-
files_dict["API"] = {
|
526 |
-
"numeric": list(main_df.select_dtypes(include=["number"]).columns),
|
527 |
-
"non_numeric": [
|
528 |
-
col
|
529 |
-
for col in main_df.select_dtypes(exclude=["number"]).columns
|
530 |
-
if col.lower() != "date"
|
531 |
-
],
|
532 |
-
"interval": determine_data_interval(
|
533 |
-
pd.Series(main_df["date"].unique())
|
534 |
-
.diff()
|
535 |
-
.dt.days.dropna()
|
536 |
-
.mode()[0]
|
537 |
-
),
|
538 |
-
"df": main_df,
|
539 |
-
}
|
540 |
-
|
541 |
-
return files_dict
|
542 |
-
|
543 |
-
# Function to reads an API into a DataFrame, parsing specified columns as datetime
|
544 |
-
@st.cache_resource(show_spinner=False)
|
545 |
-
def read_API_data():
|
546 |
-
return pd.read_excel(
|
547 |
-
r"./upf_data_converted_randomized_resp_metrics.xlsx",
|
548 |
-
parse_dates=["Date"],
|
549 |
-
)
|
550 |
-
|
551 |
-
# Function to set the 'Panel_1_Panel_2_Selected' session state variable to False
|
552 |
-
def set_Panel_1_Panel_2_Selected_false():
|
553 |
-
|
554 |
-
st.session_state["Panel_1_Panel_2_Selected"] = False
|
555 |
-
|
556 |
-
# restoring project_dct to default values when user modify any widjets
|
557 |
-
st.session_state["project_dct"]["data_import"][
|
558 |
-
"edited_stats_df"
|
559 |
-
] = None
|
560 |
-
st.session_state["project_dct"]["data_import"]["merged_df"] = None
|
561 |
-
st.session_state["project_dct"]["data_import"][
|
562 |
-
"missing_stats_df"
|
563 |
-
] = None
|
564 |
-
st.session_state["project_dct"]["data_import"]["cat_dct"] = {}
|
565 |
-
st.session_state["project_dct"]["data_import"][
|
566 |
-
"numeric_columns"
|
567 |
-
] = None
|
568 |
-
st.session_state["project_dct"]["data_import"]["default_df"] = None
|
569 |
-
st.session_state["project_dct"]["data_import"]["final_df"] = None
|
570 |
-
st.session_state["project_dct"]["data_import"]["edited_df"] = None
|
571 |
-
|
572 |
-
# Function to serialize and save the objects into a pickle file
|
573 |
-
@st.cache_resource(show_spinner=False)
|
574 |
-
def save_to_pickle(file_path, final_df, bin_dict):
|
575 |
-
# Open the file in write-binary mode and dump the objects
|
576 |
-
with open(file_path, "wb") as f:
|
577 |
-
pickle.dump({"final_df": final_df, "bin_dict": bin_dict}, f)
|
578 |
-
# Data is now saved to file
|
579 |
-
|
580 |
-
# Function to processes the merged_df DataFrame based on operations defined in edited_df
|
581 |
-
@st.cache_resource(show_spinner=False)
|
582 |
-
def process_dataframes(merged_df, edited_df, edited_stats_df):
|
583 |
-
# Ensure there are operations defined by the user
|
584 |
-
if edited_df.empty:
|
585 |
-
|
586 |
-
return merged_df, edited_stats_df # No operations to apply
|
587 |
-
|
588 |
-
# Perform operations as defined by the user
|
589 |
-
else:
|
590 |
-
|
591 |
-
for index, row in edited_df.iterrows():
|
592 |
-
result_column_name = (
|
593 |
-
f"{row['Column 1']}{row['Operator']}{row['Column 2']}"
|
594 |
-
)
|
595 |
-
col1 = row["Column 1"]
|
596 |
-
col2 = row["Column 2"]
|
597 |
-
op = row["Operator"]
|
598 |
-
|
599 |
-
# Apply the specified operation
|
600 |
-
if op == "+":
|
601 |
-
merged_df[result_column_name] = (
|
602 |
-
merged_df[col1] + merged_df[col2]
|
603 |
-
)
|
604 |
-
elif op == "-":
|
605 |
-
merged_df[result_column_name] = (
|
606 |
-
merged_df[col1] - merged_df[col2]
|
607 |
-
)
|
608 |
-
elif op == "*":
|
609 |
-
merged_df[result_column_name] = (
|
610 |
-
merged_df[col1] * merged_df[col2]
|
611 |
-
)
|
612 |
-
elif op == "/":
|
613 |
-
merged_df[result_column_name] = merged_df[
|
614 |
-
col1
|
615 |
-
] / merged_df[col2].replace(0, 1e-9)
|
616 |
-
|
617 |
-
# Add summary of operation to edited_stats_df
|
618 |
-
new_row = {
|
619 |
-
"Column": result_column_name,
|
620 |
-
"Missing Values": None,
|
621 |
-
"Missing Percentage": None,
|
622 |
-
"Impute Method": None,
|
623 |
-
"Category": row["Category"],
|
624 |
-
}
|
625 |
-
new_row_df = pd.DataFrame([new_row])
|
626 |
-
|
627 |
-
# Use pd.concat to add the new_row_df to edited_stats_df
|
628 |
-
edited_stats_df = pd.concat(
|
629 |
-
[edited_stats_df, new_row_df], ignore_index=True, axis=0
|
630 |
-
)
|
631 |
-
|
632 |
-
# Combine column names from edited_df for cleanup
|
633 |
-
combined_columns = set(edited_df["Column 1"]).union(
|
634 |
-
set(edited_df["Column 2"])
|
635 |
-
)
|
636 |
-
|
637 |
-
# Filter out rows in edited_stats_df and drop columns from merged_df
|
638 |
-
edited_stats_df = edited_stats_df[
|
639 |
-
~edited_stats_df["Column"].isin(combined_columns)
|
640 |
-
]
|
641 |
-
merged_df.drop(
|
642 |
-
columns=list(combined_columns), errors="ignore", inplace=True
|
643 |
-
)
|
644 |
-
|
645 |
-
return merged_df, edited_stats_df
|
646 |
-
|
647 |
-
# Function to prepare a list of numeric column names and initialize an empty DataFrame with predefined structure
|
648 |
-
st.cache_resource(show_spinner=False)
|
649 |
-
|
650 |
-
def prepare_numeric_columns_and_default_df(merged_df, edited_stats_df):
|
651 |
-
# Get columns categorized as 'Response Metrics'
|
652 |
-
columns_response_metrics = edited_stats_df[
|
653 |
-
edited_stats_df["Category"] == "Response Metrics"
|
654 |
-
]["Column"].tolist()
|
655 |
-
|
656 |
-
# Filter numeric columns, excluding those categorized as 'Response Metrics'
|
657 |
-
numeric_columns = [
|
658 |
-
col
|
659 |
-
for col in merged_df.select_dtypes(include=["number"]).columns
|
660 |
-
if col not in columns_response_metrics
|
661 |
-
]
|
662 |
-
|
663 |
-
# Define the structure of the empty DataFrame
|
664 |
-
data = {
|
665 |
-
"Column 1": pd.Series([], dtype="str"),
|
666 |
-
"Operator": pd.Series([], dtype="str"),
|
667 |
-
"Column 2": pd.Series([], dtype="str"),
|
668 |
-
"Category": pd.Series([], dtype="str"),
|
669 |
-
}
|
670 |
-
default_df = pd.DataFrame(data)
|
671 |
-
|
672 |
-
return numeric_columns, default_df
|
673 |
-
|
674 |
-
# function to reset to default values in project_dct:
|
675 |
-
|
676 |
-
# Initialize 'final_df' in session state
|
677 |
-
if "final_df" not in st.session_state:
|
678 |
-
st.session_state["final_df"] = pd.DataFrame()
|
679 |
-
|
680 |
-
# Initialize 'bin_dict' in session state
|
681 |
-
if "bin_dict" not in st.session_state:
|
682 |
-
st.session_state["bin_dict"] = {}
|
683 |
-
|
684 |
-
# Initialize 'Panel_1_Panel_2_Selected' in session state
|
685 |
-
if "Panel_1_Panel_2_Selected" not in st.session_state:
|
686 |
-
st.session_state["Panel_1_Panel_2_Selected"] = False
|
687 |
-
|
688 |
-
# Page Title
|
689 |
-
st.write("") # Top padding
|
690 |
-
st.title("Data Import")
|
691 |
-
|
692 |
-
conn = sqlite3.connect(
|
693 |
-
r"DB\User.db", check_same_thread=False
|
694 |
-
) # connection with sql db
|
695 |
-
c = conn.cursor()
|
696 |
-
|
697 |
-
#########################################################################################################################################################
|
698 |
-
# Create a dictionary to hold all DataFrames and collect user input to specify "Panel_2" and "Panel_1" columns for each file
|
699 |
-
#########################################################################################################################################################
|
700 |
-
|
701 |
-
# Read the Excel file, parsing 'Date' column as datetime
|
702 |
-
main_df = read_API_data()
|
703 |
-
|
704 |
-
# Convert all column names to lowercase
|
705 |
-
main_df.columns = main_df.columns.str.lower().str.strip()
|
706 |
-
|
707 |
-
# File uploader
|
708 |
-
uploaded_files = st.file_uploader(
|
709 |
-
"Upload additional data",
|
710 |
-
type=["xlsx"],
|
711 |
-
accept_multiple_files=True,
|
712 |
-
on_change=set_Panel_1_Panel_2_Selected_false,
|
713 |
-
)
|
714 |
-
|
715 |
-
# Custom HTML for upload instructions
|
716 |
-
recommendation_html = f"""
|
717 |
-
<div style="text-align: justify;">
|
718 |
-
<strong>Recommendation:</strong> For optimal processing, please ensure that all uploaded datasets including panel, media, internal, and exogenous data adhere to the following guidelines: Each dataset must include a <code>Date</code> column formatted as <code>DD-MM-YYYY</code>, be free of missing values.
|
719 |
-
</div>
|
720 |
-
"""
|
721 |
-
st.markdown(recommendation_html, unsafe_allow_html=True)
|
722 |
-
|
723 |
-
# Choose Desired Granularity
|
724 |
-
st.markdown("#### Choose Desired Granularity")
|
725 |
-
# Granularity Selection
|
726 |
-
|
727 |
-
granularity_selection = st.selectbox(
|
728 |
-
"Choose Date Granularity",
|
729 |
-
["Daily", "Weekly", "Monthly"],
|
730 |
-
label_visibility="collapsed",
|
731 |
-
on_change=set_Panel_1_Panel_2_Selected_false,
|
732 |
-
index=st.session_state["project_dct"]["data_import"][
|
733 |
-
"granularity_selection"
|
734 |
-
], # resume
|
735 |
-
)
|
736 |
-
|
737 |
-
# st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
|
738 |
-
|
739 |
-
st.session_state["project_dct"]["data_import"]["granularity_selection"] = [
|
740 |
-
"Daily",
|
741 |
-
"Weekly",
|
742 |
-
"Monthly",
|
743 |
-
].index(granularity_selection)
|
744 |
-
# st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
|
745 |
-
granularity_selection = str(granularity_selection).lower()
|
746 |
-
|
747 |
-
# Convert files to dataframes
|
748 |
-
files_dict = files_to_dataframes(uploaded_files)
|
749 |
-
|
750 |
-
# Add API Dataframe
|
751 |
-
if main_df is not None:
|
752 |
-
files_dict = add_api_dataframe_to_dict(main_df, files_dict)
|
753 |
-
|
754 |
-
# Display a warning message if no files have been uploaded and halt further execution
|
755 |
-
if not files_dict:
|
756 |
-
st.warning(
|
757 |
-
"Please upload at least one file to proceed.",
|
758 |
-
icon="⚠️",
|
759 |
-
)
|
760 |
-
st.stop() # Halts further execution until file is uploaded
|
761 |
-
|
762 |
-
# Select Panel_1 and Panel_2 columns
|
763 |
-
st.markdown("#### Select Panel columns")
|
764 |
-
selections = {}
|
765 |
-
with st.expander("Select Panel columns", expanded=False):
|
766 |
-
count = (
|
767 |
-
0 # Initialize counter to manage the visibility of labels and keys
|
768 |
-
)
|
769 |
-
for file_name, file_data in files_dict.items():
|
770 |
-
|
771 |
-
# generatimg project dct keys dynamically
|
772 |
-
if (
|
773 |
-
f"Panel_1_selectbox{file_name}"
|
774 |
-
not in st.session_state["project_dct"]["data_import"].keys()
|
775 |
-
):
|
776 |
-
st.session_state["project_dct"]["data_import"][
|
777 |
-
f"Panel_1_selectbox{file_name}"
|
778 |
-
] = 0
|
779 |
-
|
780 |
-
if (
|
781 |
-
f"Panel_2_selectbox{file_name}"
|
782 |
-
not in st.session_state["project_dct"]["data_import"].keys()
|
783 |
-
):
|
784 |
-
|
785 |
-
st.session_state["project_dct"]["data_import"][
|
786 |
-
f"Panel_2_selectbox{file_name}"
|
787 |
-
] = 0
|
788 |
-
|
789 |
-
# Determine visibility of the label based on the count
|
790 |
-
if count == 0:
|
791 |
-
label_visibility = "visible"
|
792 |
-
else:
|
793 |
-
label_visibility = "collapsed"
|
794 |
-
|
795 |
-
# Extract non-numeric columns
|
796 |
-
non_numeric_cols = file_data["non_numeric"]
|
797 |
-
|
798 |
-
# Prepare Panel_1 and Panel_2 values for dropdown, adding "N/A" as an option
|
799 |
-
panel1_values = non_numeric_cols + ["N/A"]
|
800 |
-
panel2_values = non_numeric_cols + ["N/A"]
|
801 |
-
|
802 |
-
# Skip if only one option is available
|
803 |
-
if len(panel1_values) == 1 and len(panel2_values) == 1:
|
804 |
-
selected_panel1, selected_panel2 = "N/A", "N/A"
|
805 |
-
# Update the selections for Panel_1 and Panel_2 for the current file
|
806 |
-
selections[file_name] = {
|
807 |
-
"Panel_1": selected_panel1,
|
808 |
-
"Panel_2": selected_panel2,
|
809 |
-
}
|
810 |
-
continue
|
811 |
-
|
812 |
-
# Create layout columns for File Name, Panel_2, and Panel_1 selections
|
813 |
-
file_name_col, Panel_1_col, Panel_2_col = st.columns([2, 4, 4])
|
814 |
-
|
815 |
-
with file_name_col:
|
816 |
-
# Display "File Name" label only for the first file
|
817 |
-
if count == 0:
|
818 |
-
st.write("File Name")
|
819 |
-
else:
|
820 |
-
st.write("")
|
821 |
-
st.write(file_name) # Display the file name
|
822 |
-
|
823 |
-
with Panel_1_col:
|
824 |
-
# Display a selectbox for Panel_1 values
|
825 |
-
selected_panel1 = st.selectbox(
|
826 |
-
"Select Panel Level 1",
|
827 |
-
panel2_values,
|
828 |
-
on_change=set_Panel_1_Panel_2_Selected_false,
|
829 |
-
label_visibility=label_visibility, # Control visibility of the label
|
830 |
-
key=f"Panel_1_selectbox{count}", # Ensure unique key for each selectbox
|
831 |
-
index=st.session_state["project_dct"]["data_import"][
|
832 |
-
f"Panel_1_selectbox{file_name}"
|
833 |
-
],
|
834 |
-
)
|
835 |
-
|
836 |
-
st.session_state["project_dct"]["data_import"][
|
837 |
-
f"Panel_1_selectbox{file_name}"
|
838 |
-
] = panel2_values.index(selected_panel1)
|
839 |
-
|
840 |
-
with Panel_2_col:
|
841 |
-
# Display a selectbox for Panel_2 values
|
842 |
-
selected_panel2 = st.selectbox(
|
843 |
-
"Select Panel Level 2",
|
844 |
-
panel1_values,
|
845 |
-
on_change=set_Panel_1_Panel_2_Selected_false,
|
846 |
-
label_visibility=label_visibility, # Control visibility of the label
|
847 |
-
key=f"Panel_2_selectbox{count}", # Ensure unique key for each selectbox
|
848 |
-
index=st.session_state["project_dct"]["data_import"][
|
849 |
-
f"Panel_2_selectbox{file_name}"
|
850 |
-
],
|
851 |
-
)
|
852 |
-
|
853 |
-
st.session_state["project_dct"]["data_import"][
|
854 |
-
f"Panel_2_selectbox{file_name}"
|
855 |
-
] = panel1_values.index(selected_panel2)
|
856 |
-
|
857 |
-
# st.write(st.session_state['project_dct']['data_import'][f"Panel_2_selectbox{file_name}"])
|
858 |
-
|
859 |
-
# Skip processing if the same column is selected for both Panel_1 and Panel_2 due to potential data integrity issues
|
860 |
-
|
861 |
-
if selected_panel2 == selected_panel1 and not (
|
862 |
-
selected_panel2 == "N/A" and selected_panel1 == "N/A"
|
863 |
-
):
|
864 |
-
st.warning(
|
865 |
-
f"File: {file_name} → The same column cannot serve as both Panel_1 and Panel_2. Please adjust your selections.",
|
866 |
-
)
|
867 |
-
selected_panel1, selected_panel2 = "N/A", "N/A"
|
868 |
-
st.stop()
|
869 |
-
|
870 |
-
# Update the selections for Panel_1 and Panel_2 for the current file
|
871 |
-
selections[file_name] = {
|
872 |
-
"Panel_1": selected_panel1,
|
873 |
-
"Panel_2": selected_panel2,
|
874 |
-
}
|
875 |
-
|
876 |
-
count += 1 # Increment the counter after processing each file
|
877 |
-
st.write()
|
878 |
-
# Accept Panel_1 and Panel_2 selection
|
879 |
-
accept = st.button(
|
880 |
-
"Accept and Process", use_container_width=True
|
881 |
-
) # resume project manoj
|
882 |
-
|
883 |
-
if (
|
884 |
-
accept == False
|
885 |
-
and st.session_state["project_dct"]["data_import"]["edited_stats_df"]
|
886 |
-
is not None
|
887 |
-
):
|
888 |
-
|
889 |
-
# st.write(st.session_state['project_dct'])
|
890 |
-
st.markdown("#### Unique Panel values")
|
891 |
-
# Display Panel_1 and Panel_2 values
|
892 |
-
with st.expander("Unique Panel values"):
|
893 |
-
st.write("")
|
894 |
-
st.markdown(
|
895 |
-
f"""
|
896 |
-
<style>
|
897 |
-
.justify-text {{
|
898 |
-
text-align: justify;
|
899 |
-
}}
|
900 |
-
</style>
|
901 |
-
<div class="justify-text">
|
902 |
-
<strong>Panel Level 1 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel1_values']}<br>
|
903 |
-
<strong>Panel Level 2 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel2_values']}
|
904 |
-
</div>
|
905 |
-
""",
|
906 |
-
unsafe_allow_html=True,
|
907 |
-
)
|
908 |
-
|
909 |
-
# Display total Panel_1 and Panel_2
|
910 |
-
st.write("")
|
911 |
-
st.markdown(
|
912 |
-
f"""
|
913 |
-
<div style="text-align: justify;">
|
914 |
-
<strong>Number of Level 1 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}<br>
|
915 |
-
<strong>Number of Level 2 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}
|
916 |
-
</div>
|
917 |
-
""",
|
918 |
-
unsafe_allow_html=True,
|
919 |
-
)
|
920 |
-
st.write("")
|
921 |
-
|
922 |
-
# Create an editable DataFrame in Streamlit
|
923 |
-
|
924 |
-
st.markdown("#### Select Variables Category & Impute Missing Values")
|
925 |
-
|
926 |
-
# data_temp_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
|
927 |
-
|
928 |
-
# with open(data_temp_path,"rb") as f:
|
929 |
-
# saved_edited_stats_df=pickle.load(f)
|
930 |
-
|
931 |
-
# a=st.data_editor(saved_edited_stats_df)
|
932 |
-
|
933 |
-
merged_df = st.session_state["project_dct"]["data_import"][
|
934 |
-
"merged_df"
|
935 |
-
].copy()
|
936 |
-
|
937 |
-
missing_stats_df = st.session_state["project_dct"]["data_import"][
|
938 |
-
"missing_stats_df"
|
939 |
-
]
|
940 |
-
|
941 |
-
edited_stats_df = st.data_editor(
|
942 |
-
st.session_state["project_dct"]["data_import"]["edited_stats_df"],
|
943 |
-
column_config={
|
944 |
-
"Impute Method": st.column_config.SelectboxColumn(
|
945 |
-
options=[
|
946 |
-
"Drop Column",
|
947 |
-
"Fill with Mean",
|
948 |
-
"Fill with Median",
|
949 |
-
"Fill with 0",
|
950 |
-
],
|
951 |
-
required=True,
|
952 |
-
default="Fill with 0",
|
953 |
-
),
|
954 |
-
"Category": st.column_config.SelectboxColumn(
|
955 |
-
options=[
|
956 |
-
"Media",
|
957 |
-
"Exogenous",
|
958 |
-
"Internal",
|
959 |
-
"Response Metrics",
|
960 |
-
],
|
961 |
-
required=True,
|
962 |
-
default="Media",
|
963 |
-
),
|
964 |
-
},
|
965 |
-
disabled=["Column", "Missing Values", "Missing Percentage"],
|
966 |
-
hide_index=True,
|
967 |
-
use_container_width=True,
|
968 |
-
key="data-editor-1",
|
969 |
-
)
|
970 |
-
|
971 |
-
st.session_state["project_dct"]["data_import"]["cat_dct"] = {
|
972 |
-
col: cat
|
973 |
-
for col, cat in zip(
|
974 |
-
edited_stats_df["Column"], edited_stats_df["Category"]
|
975 |
-
)
|
976 |
-
}
|
977 |
-
|
978 |
-
for i, row in edited_stats_df.iterrows():
|
979 |
-
column = row["Column"]
|
980 |
-
if row["Impute Method"] == "Drop Column":
|
981 |
-
merged_df.drop(columns=[column], inplace=True)
|
982 |
-
|
983 |
-
elif row["Impute Method"] == "Fill with Mean":
|
984 |
-
merged_df[column].fillna(
|
985 |
-
st.session_state["project_dct"]["data_import"][
|
986 |
-
"merged_df"
|
987 |
-
][column].mean(),
|
988 |
-
inplace=True,
|
989 |
-
)
|
990 |
-
|
991 |
-
elif row["Impute Method"] == "Fill with Median":
|
992 |
-
merged_df[column].fillna(
|
993 |
-
st.session_state["project_dct"]["data_import"][
|
994 |
-
"merged_df"
|
995 |
-
][column].median(),
|
996 |
-
inplace=True,
|
997 |
-
)
|
998 |
-
|
999 |
-
elif row["Impute Method"] == "Fill with 0":
|
1000 |
-
merged_df[column].fillna(0, inplace=True)
|
1001 |
-
|
1002 |
-
# st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
|
1003 |
-
#########################################################################################################################################################
|
1004 |
-
# Group columns
|
1005 |
-
#########################################################################################################################################################
|
1006 |
-
|
1007 |
-
# Display Group columns header
|
1008 |
-
numeric_columns = st.session_state["project_dct"]["data_import"][
|
1009 |
-
"numeric_columns"
|
1010 |
-
]
|
1011 |
-
default_df = st.session_state["project_dct"]["data_import"][
|
1012 |
-
"default_df"
|
1013 |
-
]
|
1014 |
-
|
1015 |
-
st.markdown("#### Feature engineering")
|
1016 |
-
|
1017 |
-
edited_df = st.data_editor(
|
1018 |
-
st.session_state["project_dct"]["data_import"]["edited_df"],
|
1019 |
-
column_config={
|
1020 |
-
"Column 1": st.column_config.SelectboxColumn(
|
1021 |
-
options=numeric_columns,
|
1022 |
-
required=True,
|
1023 |
-
width=400,
|
1024 |
-
),
|
1025 |
-
"Operator": st.column_config.SelectboxColumn(
|
1026 |
-
options=["+", "-", "*", "/"],
|
1027 |
-
required=True,
|
1028 |
-
default="+",
|
1029 |
-
width=100,
|
1030 |
-
),
|
1031 |
-
"Column 2": st.column_config.SelectboxColumn(
|
1032 |
-
options=numeric_columns,
|
1033 |
-
required=True,
|
1034 |
-
default=numeric_columns[0],
|
1035 |
-
width=400,
|
1036 |
-
),
|
1037 |
-
"Category": st.column_config.SelectboxColumn(
|
1038 |
-
options=[
|
1039 |
-
"Media",
|
1040 |
-
"Exogenous",
|
1041 |
-
"Internal",
|
1042 |
-
"Response Metrics",
|
1043 |
-
],
|
1044 |
-
required=True,
|
1045 |
-
default="Media",
|
1046 |
-
width=200,
|
1047 |
-
),
|
1048 |
-
},
|
1049 |
-
num_rows="dynamic",
|
1050 |
-
key="data-editor-4",
|
1051 |
-
)
|
1052 |
-
|
1053 |
-
final_df, edited_stats_df = process_dataframes(
|
1054 |
-
merged_df, edited_df, edited_stats_df
|
1055 |
-
)
|
1056 |
-
|
1057 |
-
st.markdown("#### Final DataFrame")
|
1058 |
-
st.dataframe(final_df, hide_index=True)
|
1059 |
-
|
1060 |
-
# Initialize an empty dictionary to hold categories and their variables
|
1061 |
-
category_dict = {}
|
1062 |
-
|
1063 |
-
# Iterate over each row in the edited DataFrame to populate the dictionary
|
1064 |
-
for i, row in edited_stats_df.iterrows():
|
1065 |
-
column = row["Column"]
|
1066 |
-
category = row[
|
1067 |
-
"Category"
|
1068 |
-
] # The category chosen by the user for this variable
|
1069 |
-
|
1070 |
-
# Check if the category already exists in the dictionary
|
1071 |
-
if category not in category_dict:
|
1072 |
-
# If not, initialize it with the current column as its first element
|
1073 |
-
category_dict[category] = [column]
|
1074 |
-
else:
|
1075 |
-
# If it exists, append the current column to the list of variables under this category
|
1076 |
-
category_dict[category].append(column)
|
1077 |
-
|
1078 |
-
# Add Date, Panel_1 and Panel_12 in category dictionary
|
1079 |
-
category_dict.update({"Date": ["date"]})
|
1080 |
-
if "Panel_1" in final_df.columns:
|
1081 |
-
category_dict["Panel Level 1"] = ["Panel_1"]
|
1082 |
-
if "Panel_2" in final_df.columns:
|
1083 |
-
category_dict["Panel Level 2"] = ["Panel_2"]
|
1084 |
-
|
1085 |
-
# Display the dictionary
|
1086 |
-
st.markdown("#### Variable Category")
|
1087 |
-
for category, variables in category_dict.items():
|
1088 |
-
# Check if there are multiple variables to handle "and" insertion correctly
|
1089 |
-
if len(variables) > 1:
|
1090 |
-
# Join all but the last variable with ", ", then add " and " before the last variable
|
1091 |
-
variables_str = (
|
1092 |
-
", ".join(variables[:-1]) + " and " + variables[-1]
|
1093 |
-
)
|
1094 |
-
else:
|
1095 |
-
# If there's only one variable, no need for "and"
|
1096 |
-
variables_str = variables[0]
|
1097 |
-
|
1098 |
-
# Display the category and its variables in the desired format
|
1099 |
-
st.markdown(
|
1100 |
-
f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
|
1101 |
-
unsafe_allow_html=True,
|
1102 |
-
)
|
1103 |
-
|
1104 |
-
# Function to check if Response Metrics is selected
|
1105 |
-
st.write("")
|
1106 |
-
response_metrics_col = category_dict.get("Response Metrics", [])
|
1107 |
-
if len(response_metrics_col) == 0:
|
1108 |
-
st.warning("Please select Response Metrics column", icon="⚠️")
|
1109 |
-
st.stop()
|
1110 |
-
# elif len(response_metrics_col) > 1:
|
1111 |
-
# st.warning("Please select only one Response Metrics column", icon="⚠️")
|
1112 |
-
# st.stop()
|
1113 |
-
|
1114 |
-
# Store final dataframe and bin dictionary into session state
|
1115 |
-
st.session_state["final_df"], st.session_state["bin_dict"] = (
|
1116 |
-
final_df,
|
1117 |
-
category_dict,
|
1118 |
-
)
|
1119 |
-
|
1120 |
-
# Save the DataFrame and dictionary from the session state to the pickle file
|
1121 |
-
if st.button(
|
1122 |
-
"Accept and Save",
|
1123 |
-
use_container_width=True,
|
1124 |
-
key="data-editor-button",
|
1125 |
-
):
|
1126 |
-
print("test*************")
|
1127 |
-
update_db("1_Data_Import.py")
|
1128 |
-
final_df = final_df.loc[:, ~final_df.columns.duplicated()]
|
1129 |
-
|
1130 |
-
project_dct_path = os.path.join(
|
1131 |
-
st.session_state["project_path"], "project_dct.pkl"
|
1132 |
-
)
|
1133 |
-
|
1134 |
-
with open(project_dct_path, "wb") as f:
|
1135 |
-
pickle.dump(st.session_state["project_dct"], f)
|
1136 |
-
|
1137 |
-
data_path = os.path.join(
|
1138 |
-
st.session_state["project_path"], "data_import.pkl"
|
1139 |
-
)
|
1140 |
-
|
1141 |
-
st.session_state["data_path"] = data_path
|
1142 |
-
|
1143 |
-
save_to_pickle(
|
1144 |
-
data_path,
|
1145 |
-
st.session_state["final_df"],
|
1146 |
-
st.session_state["bin_dict"],
|
1147 |
-
)
|
1148 |
-
|
1149 |
-
st.session_state["project_dct"]["data_import"][
|
1150 |
-
"edited_stats_df"
|
1151 |
-
] = edited_stats_df
|
1152 |
-
st.session_state["project_dct"]["data_import"][
|
1153 |
-
"merged_df"
|
1154 |
-
] = merged_df
|
1155 |
-
st.session_state["project_dct"]["data_import"][
|
1156 |
-
"missing_stats_df"
|
1157 |
-
] = missing_stats_df
|
1158 |
-
st.session_state["project_dct"]["data_import"]["cat_dct"] = {
|
1159 |
-
col: cat
|
1160 |
-
for col, cat in zip(
|
1161 |
-
edited_stats_df["Column"], edited_stats_df["Category"]
|
1162 |
-
)
|
1163 |
-
}
|
1164 |
-
st.session_state["project_dct"]["data_import"][
|
1165 |
-
"numeric_columns"
|
1166 |
-
] = numeric_columns
|
1167 |
-
st.session_state["project_dct"]["data_import"][
|
1168 |
-
"default_df"
|
1169 |
-
] = default_df
|
1170 |
-
st.session_state["project_dct"]["data_import"][
|
1171 |
-
"final_df"
|
1172 |
-
] = final_df
|
1173 |
-
st.session_state["project_dct"]["data_import"][
|
1174 |
-
"edited_df"
|
1175 |
-
] = edited_df
|
1176 |
-
|
1177 |
-
st.toast("💾 Saved Successfully!")
|
1178 |
-
|
1179 |
-
if accept:
|
1180 |
-
# Normalize all data to a daily granularity. This initial standardization simplifies subsequent conversions to other levels of granularity
|
1181 |
-
with st.spinner("Processing..."):
|
1182 |
-
files_dict = standardize_data_to_daily(files_dict, selections)
|
1183 |
-
|
1184 |
-
# Convert all data to daily level granularity
|
1185 |
-
files_dict = apply_granularity_to_all(
|
1186 |
-
files_dict, granularity_selection, selections
|
1187 |
-
)
|
1188 |
-
|
1189 |
-
# Update the 'files_dict' in the session state
|
1190 |
-
st.session_state["files_dict"] = files_dict
|
1191 |
-
|
1192 |
-
# Set a flag in the session state to indicate that selection has been made
|
1193 |
-
st.session_state["Panel_1_Panel_2_Selected"] = True
|
1194 |
-
|
1195 |
-
#########################################################################################################################################################
|
1196 |
-
# Display unique Panel_1 and Panel_2 values
|
1197 |
-
#########################################################################################################################################################
|
1198 |
-
|
1199 |
-
# Halts further execution until Panel_1 and Panel_2 columns are selected
|
1200 |
-
if (
|
1201 |
-
st.session_state["project_dct"]["data_import"]["edited_stats_df"]
|
1202 |
-
is None
|
1203 |
-
):
|
1204 |
-
|
1205 |
-
if (
|
1206 |
-
"files_dict" in st.session_state
|
1207 |
-
and st.session_state["Panel_1_Panel_2_Selected"]
|
1208 |
-
):
|
1209 |
-
files_dict = st.session_state["files_dict"]
|
1210 |
-
|
1211 |
-
st.session_state["project_dct"]["data_import"][
|
1212 |
-
"files_dict"
|
1213 |
-
] = files_dict # resume
|
1214 |
-
else:
|
1215 |
-
st.stop()
|
1216 |
-
|
1217 |
-
# Set to store unique values of Panel_1 and Panel_2
|
1218 |
-
with st.spinner("Fetching Panel values..."):
|
1219 |
-
all_panel1_values, all_panel2_values = (
|
1220 |
-
clean_and_extract_unique_values(files_dict, selections)
|
1221 |
-
)
|
1222 |
-
|
1223 |
-
# List of Panel_1 and Panel_2 columns unique values
|
1224 |
-
list_of_all_panel1_values = list(all_panel1_values)
|
1225 |
-
list_of_all_panel2_values = list(all_panel2_values)
|
1226 |
-
|
1227 |
-
# Format Panel_1 and Panel_2 values for display
|
1228 |
-
formatted_panel1_values = format_values_for_display(
|
1229 |
-
list_of_all_panel1_values
|
1230 |
-
) ##
|
1231 |
-
formatted_panel2_values = format_values_for_display(
|
1232 |
-
list_of_all_panel2_values
|
1233 |
-
) ##
|
1234 |
-
|
1235 |
-
# storing panel values in project_dct
|
1236 |
-
|
1237 |
-
st.session_state["project_dct"]["data_import"][
|
1238 |
-
"formatted_panel1_values"
|
1239 |
-
] = formatted_panel1_values
|
1240 |
-
st.session_state["project_dct"]["data_import"][
|
1241 |
-
"formatted_panel2_values"
|
1242 |
-
] = formatted_panel2_values
|
1243 |
-
|
1244 |
-
# Unique Panel_1 and Panel_2 values
|
1245 |
-
st.markdown("#### Unique Panel values")
|
1246 |
-
# Display Panel_1 and Panel_2 values
|
1247 |
-
with st.expander("Unique Panel values"):
|
1248 |
-
st.write("")
|
1249 |
-
st.markdown(
|
1250 |
-
f"""
|
1251 |
-
<style>
|
1252 |
-
.justify-text {{
|
1253 |
-
text-align: justify;
|
1254 |
-
}}
|
1255 |
-
</style>
|
1256 |
-
<div class="justify-text">
|
1257 |
-
<strong>Panel Level 1 Values:</strong> {formatted_panel1_values}<br>
|
1258 |
-
<strong>Panel Level 2 Values:</strong> {formatted_panel2_values}
|
1259 |
-
</div>
|
1260 |
-
""",
|
1261 |
-
unsafe_allow_html=True,
|
1262 |
-
)
|
1263 |
-
|
1264 |
-
# Display total Panel_1 and Panel_2
|
1265 |
-
st.write("")
|
1266 |
-
st.markdown(
|
1267 |
-
f"""
|
1268 |
-
<div style="text-align: justify;">
|
1269 |
-
<strong>Number of Level 1 Panels detected:</strong> {len(list_of_all_panel1_values)}<br>
|
1270 |
-
<strong>Number of Level 2 Panels detected:</strong> {len(list_of_all_panel2_values)}
|
1271 |
-
</div>
|
1272 |
-
""",
|
1273 |
-
unsafe_allow_html=True,
|
1274 |
-
)
|
1275 |
-
st.write("")
|
1276 |
-
|
1277 |
-
#########################################################################################################################################################
|
1278 |
-
# Merge all DataFrames
|
1279 |
-
#########################################################################################################################################################
|
1280 |
-
|
1281 |
-
# Merge all DataFrames selected
|
1282 |
-
|
1283 |
-
main_df = create_main_dataframe(
|
1284 |
-
files_dict,
|
1285 |
-
all_panel1_values,
|
1286 |
-
all_panel2_values,
|
1287 |
-
granularity_selection,
|
1288 |
-
)
|
1289 |
-
|
1290 |
-
merged_df = merge_into_main_df(main_df, files_dict, selections) ##
|
1291 |
-
|
1292 |
-
#########################################################################################################################################################
|
1293 |
-
# Categorize Variables and Impute Missing Values
|
1294 |
-
#########################################################################################################################################################
|
1295 |
-
|
1296 |
-
# Create an editable DataFrame in Streamlit
|
1297 |
-
|
1298 |
-
st.markdown("#### Select Variables Category & Impute Missing Values")
|
1299 |
-
|
1300 |
-
# Prepare missing stats DataFrame for editing
|
1301 |
-
missing_stats_df = prepare_missing_stats_df(merged_df)
|
1302 |
-
|
1303 |
-
# storing missing stats df
|
1304 |
-
|
1305 |
-
edited_stats_df = st.data_editor(
|
1306 |
-
missing_stats_df,
|
1307 |
-
column_config={
|
1308 |
-
"Impute Method": st.column_config.SelectboxColumn(
|
1309 |
-
options=[
|
1310 |
-
"Drop Column",
|
1311 |
-
"Fill with Mean",
|
1312 |
-
"Fill with Median",
|
1313 |
-
"Fill with 0",
|
1314 |
-
],
|
1315 |
-
required=True,
|
1316 |
-
default="Fill with 0",
|
1317 |
-
),
|
1318 |
-
"Category": st.column_config.SelectboxColumn(
|
1319 |
-
options=[
|
1320 |
-
"Media",
|
1321 |
-
"Exogenous",
|
1322 |
-
"Internal",
|
1323 |
-
"Response Metrics",
|
1324 |
-
],
|
1325 |
-
required=True,
|
1326 |
-
default="Media",
|
1327 |
-
),
|
1328 |
-
},
|
1329 |
-
disabled=["Column", "Missing Values", "Missing Percentage"],
|
1330 |
-
hide_index=True,
|
1331 |
-
use_container_width=True,
|
1332 |
-
key="data-editor-2",
|
1333 |
-
)
|
1334 |
-
|
1335 |
-
# edited_stats_df_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
|
1336 |
-
|
1337 |
-
# edited_stats_df.to_pickle(edited_stats_df_path)
|
1338 |
-
|
1339 |
-
# Apply changes based on edited DataFrame
|
1340 |
-
for i, row in edited_stats_df.iterrows():
|
1341 |
-
column = row["Column"]
|
1342 |
-
if row["Impute Method"] == "Drop Column":
|
1343 |
-
merged_df.drop(columns=[column], inplace=True)
|
1344 |
-
|
1345 |
-
elif row["Impute Method"] == "Fill with Mean":
|
1346 |
-
merged_df[column].fillna(
|
1347 |
-
merged_df[column].mean(), inplace=True
|
1348 |
-
)
|
1349 |
-
|
1350 |
-
elif row["Impute Method"] == "Fill with Median":
|
1351 |
-
merged_df[column].fillna(
|
1352 |
-
merged_df[column].median(), inplace=True
|
1353 |
-
)
|
1354 |
-
|
1355 |
-
elif row["Impute Method"] == "Fill with 0":
|
1356 |
-
merged_df[column].fillna(0, inplace=True)
|
1357 |
-
|
1358 |
-
# st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
|
1359 |
-
|
1360 |
-
#########################################################################################################################################################
|
1361 |
-
# Group columns
|
1362 |
-
#########################################################################################################################################################
|
1363 |
-
|
1364 |
-
# Display Group columns header
|
1365 |
-
st.markdown("#### Feature engineering")
|
1366 |
-
|
1367 |
-
# Prepare the numeric columns and an empty DataFrame for user input
|
1368 |
-
numeric_columns, default_df = prepare_numeric_columns_and_default_df(
|
1369 |
-
merged_df, edited_stats_df
|
1370 |
-
)
|
1371 |
-
|
1372 |
-
# st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
|
1373 |
-
|
1374 |
-
# Display editable Dataframe
|
1375 |
-
edited_df = st.data_editor(
|
1376 |
-
default_df,
|
1377 |
-
column_config={
|
1378 |
-
"Column 1": st.column_config.SelectboxColumn(
|
1379 |
-
options=numeric_columns,
|
1380 |
-
required=True,
|
1381 |
-
width=400,
|
1382 |
-
),
|
1383 |
-
"Operator": st.column_config.SelectboxColumn(
|
1384 |
-
options=["+", "-", "*", "/"],
|
1385 |
-
required=True,
|
1386 |
-
default="+",
|
1387 |
-
width=100,
|
1388 |
-
),
|
1389 |
-
"Column 2": st.column_config.SelectboxColumn(
|
1390 |
-
options=numeric_columns,
|
1391 |
-
required=True,
|
1392 |
-
default=numeric_columns[0],
|
1393 |
-
width=400,
|
1394 |
-
),
|
1395 |
-
"Category": st.column_config.SelectboxColumn(
|
1396 |
-
options=[
|
1397 |
-
"Media",
|
1398 |
-
"Exogenous",
|
1399 |
-
"Internal",
|
1400 |
-
"Response Metrics",
|
1401 |
-
],
|
1402 |
-
required=True,
|
1403 |
-
default="Media",
|
1404 |
-
width=200,
|
1405 |
-
),
|
1406 |
-
},
|
1407 |
-
num_rows="dynamic",
|
1408 |
-
key="data-editor-3",
|
1409 |
-
)
|
1410 |
-
|
1411 |
-
# Process the DataFrame based on user inputs and operations specified in edited_df
|
1412 |
-
final_df, edited_stats_df = process_dataframes(
|
1413 |
-
merged_df, edited_df, edited_stats_df
|
1414 |
-
)
|
1415 |
-
|
1416 |
-
# edited_df_path=os.path.join(st.session_state['project_path'],'edited_df.pkl')
|
1417 |
-
# edited_df.to_pickle(edited_df_path)
|
1418 |
-
|
1419 |
-
#########################################################################################################################################################
|
1420 |
-
# Display the Final DataFrame and variables
|
1421 |
-
#########################################################################################################################################################
|
1422 |
-
|
1423 |
-
# Display the Final DataFrame and variables
|
1424 |
-
|
1425 |
-
st.markdown("#### Final DataFrame")
|
1426 |
-
|
1427 |
-
st.dataframe(final_df, hide_index=True)
|
1428 |
-
|
1429 |
-
# Initialize an empty dictionary to hold categories and their variables
|
1430 |
-
category_dict = {}
|
1431 |
-
|
1432 |
-
# Iterate over each row in the edited DataFrame to populate the dictionary
|
1433 |
-
for i, row in edited_stats_df.iterrows():
|
1434 |
-
column = row["Column"]
|
1435 |
-
category = row[
|
1436 |
-
"Category"
|
1437 |
-
] # The category chosen by the user for this variable
|
1438 |
-
|
1439 |
-
# Check if the category already exists in the dictionary
|
1440 |
-
if category not in category_dict:
|
1441 |
-
# If not, initialize it with the current column as its first element
|
1442 |
-
category_dict[category] = [column]
|
1443 |
-
else:
|
1444 |
-
# If it exists, append the current column to the list of variables under this category
|
1445 |
-
category_dict[category].append(column)
|
1446 |
-
|
1447 |
-
# Add Date, Panel_1 and Panel_12 in category dictionary
|
1448 |
-
category_dict.update({"Date": ["date"]})
|
1449 |
-
if "Panel_1" in final_df.columns:
|
1450 |
-
category_dict["Panel Level 1"] = ["Panel_1"]
|
1451 |
-
if "Panel_2" in final_df.columns:
|
1452 |
-
category_dict["Panel Level 2"] = ["Panel_2"]
|
1453 |
-
|
1454 |
-
# Display the dictionary
|
1455 |
-
st.markdown("#### Variable Category")
|
1456 |
-
for category, variables in category_dict.items():
|
1457 |
-
# Check if there are multiple variables to handle "and" insertion correctly
|
1458 |
-
if len(variables) > 1:
|
1459 |
-
# Join all but the last variable with ", ", then add " and " before the last variable
|
1460 |
-
variables_str = (
|
1461 |
-
", ".join(variables[:-1]) + " and " + variables[-1]
|
1462 |
-
)
|
1463 |
-
else:
|
1464 |
-
# If there's only one variable, no need for "and"
|
1465 |
-
variables_str = variables[0]
|
1466 |
-
|
1467 |
-
# Display the category and its variables in the desired format
|
1468 |
-
st.markdown(
|
1469 |
-
f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
|
1470 |
-
unsafe_allow_html=True,
|
1471 |
-
)
|
1472 |
-
|
1473 |
-
# Function to check if Response Metrics is selected
|
1474 |
-
st.write("")
|
1475 |
-
|
1476 |
-
response_metrics_col = category_dict.get("Response Metrics", [])
|
1477 |
-
if len(response_metrics_col) == 0:
|
1478 |
-
st.warning("Please select Response Metrics column", icon="⚠️")
|
1479 |
-
st.stop()
|
1480 |
-
# elif len(response_metrics_col) > 1:
|
1481 |
-
# st.warning("Please select only one Response Metrics column", icon="⚠️")
|
1482 |
-
# st.stop()
|
1483 |
-
|
1484 |
-
# Store final dataframe and bin dictionary into session state
|
1485 |
-
|
1486 |
-
st.session_state["final_df"], st.session_state["bin_dict"] = (
|
1487 |
-
final_df,
|
1488 |
-
category_dict,
|
1489 |
-
)
|
1490 |
-
|
1491 |
-
# Save the DataFrame and dictionary from the session state to the pickle file
|
1492 |
-
|
1493 |
-
if st.button("Accept and Save", use_container_width=True):
|
1494 |
-
|
1495 |
-
print("test*************")
|
1496 |
-
update_db("1_Data_Import.py")
|
1497 |
-
|
1498 |
-
project_dct_path = os.path.join(
|
1499 |
-
st.session_state["project_path"], "project_dct.pkl"
|
1500 |
-
)
|
1501 |
-
|
1502 |
-
with open(project_dct_path, "wb") as f:
|
1503 |
-
pickle.dump(st.session_state["project_dct"], f)
|
1504 |
-
|
1505 |
-
data_path = os.path.join(
|
1506 |
-
st.session_state["project_path"], "data_import.pkl"
|
1507 |
-
)
|
1508 |
-
st.session_state["data_path"] = data_path
|
1509 |
-
|
1510 |
-
save_to_pickle(
|
1511 |
-
data_path,
|
1512 |
-
st.session_state["final_df"],
|
1513 |
-
st.session_state["bin_dict"],
|
1514 |
-
)
|
1515 |
-
|
1516 |
-
st.session_state["project_dct"]["data_import"][
|
1517 |
-
"edited_stats_df"
|
1518 |
-
] = edited_stats_df
|
1519 |
-
st.session_state["project_dct"]["data_import"][
|
1520 |
-
"merged_df"
|
1521 |
-
] = merged_df
|
1522 |
-
st.session_state["project_dct"]["data_import"][
|
1523 |
-
"missing_stats_df"
|
1524 |
-
] = missing_stats_df
|
1525 |
-
st.session_state["project_dct"]["data_import"]["cat_dct"] = {
|
1526 |
-
col: cat
|
1527 |
-
for col, cat in zip(
|
1528 |
-
edited_stats_df["Column"], edited_stats_df["Category"]
|
1529 |
-
)
|
1530 |
-
}
|
1531 |
-
st.session_state["project_dct"]["data_import"][
|
1532 |
-
"numeric_columns"
|
1533 |
-
] = numeric_columns
|
1534 |
-
st.session_state["project_dct"]["data_import"][
|
1535 |
-
"default_df"
|
1536 |
-
] = default_df
|
1537 |
-
st.session_state["project_dct"]["data_import"][
|
1538 |
-
"final_df"
|
1539 |
-
] = final_df
|
1540 |
-
st.session_state["project_dct"]["data_import"][
|
1541 |
-
"edited_df"
|
1542 |
-
] = edited_df
|
1543 |
-
|
1544 |
-
st.toast("💾 Saved Successfully!")
|
1545 |
-
|
1546 |
-
# *****************************************************************
|
1547 |
-
# *********************************Persistant flow****************
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/2_Data_Validation.py
DELETED
@@ -1,509 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import pandas as pd
|
3 |
-
import plotly.express as px
|
4 |
-
import plotly.graph_objects as go
|
5 |
-
from Eda_functions import *
|
6 |
-
import numpy as np
|
7 |
-
import pickle
|
8 |
-
|
9 |
-
# from streamlit_pandas_profiling import st_profile_report
|
10 |
-
import streamlit as st
|
11 |
-
import streamlit.components.v1 as components
|
12 |
-
import sweetviz as sv
|
13 |
-
from utilities import set_header, load_local_css
|
14 |
-
from st_aggrid import GridOptionsBuilder, GridUpdateMode
|
15 |
-
from st_aggrid import GridOptionsBuilder
|
16 |
-
from st_aggrid import AgGrid
|
17 |
-
import base64
|
18 |
-
import os
|
19 |
-
import tempfile
|
20 |
-
#import pandas_profiling
|
21 |
-
#from pydantic_settings import BaseSettings
|
22 |
-
from ydata_profiling import ProfileReport
|
23 |
-
import re
|
24 |
-
|
25 |
-
# from pygwalker.api.streamlit import StreamlitRenderer
|
26 |
-
# from Home_redirecting import home
|
27 |
-
import sqlite3
|
28 |
-
from utilities import update_db
|
29 |
-
|
30 |
-
st.set_page_config(
|
31 |
-
page_title="Data Validation",
|
32 |
-
page_icon=":shark:",
|
33 |
-
layout="wide",
|
34 |
-
initial_sidebar_state="collapsed",
|
35 |
-
)
|
36 |
-
load_local_css("styles.css")
|
37 |
-
set_header()
|
38 |
-
|
39 |
-
|
40 |
-
if "project_dct" not in st.session_state:
|
41 |
-
# home()
|
42 |
-
st.warning("Please select a project from home page")
|
43 |
-
st.stop()
|
44 |
-
|
45 |
-
|
46 |
-
data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")
|
47 |
-
|
48 |
-
try:
|
49 |
-
with open(data_path, "rb") as f:
|
50 |
-
data = pickle.load(f)
|
51 |
-
except Exception as e:
|
52 |
-
st.error(f"Please import data from the Data Import Page")
|
53 |
-
st.stop()
|
54 |
-
|
55 |
-
conn = sqlite3.connect(
|
56 |
-
r"DB\User.db", check_same_thread=False
|
57 |
-
) # connection with sql db
|
58 |
-
c = conn.cursor()
|
59 |
-
st.session_state["cleaned_data"] = data["final_df"]
|
60 |
-
st.session_state["category_dict"] = data["bin_dict"]
|
61 |
-
# st.write(st.session_state['category_dict'])
|
62 |
-
|
63 |
-
st.title("Data Validation and Insights")
|
64 |
-
|
65 |
-
|
66 |
-
target_variables = [
|
67 |
-
st.session_state["category_dict"][key]
|
68 |
-
for key in st.session_state["category_dict"].keys()
|
69 |
-
if key == "Response Metrics"
|
70 |
-
]
|
71 |
-
target_variables = list(*target_variables)
|
72 |
-
target_column = st.selectbox(
|
73 |
-
"Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
|
74 |
-
target_variables,
|
75 |
-
index=st.session_state["project_dct"]["data_validation"]["target_column"],
|
76 |
-
)
|
77 |
-
|
78 |
-
st.session_state["project_dct"]["data_validation"]["target_column"] = (
|
79 |
-
target_variables.index(target_column)
|
80 |
-
)
|
81 |
-
|
82 |
-
st.session_state["target_column"] = target_column
|
83 |
-
|
84 |
-
panels = st.session_state["category_dict"]["Panel Level 1"][0]
|
85 |
-
|
86 |
-
selected_panels = st.multiselect(
|
87 |
-
"Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
|
88 |
-
st.session_state["cleaned_data"][panels].unique(),
|
89 |
-
default=st.session_state["project_dct"]["data_validation"][
|
90 |
-
"selected_panels"
|
91 |
-
],
|
92 |
-
)
|
93 |
-
|
94 |
-
st.session_state["project_dct"]["data_validation"][
|
95 |
-
"selected_panels"
|
96 |
-
] = selected_panels
|
97 |
-
|
98 |
-
aggregation_dict = {
|
99 |
-
item: "sum" if key == "Media" else "mean"
|
100 |
-
for key, value in st.session_state["category_dict"].items()
|
101 |
-
for item in value
|
102 |
-
if item not in ["date", "Panel_1"]
|
103 |
-
}
|
104 |
-
|
105 |
-
with st.expander("**Reponse Metric Analysis**"):
|
106 |
-
|
107 |
-
if len(selected_panels) > 0:
|
108 |
-
st.session_state["Cleaned_data_panel"] = st.session_state[
|
109 |
-
"cleaned_data"
|
110 |
-
][st.session_state["cleaned_data"]["Panel_1"].isin(selected_panels)]
|
111 |
-
|
112 |
-
st.session_state["Cleaned_data_panel"] = (
|
113 |
-
st.session_state["Cleaned_data_panel"]
|
114 |
-
.groupby(by="date")
|
115 |
-
.agg(aggregation_dict)
|
116 |
-
)
|
117 |
-
st.session_state["Cleaned_data_panel"] = st.session_state[
|
118 |
-
"Cleaned_data_panel"
|
119 |
-
].reset_index()
|
120 |
-
else:
|
121 |
-
# st.write(st.session_state['cleaned_data'])
|
122 |
-
st.session_state["Cleaned_data_panel"] = (
|
123 |
-
st.session_state["cleaned_data"]
|
124 |
-
.groupby(by="date")
|
125 |
-
.agg(aggregation_dict)
|
126 |
-
)
|
127 |
-
st.session_state["Cleaned_data_panel"] = st.session_state[
|
128 |
-
"Cleaned_data_panel"
|
129 |
-
].reset_index()
|
130 |
-
|
131 |
-
fig = line_plot_target(
|
132 |
-
st.session_state["Cleaned_data_panel"],
|
133 |
-
target=target_column,
|
134 |
-
title=f"{target_column} Over Time",
|
135 |
-
)
|
136 |
-
st.plotly_chart(fig, use_container_width=True)
|
137 |
-
|
138 |
-
media_channel = list(
|
139 |
-
*[
|
140 |
-
st.session_state["category_dict"][key]
|
141 |
-
for key in st.session_state["category_dict"].keys()
|
142 |
-
if key == "Media"
|
143 |
-
]
|
144 |
-
)
|
145 |
-
# st.write(media_channel)
|
146 |
-
|
147 |
-
exo_var = list(
|
148 |
-
*[
|
149 |
-
st.session_state["category_dict"][key]
|
150 |
-
for key in st.session_state["category_dict"].keys()
|
151 |
-
if key == "Exogenous"
|
152 |
-
]
|
153 |
-
)
|
154 |
-
internal_var = list(
|
155 |
-
*[
|
156 |
-
st.session_state["category_dict"][key]
|
157 |
-
for key in st.session_state["category_dict"].keys()
|
158 |
-
if key == "Internal"
|
159 |
-
]
|
160 |
-
)
|
161 |
-
Non_media_variables = exo_var + internal_var
|
162 |
-
|
163 |
-
st.markdown("### Annual Data Summary")
|
164 |
-
|
165 |
-
st.dataframe(
|
166 |
-
summary(
|
167 |
-
st.session_state["Cleaned_data_panel"],
|
168 |
-
media_channel + [target_column],
|
169 |
-
spends=None,
|
170 |
-
Target=True,
|
171 |
-
),
|
172 |
-
use_container_width=True,
|
173 |
-
)
|
174 |
-
|
175 |
-
if st.checkbox("Show raw data"):
|
176 |
-
st.write(
|
177 |
-
pd.concat(
|
178 |
-
[
|
179 |
-
pd.to_datetime(
|
180 |
-
st.session_state["Cleaned_data_panel"]["date"]
|
181 |
-
).dt.strftime("%m/%d/%Y"),
|
182 |
-
st.session_state["Cleaned_data_panel"]
|
183 |
-
.select_dtypes(np.number)
|
184 |
-
.applymap(format_numbers),
|
185 |
-
],
|
186 |
-
axis=1,
|
187 |
-
)
|
188 |
-
)
|
189 |
-
col1 = st.columns(1)
|
190 |
-
|
191 |
-
if "selected_feature" not in st.session_state:
|
192 |
-
st.session_state["selected_feature"] = None
|
193 |
-
|
194 |
-
|
195 |
-
def generate_report_with_target(channel_data, target_feature):
|
196 |
-
report = sv.analyze([channel_data, "Dataset"], target_feat=target_feature)
|
197 |
-
temp_dir = tempfile.mkdtemp()
|
198 |
-
report_path = os.path.join(temp_dir, "report.html")
|
199 |
-
report.show_html(
|
200 |
-
filepath=report_path, open_browser=False
|
201 |
-
) # Generate the report as an HTML file
|
202 |
-
return report_path
|
203 |
-
|
204 |
-
|
205 |
-
def generate_profile_report(df):
|
206 |
-
pr = df.profile_report()
|
207 |
-
temp_dir = tempfile.mkdtemp()
|
208 |
-
report_path = os.path.join(temp_dir, "report.html")
|
209 |
-
pr.to_file(report_path)
|
210 |
-
return report_path
|
211 |
-
|
212 |
-
|
213 |
-
# st.header()
|
214 |
-
with st.expander("Univariate and Bivariate Report"):
|
215 |
-
eda_columns = st.columns(2)
|
216 |
-
with eda_columns[0]:
|
217 |
-
if st.button(
|
218 |
-
"Generate Profile Report",
|
219 |
-
help="Univariate report which inlcudes all statistical analysis",
|
220 |
-
):
|
221 |
-
with st.spinner("Generating Report"):
|
222 |
-
report_file = generate_profile_report(
|
223 |
-
st.session_state["Cleaned_data_panel"]
|
224 |
-
)
|
225 |
-
|
226 |
-
if os.path.exists(report_file):
|
227 |
-
with open(report_file, "rb") as f:
|
228 |
-
st.success("Report Generated")
|
229 |
-
st.download_button(
|
230 |
-
label="Download EDA Report",
|
231 |
-
data=f.read(),
|
232 |
-
file_name="pandas_profiling_report.html",
|
233 |
-
mime="text/html",
|
234 |
-
)
|
235 |
-
else:
|
236 |
-
st.warning(
|
237 |
-
"Report generation failed. Unable to find the report file."
|
238 |
-
)
|
239 |
-
|
240 |
-
with eda_columns[1]:
|
241 |
-
if st.button(
|
242 |
-
"Generate Sweetviz Report",
|
243 |
-
help="Bivariate report for selected response metric",
|
244 |
-
):
|
245 |
-
with st.spinner("Generating Report"):
|
246 |
-
report_file = generate_report_with_target(
|
247 |
-
st.session_state["Cleaned_data_panel"], target_column
|
248 |
-
)
|
249 |
-
|
250 |
-
if os.path.exists(report_file):
|
251 |
-
with open(report_file, "rb") as f:
|
252 |
-
st.success("Report Generated")
|
253 |
-
st.download_button(
|
254 |
-
label="Download EDA Report",
|
255 |
-
data=f.read(),
|
256 |
-
file_name="report.html",
|
257 |
-
mime="text/html",
|
258 |
-
)
|
259 |
-
else:
|
260 |
-
st.warning(
|
261 |
-
"Report generation failed. Unable to find the report file."
|
262 |
-
)
|
263 |
-
|
264 |
-
|
265 |
-
# st.warning('Work in Progress')
|
266 |
-
with st.expander("Media Variables Analysis"):
|
267 |
-
# Get the selected feature
|
268 |
-
|
269 |
-
media_variables = [
|
270 |
-
col
|
271 |
-
for col in media_channel
|
272 |
-
if "cost" not in col.lower() and "spend" not in col.lower()
|
273 |
-
]
|
274 |
-
|
275 |
-
st.session_state["selected_feature"] = st.selectbox(
|
276 |
-
"Select media", media_variables
|
277 |
-
)
|
278 |
-
|
279 |
-
st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
|
280 |
-
media_variables.index(st.session_state["selected_feature"])
|
281 |
-
)
|
282 |
-
|
283 |
-
# Filter spends features based on the selected feature
|
284 |
-
spends_features = [
|
285 |
-
col
|
286 |
-
for col in st.session_state["Cleaned_data_panel"].columns
|
287 |
-
if any(keyword in col.lower() for keyword in ["cost", "spend"])
|
288 |
-
]
|
289 |
-
spends_feature = [
|
290 |
-
col
|
291 |
-
for col in spends_features
|
292 |
-
if re.split(r"_cost|_spend", col.lower())[0]
|
293 |
-
in st.session_state["selected_feature"]
|
294 |
-
]
|
295 |
-
|
296 |
-
if "validation" not in st.session_state:
|
297 |
-
|
298 |
-
st.session_state["validation"] = st.session_state["project_dct"][
|
299 |
-
"data_validation"
|
300 |
-
]["validated_variables"]
|
301 |
-
|
302 |
-
val_variables = [col for col in media_channel if col != "date"]
|
303 |
-
|
304 |
-
if not set(
|
305 |
-
st.session_state["project_dct"]["data_validation"][
|
306 |
-
"validated_variables"
|
307 |
-
]
|
308 |
-
).issubset(set(val_variables)):
|
309 |
-
|
310 |
-
st.session_state["validation"] = []
|
311 |
-
|
312 |
-
if len(spends_feature) == 0:
|
313 |
-
st.warning(
|
314 |
-
"No spends varaible available for the selected metric in data"
|
315 |
-
)
|
316 |
-
|
317 |
-
else:
|
318 |
-
fig_row1 = line_plot(
|
319 |
-
st.session_state["Cleaned_data_panel"],
|
320 |
-
x_col="date",
|
321 |
-
y1_cols=[st.session_state["selected_feature"]],
|
322 |
-
y2_cols=[target_column],
|
323 |
-
title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
|
324 |
-
)
|
325 |
-
st.plotly_chart(fig_row1, use_container_width=True)
|
326 |
-
st.markdown("### Summary")
|
327 |
-
st.dataframe(
|
328 |
-
summary(
|
329 |
-
st.session_state["cleaned_data"],
|
330 |
-
[st.session_state["selected_feature"]],
|
331 |
-
spends=spends_feature[0],
|
332 |
-
),
|
333 |
-
use_container_width=True,
|
334 |
-
)
|
335 |
-
|
336 |
-
cols2 = st.columns(2)
|
337 |
-
|
338 |
-
if len(
|
339 |
-
set(st.session_state["validation"]).intersection(val_variables)
|
340 |
-
) == len(val_variables):
|
341 |
-
disable = True
|
342 |
-
help = "All media variables are validated"
|
343 |
-
else:
|
344 |
-
disable = False
|
345 |
-
help = ""
|
346 |
-
|
347 |
-
with cols2[0]:
|
348 |
-
if st.button("Validate", disabled=disable, help=help):
|
349 |
-
st.session_state["validation"].append(
|
350 |
-
st.session_state["selected_feature"]
|
351 |
-
)
|
352 |
-
with cols2[1]:
|
353 |
-
|
354 |
-
if st.checkbox("Validate all", disabled=disable, help=help):
|
355 |
-
st.session_state["validation"].extend(val_variables)
|
356 |
-
st.success("All media variables are validated ✅")
|
357 |
-
|
358 |
-
if len(
|
359 |
-
set(st.session_state["validation"]).intersection(val_variables)
|
360 |
-
) != len(val_variables):
|
361 |
-
validation_data = pd.DataFrame(
|
362 |
-
{
|
363 |
-
"Validate": [
|
364 |
-
(
|
365 |
-
True
|
366 |
-
if col in st.session_state["validation"]
|
367 |
-
else False
|
368 |
-
)
|
369 |
-
for col in val_variables
|
370 |
-
],
|
371 |
-
"Variables": val_variables,
|
372 |
-
}
|
373 |
-
)
|
374 |
-
cols3 = st.columns([1, 30])
|
375 |
-
with cols3[1]:
|
376 |
-
validation_df = st.data_editor(
|
377 |
-
validation_data,
|
378 |
-
# column_config={
|
379 |
-
# 'Validate':st.column_config.CheckboxColumn(wi)
|
380 |
-
# },
|
381 |
-
column_config={
|
382 |
-
"Validate": st.column_config.CheckboxColumn(
|
383 |
-
default=False,
|
384 |
-
width=100,
|
385 |
-
),
|
386 |
-
"Variables": st.column_config.TextColumn(width=1000),
|
387 |
-
},
|
388 |
-
hide_index=True,
|
389 |
-
)
|
390 |
-
|
391 |
-
selected_rows = validation_df[
|
392 |
-
validation_df["Validate"] == True
|
393 |
-
]["Variables"]
|
394 |
-
|
395 |
-
# st.write(selected_rows)
|
396 |
-
|
397 |
-
st.session_state["validation"].extend(selected_rows)
|
398 |
-
|
399 |
-
st.session_state["project_dct"]["data_validation"][
|
400 |
-
"validated_variables"
|
401 |
-
] = st.session_state["validation"]
|
402 |
-
|
403 |
-
not_validated_variables = [
|
404 |
-
col
|
405 |
-
for col in val_variables
|
406 |
-
if col not in st.session_state["validation"]
|
407 |
-
]
|
408 |
-
|
409 |
-
if not_validated_variables:
|
410 |
-
not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
|
411 |
-
st.warning(not_validated_message)
|
412 |
-
|
413 |
-
|
414 |
-
with st.expander("Non Media Variables Analysis"):
|
415 |
-
selected_columns_row4 = st.selectbox(
|
416 |
-
"Select Channel",
|
417 |
-
Non_media_variables,
|
418 |
-
index=st.session_state["project_dct"]["data_validation"][
|
419 |
-
"Non_media_variables"
|
420 |
-
],
|
421 |
-
)
|
422 |
-
|
423 |
-
st.session_state["project_dct"]["data_validation"][
|
424 |
-
"Non_media_variables"
|
425 |
-
] = Non_media_variables.index(selected_columns_row4)
|
426 |
-
|
427 |
-
# # Create the dual-axis line plot
|
428 |
-
fig_row4 = line_plot(
|
429 |
-
st.session_state["Cleaned_data_panel"],
|
430 |
-
x_col="date",
|
431 |
-
y1_cols=[selected_columns_row4],
|
432 |
-
y2_cols=[target_column],
|
433 |
-
title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
|
434 |
-
)
|
435 |
-
st.plotly_chart(fig_row4, use_container_width=True)
|
436 |
-
selected_non_media = selected_columns_row4
|
437 |
-
sum_df = st.session_state["Cleaned_data_panel"][
|
438 |
-
["date", selected_non_media, target_column]
|
439 |
-
]
|
440 |
-
sum_df["Year"] = pd.to_datetime(
|
441 |
-
st.session_state["Cleaned_data_panel"]["date"]
|
442 |
-
).dt.year
|
443 |
-
# st.dataframe(df)
|
444 |
-
# st.dataframe(sum_df.head(2))
|
445 |
-
print(sum_df)
|
446 |
-
sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
|
447 |
-
sum_df.loc["Grand Total"] = sum_df.sum()
|
448 |
-
sum_df = sum_df.applymap(format_numbers)
|
449 |
-
sum_df.fillna("-", inplace=True)
|
450 |
-
sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
|
451 |
-
st.markdown("### Summary")
|
452 |
-
st.dataframe(sum_df, use_container_width=True)
|
453 |
-
|
454 |
-
# with st.expander('Interactive Dashboard'):
|
455 |
-
|
456 |
-
# pygg_app=StreamlitRenderer(st.session_state['cleaned_data'])
|
457 |
-
|
458 |
-
# pygg_app.explorer()
|
459 |
-
|
460 |
-
with st.expander("Correlation Analysis"):
|
461 |
-
options = list(
|
462 |
-
st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
|
463 |
-
)
|
464 |
-
|
465 |
-
# selected_options = []
|
466 |
-
# num_columns = 4
|
467 |
-
# num_rows = -(-len(options) // num_columns) # Ceiling division to calculate rows
|
468 |
-
|
469 |
-
# # Create a grid of checkboxes
|
470 |
-
# st.header('Select Features for Correlation Plot')
|
471 |
-
# tick=False
|
472 |
-
# if st.checkbox('Select all'):
|
473 |
-
# tick=True
|
474 |
-
# selected_options = []
|
475 |
-
# for row in range(num_rows):
|
476 |
-
# cols = st.columns(num_columns)
|
477 |
-
# for col in cols:
|
478 |
-
# if options:
|
479 |
-
# option = options.pop(0)
|
480 |
-
# selected = col.checkbox(option,value=tick)
|
481 |
-
# if selected:
|
482 |
-
# selected_options.append(option)
|
483 |
-
# # Display selected options
|
484 |
-
|
485 |
-
selected_options = st.multiselect(
|
486 |
-
"Select Variables For correlation plot",
|
487 |
-
[var for var in options if var != target_column],
|
488 |
-
default=options[3],
|
489 |
-
)
|
490 |
-
|
491 |
-
st.pyplot(
|
492 |
-
correlation_plot(
|
493 |
-
st.session_state["Cleaned_data_panel"],
|
494 |
-
selected_options,
|
495 |
-
target_column,
|
496 |
-
)
|
497 |
-
)
|
498 |
-
|
499 |
-
if st.button("Save Changes", use_container_width=True):
|
500 |
-
|
501 |
-
update_db("2_Data_Validation.py")
|
502 |
-
|
503 |
-
project_dct_path = os.path.join(
|
504 |
-
st.session_state["project_path"], "project_dct.pkl"
|
505 |
-
)
|
506 |
-
|
507 |
-
with open(project_dct_path, "wb") as f:
|
508 |
-
pickle.dump(st.session_state["project_dct"], f)
|
509 |
-
st.success("Changes saved")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/3_Transformations.py
DELETED
@@ -1,686 +0,0 @@
|
|
1 |
-
# Importing necessary libraries
|
2 |
-
import streamlit as st
|
3 |
-
|
4 |
-
st.set_page_config(
|
5 |
-
page_title="Transformations",
|
6 |
-
page_icon=":shark:",
|
7 |
-
layout="wide",
|
8 |
-
initial_sidebar_state="collapsed",
|
9 |
-
)
|
10 |
-
|
11 |
-
import pickle
|
12 |
-
import numpy as np
|
13 |
-
import pandas as pd
|
14 |
-
from utilities import set_header, load_local_css
|
15 |
-
import streamlit_authenticator as stauth
|
16 |
-
import yaml
|
17 |
-
from yaml import SafeLoader
|
18 |
-
import os
|
19 |
-
import sqlite3
|
20 |
-
from utilities import update_db
|
21 |
-
|
22 |
-
|
23 |
-
load_local_css("styles.css")
|
24 |
-
set_header()
|
25 |
-
|
26 |
-
|
27 |
-
# Check for authentication status
|
28 |
-
for k, v in st.session_state.items():
|
29 |
-
if k not in ["logout", "login", "config"] and not k.startswith(
|
30 |
-
"FormSubmitter"
|
31 |
-
):
|
32 |
-
st.session_state[k] = v
|
33 |
-
with open("config.yaml") as file:
|
34 |
-
config = yaml.load(file, Loader=SafeLoader)
|
35 |
-
st.session_state["config"] = config
|
36 |
-
authenticator = stauth.Authenticate(
|
37 |
-
config["credentials"],
|
38 |
-
config["cookie"]["name"],
|
39 |
-
config["cookie"]["key"],
|
40 |
-
config["cookie"]["expiry_days"],
|
41 |
-
config["preauthorized"],
|
42 |
-
)
|
43 |
-
st.session_state["authenticator"] = authenticator
|
44 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
45 |
-
auth_status = st.session_state.get("authentication_status")
|
46 |
-
|
47 |
-
if auth_status == True:
|
48 |
-
authenticator.logout("Logout", "main")
|
49 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
50 |
-
|
51 |
-
if "project_dct" not in st.session_state:
|
52 |
-
st.error("Please load a project from Home page")
|
53 |
-
st.stop()
|
54 |
-
|
55 |
-
conn = sqlite3.connect(
|
56 |
-
r"DB/User.db", check_same_thread=False
|
57 |
-
) # connection with sql db
|
58 |
-
c = conn.cursor()
|
59 |
-
|
60 |
-
if not is_state_initiaized:
|
61 |
-
if "session_name" not in st.session_state:
|
62 |
-
st.session_state["session_name"] = None
|
63 |
-
|
64 |
-
if not os.path.exists(
|
65 |
-
os.path.join(st.session_state["project_path"], "data_import.pkl")
|
66 |
-
):
|
67 |
-
st.error("Please move to Data Import page")
|
68 |
-
# Deserialize and load the objects from the pickle file
|
69 |
-
with open(
|
70 |
-
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
|
71 |
-
) as f:
|
72 |
-
data = pickle.load(f)
|
73 |
-
|
74 |
-
# Accessing the loaded objects
|
75 |
-
final_df_loaded = data["final_df"]
|
76 |
-
bin_dict_loaded = data["bin_dict"]
|
77 |
-
# final_df_loaded.to_csv("Test/final_df_loaded.csv",index=False)
|
78 |
-
# Initialize session state==-
|
79 |
-
if "transformed_columns_dict" not in st.session_state:
|
80 |
-
st.session_state["transformed_columns_dict"] = (
|
81 |
-
{}
|
82 |
-
) # Default empty dictionary
|
83 |
-
|
84 |
-
if "final_df" not in st.session_state:
|
85 |
-
st.session_state["final_df"] = (
|
86 |
-
final_df_loaded # Default as original dataframe
|
87 |
-
)
|
88 |
-
|
89 |
-
if "summary_string" not in st.session_state:
|
90 |
-
st.session_state["summary_string"] = None # Default as None
|
91 |
-
|
92 |
-
# Extract original columns for specified categories
|
93 |
-
original_columns = {
|
94 |
-
category: bin_dict_loaded[category]
|
95 |
-
for category in ["Media", "Internal", "Exogenous"]
|
96 |
-
if category in bin_dict_loaded
|
97 |
-
}
|
98 |
-
|
99 |
-
# Retrive Panel columns
|
100 |
-
panel_1 = bin_dict_loaded.get("Panel Level 1")
|
101 |
-
panel_2 = bin_dict_loaded.get("Panel Level 2")
|
102 |
-
|
103 |
-
# # For testing on non panel level
|
104 |
-
# final_df_loaded = final_df_loaded.drop("Panel_1", axis=1)
|
105 |
-
# final_df_loaded = final_df_loaded.groupby("date").mean().reset_index()
|
106 |
-
# panel_1 = None
|
107 |
-
|
108 |
-
# Apply transformations on panel level
|
109 |
-
if panel_1:
|
110 |
-
panel = panel_1 + panel_2 if panel_2 else panel_1
|
111 |
-
else:
|
112 |
-
panel = []
|
113 |
-
|
114 |
-
# Function to build transformation widgets
|
115 |
-
def transformation_widgets(category, transform_params, date_granularity):
|
116 |
-
|
117 |
-
if (
|
118 |
-
st.session_state["project_dct"]["transformations"] is None
|
119 |
-
or st.session_state["project_dct"]["transformations"] == {}
|
120 |
-
):
|
121 |
-
st.session_state["project_dct"]["transformations"] = {}
|
122 |
-
if (
|
123 |
-
category
|
124 |
-
not in st.session_state["project_dct"]["transformations"].keys()
|
125 |
-
):
|
126 |
-
st.session_state["project_dct"]["transformations"][category] = {}
|
127 |
-
|
128 |
-
# Define a dict of pre-defined default values of every transformation
|
129 |
-
predefined_defualts = {
|
130 |
-
"Lag": (1, 2),
|
131 |
-
"Lead": (1, 2),
|
132 |
-
"Moving Average": (1, 2),
|
133 |
-
"Saturation": (10, 20),
|
134 |
-
"Power": (2, 4),
|
135 |
-
"Adstock": (0.5, 0.7),
|
136 |
-
}
|
137 |
-
|
138 |
-
def selection_change():
|
139 |
-
# Handles removing transformations
|
140 |
-
if f"transformation_{category}" in st.session_state:
|
141 |
-
current_selection = st.session_state[
|
142 |
-
f"transformation_{category}"
|
143 |
-
]
|
144 |
-
past_selection = st.session_state["project_dct"][
|
145 |
-
"transformations"
|
146 |
-
][category][f"transformation_{category}"]
|
147 |
-
removed_selection = list(
|
148 |
-
set(past_selection) - set(current_selection)
|
149 |
-
)
|
150 |
-
for selection in removed_selection:
|
151 |
-
# Option 1 - revert to defualt
|
152 |
-
# st.session_state['project_dct']['transformations'][category][selection] = predefined_defualts[selection]
|
153 |
-
|
154 |
-
# option 2 - delete from dict
|
155 |
-
del st.session_state["project_dct"]["transformations"][
|
156 |
-
category
|
157 |
-
][selection]
|
158 |
-
|
159 |
-
# Transformation Options
|
160 |
-
transformation_options = {
|
161 |
-
"Media": [
|
162 |
-
"Lag",
|
163 |
-
"Moving Average",
|
164 |
-
"Saturation",
|
165 |
-
"Power",
|
166 |
-
"Adstock",
|
167 |
-
],
|
168 |
-
"Internal": ["Lead", "Lag", "Moving Average"],
|
169 |
-
"Exogenous": ["Lead", "Lag", "Moving Average"],
|
170 |
-
}
|
171 |
-
|
172 |
-
expanded = st.session_state["project_dct"]["transformations"][
|
173 |
-
category
|
174 |
-
].get("expanded", False)
|
175 |
-
st.session_state["project_dct"]["transformations"][category][
|
176 |
-
"expanded"
|
177 |
-
] = False
|
178 |
-
with st.expander(f"{category} Transformations", expanded=expanded):
|
179 |
-
st.session_state["project_dct"]["transformations"][category][
|
180 |
-
"expanded"
|
181 |
-
] = True
|
182 |
-
|
183 |
-
# Let users select which transformations to apply
|
184 |
-
sel_transformations = st.session_state["project_dct"][
|
185 |
-
"transformations"
|
186 |
-
][category].get(f"transformation_{category}", [])
|
187 |
-
transformations_to_apply = st.multiselect(
|
188 |
-
"Select transformations to apply",
|
189 |
-
options=transformation_options[category],
|
190 |
-
default=sel_transformations,
|
191 |
-
key=f"transformation_{category}",
|
192 |
-
# on_change=selection_change(),
|
193 |
-
)
|
194 |
-
st.session_state["project_dct"]["transformations"][category][
|
195 |
-
"transformation_" + category
|
196 |
-
] = transformations_to_apply
|
197 |
-
# Determine the number of transformations to put in each column
|
198 |
-
transformations_per_column = (
|
199 |
-
len(transformations_to_apply) // 2
|
200 |
-
+ len(transformations_to_apply) % 2
|
201 |
-
)
|
202 |
-
|
203 |
-
# Create two columns
|
204 |
-
col1, col2 = st.columns(2)
|
205 |
-
|
206 |
-
# Assign transformations to each column
|
207 |
-
transformations_col1 = transformations_to_apply[
|
208 |
-
:transformations_per_column
|
209 |
-
]
|
210 |
-
transformations_col2 = transformations_to_apply[
|
211 |
-
transformations_per_column:
|
212 |
-
]
|
213 |
-
|
214 |
-
# Define a helper function to create widgets for each transformation
|
215 |
-
def create_transformation_widgets(column, transformations):
|
216 |
-
with column:
|
217 |
-
for transformation in transformations:
|
218 |
-
# Conditionally create widgets for selected transformations
|
219 |
-
if transformation == "Lead":
|
220 |
-
lead_default = st.session_state["project_dct"][
|
221 |
-
"transformations"
|
222 |
-
][category].get(
|
223 |
-
"Lead", predefined_defualts["Lead"]
|
224 |
-
)
|
225 |
-
st.markdown(f"**Lead ({date_granularity})**")
|
226 |
-
lead = st.slider(
|
227 |
-
"Lead periods",
|
228 |
-
1,
|
229 |
-
10,
|
230 |
-
lead_default,
|
231 |
-
1,
|
232 |
-
key=f"lead_{category}",
|
233 |
-
label_visibility="collapsed",
|
234 |
-
)
|
235 |
-
st.session_state["project_dct"]["transformations"][
|
236 |
-
category
|
237 |
-
]["Lead"] = lead
|
238 |
-
start = lead[0]
|
239 |
-
end = lead[1]
|
240 |
-
step = 1
|
241 |
-
transform_params[category]["Lead"] = np.arange(
|
242 |
-
start, end + step, step
|
243 |
-
)
|
244 |
-
|
245 |
-
if transformation == "Lag":
|
246 |
-
lag_default = st.session_state["project_dct"][
|
247 |
-
"transformations"
|
248 |
-
][category].get("Lag", predefined_defualts["Lag"])
|
249 |
-
st.markdown(f"**Lag ({date_granularity})**")
|
250 |
-
lag = st.slider(
|
251 |
-
"Lag periods",
|
252 |
-
1,
|
253 |
-
10,
|
254 |
-
(1, 2), # lag_default,
|
255 |
-
1,
|
256 |
-
key=f"lag_{category}",
|
257 |
-
label_visibility="collapsed",
|
258 |
-
)
|
259 |
-
st.session_state["project_dct"]["transformations"][
|
260 |
-
category
|
261 |
-
]["Lag"] = lag
|
262 |
-
start = lag[0]
|
263 |
-
end = lag[1]
|
264 |
-
step = 1
|
265 |
-
transform_params[category]["Lag"] = np.arange(
|
266 |
-
start, end + step, step
|
267 |
-
)
|
268 |
-
|
269 |
-
if transformation == "Moving Average":
|
270 |
-
ma_default = st.session_state["project_dct"][
|
271 |
-
"transformations"
|
272 |
-
][category].get(
|
273 |
-
"MA", predefined_defualts["Moving Average"]
|
274 |
-
)
|
275 |
-
st.markdown(
|
276 |
-
f"**Moving Average ({date_granularity})**"
|
277 |
-
)
|
278 |
-
window = st.slider(
|
279 |
-
"Window size for Moving Average",
|
280 |
-
1,
|
281 |
-
10,
|
282 |
-
ma_default,
|
283 |
-
1,
|
284 |
-
key=f"ma_{category}",
|
285 |
-
label_visibility="collapsed",
|
286 |
-
)
|
287 |
-
st.session_state["project_dct"]["transformations"][
|
288 |
-
category
|
289 |
-
]["MA"] = window
|
290 |
-
start = window[0]
|
291 |
-
end = window[1]
|
292 |
-
step = 1
|
293 |
-
transform_params[category]["Moving Average"] = (
|
294 |
-
np.arange(start, end + step, step)
|
295 |
-
)
|
296 |
-
|
297 |
-
if transformation == "Saturation":
|
298 |
-
st.markdown("**Saturation (%)**")
|
299 |
-
saturation_default = st.session_state[
|
300 |
-
"project_dct"
|
301 |
-
]["transformations"][category].get(
|
302 |
-
"Saturation", predefined_defualts["Saturation"]
|
303 |
-
)
|
304 |
-
saturation_point = st.slider(
|
305 |
-
f"Saturation Percentage",
|
306 |
-
0,
|
307 |
-
100,
|
308 |
-
saturation_default,
|
309 |
-
10,
|
310 |
-
key=f"sat_{category}",
|
311 |
-
label_visibility="collapsed",
|
312 |
-
)
|
313 |
-
st.session_state["project_dct"]["transformations"][
|
314 |
-
category
|
315 |
-
]["Saturation"] = saturation_point
|
316 |
-
start = saturation_point[0]
|
317 |
-
end = saturation_point[1]
|
318 |
-
step = 10
|
319 |
-
transform_params[category]["Saturation"] = (
|
320 |
-
np.arange(start, end + step, step)
|
321 |
-
)
|
322 |
-
|
323 |
-
if transformation == "Power":
|
324 |
-
st.markdown("**Power**")
|
325 |
-
power_default = st.session_state["project_dct"][
|
326 |
-
"transformations"
|
327 |
-
][category].get(
|
328 |
-
"Power", predefined_defualts["Power"]
|
329 |
-
)
|
330 |
-
power = st.slider(
|
331 |
-
f"Power",
|
332 |
-
0,
|
333 |
-
10,
|
334 |
-
power_default,
|
335 |
-
1,
|
336 |
-
key=f"power_{category}",
|
337 |
-
label_visibility="collapsed",
|
338 |
-
)
|
339 |
-
st.session_state["project_dct"]["transformations"][
|
340 |
-
category
|
341 |
-
]["Power"] = power
|
342 |
-
start = power[0]
|
343 |
-
end = power[1]
|
344 |
-
step = 1
|
345 |
-
transform_params[category]["Power"] = np.arange(
|
346 |
-
start, end + step, step
|
347 |
-
)
|
348 |
-
|
349 |
-
if transformation == "Adstock":
|
350 |
-
ads_default = st.session_state["project_dct"][
|
351 |
-
"transformations"
|
352 |
-
][category].get(
|
353 |
-
"Adstock", predefined_defualts["Adstock"]
|
354 |
-
)
|
355 |
-
st.markdown("**Adstock**")
|
356 |
-
rate = st.slider(
|
357 |
-
f"Factor ({category})",
|
358 |
-
0.0,
|
359 |
-
1.0,
|
360 |
-
ads_default,
|
361 |
-
0.05,
|
362 |
-
key=f"adstock_{category}",
|
363 |
-
label_visibility="collapsed",
|
364 |
-
)
|
365 |
-
st.session_state["project_dct"]["transformations"][
|
366 |
-
category
|
367 |
-
]["Adstock"] = rate
|
368 |
-
start = rate[0]
|
369 |
-
end = rate[1]
|
370 |
-
step = 0.05
|
371 |
-
adstock_range = [
|
372 |
-
round(a, 3)
|
373 |
-
for a in np.arange(start, end + step, step)
|
374 |
-
]
|
375 |
-
transform_params[category][
|
376 |
-
"Adstock"
|
377 |
-
] = adstock_range
|
378 |
-
|
379 |
-
# Create widgets in each column
|
380 |
-
create_transformation_widgets(col1, transformations_col1)
|
381 |
-
create_transformation_widgets(col2, transformations_col2)
|
382 |
-
|
383 |
-
# Function to apply Lag transformation
|
384 |
-
def apply_lag(df, lag):
|
385 |
-
return df.shift(lag)
|
386 |
-
|
387 |
-
# Function to apply Lead transformation
|
388 |
-
def apply_lead(df, lead):
|
389 |
-
return df.shift(-lead)
|
390 |
-
|
391 |
-
# Function to apply Moving Average transformation
|
392 |
-
def apply_moving_average(df, window_size):
|
393 |
-
return df.rolling(window=window_size).mean()
|
394 |
-
|
395 |
-
# Function to apply Saturation transformation
|
396 |
-
def apply_saturation(df, saturation_percent_100):
|
397 |
-
# Convert saturation percentage from 100-based to fraction
|
398 |
-
saturation_percent = saturation_percent_100 / 100.0
|
399 |
-
|
400 |
-
# Calculate saturation point and steepness
|
401 |
-
column_max = df.max()
|
402 |
-
column_min = df.min()
|
403 |
-
saturation_point = (column_min + column_max) / 2
|
404 |
-
|
405 |
-
numerator = np.log(
|
406 |
-
(1 / (saturation_percent if saturation_percent != 1 else 1 - 1e-9))
|
407 |
-
- 1
|
408 |
-
)
|
409 |
-
denominator = np.log(saturation_point / max(column_max, 1e-9))
|
410 |
-
|
411 |
-
steepness = numerator / max(
|
412 |
-
denominator, 1e-9
|
413 |
-
) # Avoid division by zero with a small constant
|
414 |
-
|
415 |
-
# Apply the saturation transformation
|
416 |
-
transformed_series = df.apply(
|
417 |
-
lambda x: (1 / (1 + (saturation_point / x) ** steepness)) * x
|
418 |
-
)
|
419 |
-
|
420 |
-
return transformed_series
|
421 |
-
|
422 |
-
# Function to apply Power transformation
|
423 |
-
def apply_power(df, power):
|
424 |
-
return df**power
|
425 |
-
|
426 |
-
# Function to apply Adstock transformation
|
427 |
-
def apply_adstock(df, factor):
|
428 |
-
x = 0
|
429 |
-
# Use the walrus operator to update x iteratively with the Adstock formula
|
430 |
-
adstock_var = [x := x * factor + v for v in df]
|
431 |
-
ans = pd.Series(adstock_var, index=df.index)
|
432 |
-
return ans
|
433 |
-
|
434 |
-
# Function to generate transformed columns names
|
435 |
-
@st.cache_resource(show_spinner=False)
|
436 |
-
def generate_transformed_columns(original_columns, transform_params):
|
437 |
-
transformed_columns, summary = {}, {}
|
438 |
-
|
439 |
-
for category, columns in original_columns.items():
|
440 |
-
for column in columns:
|
441 |
-
transformed_columns[column] = []
|
442 |
-
summary_details = (
|
443 |
-
[]
|
444 |
-
) # List to hold transformation details for the current column
|
445 |
-
|
446 |
-
if category in transform_params:
|
447 |
-
for transformation, values in transform_params[
|
448 |
-
category
|
449 |
-
].items():
|
450 |
-
# Generate transformed column names for each value
|
451 |
-
for value in values:
|
452 |
-
transformed_name = (
|
453 |
-
f"{column}@{transformation}_{value}"
|
454 |
-
)
|
455 |
-
transformed_columns[column].append(
|
456 |
-
transformed_name
|
457 |
-
)
|
458 |
-
|
459 |
-
# Format the values list as a string with commas and "and" before the last item
|
460 |
-
if len(values) > 1:
|
461 |
-
formatted_values = (
|
462 |
-
", ".join(map(str, values[:-1]))
|
463 |
-
+ " and "
|
464 |
-
+ str(values[-1])
|
465 |
-
)
|
466 |
-
else:
|
467 |
-
formatted_values = str(values[0])
|
468 |
-
|
469 |
-
# Add transformation details
|
470 |
-
summary_details.append(
|
471 |
-
f"{transformation} ({formatted_values})"
|
472 |
-
)
|
473 |
-
|
474 |
-
# Only add to summary if there are transformation details for the column
|
475 |
-
if summary_details:
|
476 |
-
formatted_summary = "⮕ ".join(summary_details)
|
477 |
-
# Use <strong> tags to make the column name bold
|
478 |
-
summary[column] = (
|
479 |
-
f"<strong>{column}</strong>: {formatted_summary}"
|
480 |
-
)
|
481 |
-
|
482 |
-
# Generate a comprehensive summary string for all columns
|
483 |
-
summary_items = [
|
484 |
-
f"{idx + 1}. {details}"
|
485 |
-
for idx, details in enumerate(summary.values())
|
486 |
-
]
|
487 |
-
|
488 |
-
summary_string = "\n".join(summary_items)
|
489 |
-
|
490 |
-
return transformed_columns, summary_string
|
491 |
-
|
492 |
-
# Function to apply transformations to DataFrame slices based on specified categories and parameters
|
493 |
-
@st.cache_resource(show_spinner=False)
|
494 |
-
def apply_category_transformations(df, bin_dict, transform_params, panel):
|
495 |
-
# Dictionary for function mapping
|
496 |
-
transformation_functions = {
|
497 |
-
"Lead": apply_lead,
|
498 |
-
"Lag": apply_lag,
|
499 |
-
"Moving Average": apply_moving_average,
|
500 |
-
"Saturation": apply_saturation,
|
501 |
-
"Power": apply_power,
|
502 |
-
"Adstock": apply_adstock,
|
503 |
-
}
|
504 |
-
|
505 |
-
# Initialize category_df as an empty DataFrame
|
506 |
-
category_df = pd.DataFrame()
|
507 |
-
|
508 |
-
# Iterate through each category specified in transform_params
|
509 |
-
for category in ["Media", "Internal", "Exogenous"]:
|
510 |
-
if (
|
511 |
-
category not in transform_params
|
512 |
-
or category not in bin_dict
|
513 |
-
or not transform_params[category]
|
514 |
-
):
|
515 |
-
continue # Skip categories without transformations
|
516 |
-
|
517 |
-
# Slice the DataFrame based on the columns specified in bin_dict for the current category
|
518 |
-
df_slice = df[bin_dict[category] + panel]
|
519 |
-
|
520 |
-
# Iterate through each transformation and its parameters for the current category
|
521 |
-
for transformation, parameters in transform_params[
|
522 |
-
category
|
523 |
-
].items():
|
524 |
-
transformation_function = transformation_functions[
|
525 |
-
transformation
|
526 |
-
]
|
527 |
-
|
528 |
-
# Check if there is panel data to group by
|
529 |
-
if len(panel) > 0:
|
530 |
-
# Apply the transformation to each group
|
531 |
-
category_df = pd.concat(
|
532 |
-
[
|
533 |
-
df_slice.groupby(panel)
|
534 |
-
.transform(transformation_function, p)
|
535 |
-
.add_suffix(f"@{transformation}_{p}")
|
536 |
-
for p in parameters
|
537 |
-
],
|
538 |
-
axis=1,
|
539 |
-
)
|
540 |
-
|
541 |
-
# Replace all NaN or null values in category_df with 0
|
542 |
-
category_df.fillna(0, inplace=True)
|
543 |
-
|
544 |
-
# Update df_slice
|
545 |
-
df_slice = pd.concat(
|
546 |
-
[df[panel], category_df],
|
547 |
-
axis=1,
|
548 |
-
)
|
549 |
-
|
550 |
-
else:
|
551 |
-
for p in parameters:
|
552 |
-
# Apply the transformation function to each column
|
553 |
-
temp_df = df_slice.apply(
|
554 |
-
lambda x: transformation_function(x, p), axis=0
|
555 |
-
).rename(
|
556 |
-
lambda x: f"{x}@{transformation}_{p}",
|
557 |
-
axis="columns",
|
558 |
-
)
|
559 |
-
# Concatenate the transformed DataFrame slice to the category DataFrame
|
560 |
-
category_df = pd.concat([category_df, temp_df], axis=1)
|
561 |
-
|
562 |
-
# Replace all NaN or null values in category_df with 0
|
563 |
-
category_df.fillna(0, inplace=True)
|
564 |
-
|
565 |
-
# Update df_slice
|
566 |
-
df_slice = pd.concat(
|
567 |
-
[df[panel], category_df],
|
568 |
-
axis=1,
|
569 |
-
)
|
570 |
-
|
571 |
-
# If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
|
572 |
-
if not category_df.empty:
|
573 |
-
final_df = pd.concat([df, category_df], axis=1)
|
574 |
-
else:
|
575 |
-
# If no transformations were applied, use the original DataFrame
|
576 |
-
final_df = df
|
577 |
-
|
578 |
-
return final_df
|
579 |
-
|
580 |
-
# Function to infers the granularity of the date column in a DataFrame
|
581 |
-
@st.cache_resource(show_spinner=False)
|
582 |
-
def infer_date_granularity(df):
|
583 |
-
# Find the most common difference
|
584 |
-
common_freq = (
|
585 |
-
pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
|
586 |
-
)
|
587 |
-
|
588 |
-
# Map the most common difference to a granularity
|
589 |
-
if common_freq == 1:
|
590 |
-
return "daily"
|
591 |
-
elif common_freq == 7:
|
592 |
-
return "weekly"
|
593 |
-
elif 28 <= common_freq <= 31:
|
594 |
-
return "monthly"
|
595 |
-
else:
|
596 |
-
return "irregular"
|
597 |
-
|
598 |
-
#########################################################################################################################################################
|
599 |
-
# User input for transformations
|
600 |
-
#########################################################################################################################################################
|
601 |
-
|
602 |
-
# Infer date granularity
|
603 |
-
date_granularity = infer_date_granularity(final_df_loaded)
|
604 |
-
|
605 |
-
# Initialize the main dictionary to store the transformation parameters for each category
|
606 |
-
transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
|
607 |
-
|
608 |
-
# User input for transformations
|
609 |
-
st.markdown("### Select Transformations to Apply")
|
610 |
-
for category in ["Media", "Internal", "Exogenous"]:
|
611 |
-
# Skip Internal
|
612 |
-
if category == "Internal":
|
613 |
-
continue
|
614 |
-
|
615 |
-
transformation_widgets(category, transform_params, date_granularity)
|
616 |
-
|
617 |
-
#########################################################################################################################################################
|
618 |
-
# Apply transformations
|
619 |
-
#########################################################################################################################################################
|
620 |
-
|
621 |
-
# Apply category-based transformations to the DataFrame
|
622 |
-
if st.button("Accept and Proceed", use_container_width=True):
|
623 |
-
with st.spinner("Applying transformations..."):
|
624 |
-
final_df = apply_category_transformations(
|
625 |
-
final_df_loaded, bin_dict_loaded, transform_params, panel
|
626 |
-
)
|
627 |
-
|
628 |
-
# Generate a dictionary mapping original column names to lists of transformed column names
|
629 |
-
transformed_columns_dict, summary_string = (
|
630 |
-
generate_transformed_columns(
|
631 |
-
original_columns, transform_params
|
632 |
-
)
|
633 |
-
)
|
634 |
-
|
635 |
-
# Store into transformed dataframe and summary session state
|
636 |
-
st.session_state["final_df"] = final_df
|
637 |
-
st.session_state["summary_string"] = summary_string
|
638 |
-
|
639 |
-
#########################################################################################################################################################
|
640 |
-
# Display the transformed DataFrame and summary
|
641 |
-
#########################################################################################################################################################
|
642 |
-
|
643 |
-
# Display the transformed DataFrame in the Streamlit app
|
644 |
-
st.markdown("### Transformed DataFrame")
|
645 |
-
st.dataframe(st.session_state["final_df"], hide_index=True)
|
646 |
-
|
647 |
-
# Total rows and columns
|
648 |
-
total_rows, total_columns = st.session_state["final_df"].shape
|
649 |
-
st.markdown(
|
650 |
-
f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
|
651 |
-
unsafe_allow_html=True,
|
652 |
-
)
|
653 |
-
|
654 |
-
# Display the summary of transformations as markdown
|
655 |
-
if st.session_state["summary_string"]:
|
656 |
-
with st.expander("Summary of Transformations"):
|
657 |
-
st.markdown("### Summary of Transformations")
|
658 |
-
st.markdown(
|
659 |
-
st.session_state["summary_string"], unsafe_allow_html=True
|
660 |
-
)
|
661 |
-
|
662 |
-
@st.cache_resource(show_spinner=False)
|
663 |
-
def save_to_pickle(file_path, final_df):
|
664 |
-
# Open the file in write-binary mode and dump the objects
|
665 |
-
with open(file_path, "wb") as f:
|
666 |
-
pickle.dump({"final_df_transformed": final_df}, f)
|
667 |
-
# Data is now saved to file
|
668 |
-
|
669 |
-
if st.button("Accept and Save", use_container_width=True):
|
670 |
-
|
671 |
-
save_to_pickle(
|
672 |
-
os.path.join(
|
673 |
-
st.session_state["project_path"], "final_df_transformed.pkl"
|
674 |
-
),
|
675 |
-
st.session_state["final_df"],
|
676 |
-
)
|
677 |
-
project_dct_path = os.path.join(
|
678 |
-
st.session_state["project_path"], "project_dct.pkl"
|
679 |
-
)
|
680 |
-
|
681 |
-
with open(project_dct_path, "wb") as f:
|
682 |
-
pickle.dump(st.session_state["project_dct"], f)
|
683 |
-
|
684 |
-
update_db("3_Transformations.py")
|
685 |
-
|
686 |
-
st.toast("💾 Saved Successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/4_Model_Build.py
DELETED
@@ -1,1062 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
MMO Build Sprint 3
|
3 |
-
additions : adding more variables to session state for saved model : random effect, predicted train & test
|
4 |
-
|
5 |
-
MMO Build Sprint 4
|
6 |
-
additions : ability to run models for different response metrics
|
7 |
-
"""
|
8 |
-
|
9 |
-
import streamlit as st
|
10 |
-
import pandas as pd
|
11 |
-
import plotly.express as px
|
12 |
-
import plotly.graph_objects as go
|
13 |
-
from Eda_functions import format_numbers
|
14 |
-
import numpy as np
|
15 |
-
import pickle
|
16 |
-
from st_aggrid import AgGrid
|
17 |
-
from st_aggrid import GridOptionsBuilder, GridUpdateMode
|
18 |
-
from utilities import set_header, load_local_css
|
19 |
-
from st_aggrid import GridOptionsBuilder
|
20 |
-
import time
|
21 |
-
import itertools
|
22 |
-
import statsmodels.api as sm
|
23 |
-
import numpy as npc
|
24 |
-
import re
|
25 |
-
import itertools
|
26 |
-
from sklearn.metrics import (
|
27 |
-
mean_absolute_error,
|
28 |
-
r2_score,
|
29 |
-
mean_absolute_percentage_error,
|
30 |
-
)
|
31 |
-
from sklearn.preprocessing import MinMaxScaler
|
32 |
-
import os
|
33 |
-
import matplotlib.pyplot as plt
|
34 |
-
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
35 |
-
import yaml
|
36 |
-
from yaml import SafeLoader
|
37 |
-
import streamlit_authenticator as stauth
|
38 |
-
|
39 |
-
st.set_option("deprecation.showPyplotGlobalUse", False)
|
40 |
-
import statsmodels.api as sm
|
41 |
-
import statsmodels.formula.api as smf
|
42 |
-
|
43 |
-
from datetime import datetime
|
44 |
-
import seaborn as sns
|
45 |
-
from Data_prep_functions import *
|
46 |
-
import sqlite3
|
47 |
-
from utilities import update_db
|
48 |
-
|
49 |
-
|
50 |
-
@st.cache_resource(show_spinner=False)
|
51 |
-
# def save_to_pickle(file_path, final_df):
|
52 |
-
# # Open the file in write-binary mode and dump the objects
|
53 |
-
# with open(file_path, "wb") as f:
|
54 |
-
# pickle.dump({file_path: final_df}, f)
|
55 |
-
|
56 |
-
|
57 |
-
def get_random_effects(media_data, panel_col, _mdf):
|
58 |
-
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
59 |
-
|
60 |
-
for i, market in enumerate(media_data[panel_col].unique()):
|
61 |
-
print(i, end="\r")
|
62 |
-
intercept = _mdf.random_effects[market].values[0]
|
63 |
-
random_eff_df.loc[i, "random_effect"] = intercept
|
64 |
-
random_eff_df.loc[i, panel_col] = market
|
65 |
-
|
66 |
-
return random_eff_df
|
67 |
-
|
68 |
-
|
69 |
-
def mdf_predict(X_df, mdf, random_eff_df):
|
70 |
-
X = X_df.copy()
|
71 |
-
X["fixed_effect"] = mdf.predict(X)
|
72 |
-
X = pd.merge(X, random_eff_df, on=panel_col, how="left")
|
73 |
-
X["pred"] = X["fixed_effect"] + X["random_effect"]
|
74 |
-
# X.to_csv('Test/megred_df.csv',index=False)
|
75 |
-
X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
|
76 |
-
return X["pred"]
|
77 |
-
|
78 |
-
|
79 |
-
st.set_page_config(
|
80 |
-
page_title="Model Build",
|
81 |
-
page_icon=":shark:",
|
82 |
-
layout="wide",
|
83 |
-
initial_sidebar_state="collapsed",
|
84 |
-
)
|
85 |
-
|
86 |
-
load_local_css("styles.css")
|
87 |
-
set_header()
|
88 |
-
|
89 |
-
# Check for authentication status
|
90 |
-
for k, v in st.session_state.items():
|
91 |
-
if k not in [
|
92 |
-
"logout",
|
93 |
-
"login",
|
94 |
-
"config",
|
95 |
-
"model_build_button",
|
96 |
-
] and not k.startswith("FormSubmitter"):
|
97 |
-
st.session_state[k] = v
|
98 |
-
with open("config.yaml") as file:
|
99 |
-
config = yaml.load(file, Loader=SafeLoader)
|
100 |
-
st.session_state["config"] = config
|
101 |
-
authenticator = stauth.Authenticate(
|
102 |
-
config["credentials"],
|
103 |
-
config["cookie"]["name"],
|
104 |
-
config["cookie"]["key"],
|
105 |
-
config["cookie"]["expiry_days"],
|
106 |
-
config["preauthorized"],
|
107 |
-
)
|
108 |
-
st.session_state["authenticator"] = authenticator
|
109 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
110 |
-
auth_status = st.session_state.get("authentication_status")
|
111 |
-
|
112 |
-
if auth_status == True:
|
113 |
-
authenticator.logout("Logout", "main")
|
114 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
115 |
-
|
116 |
-
conn = sqlite3.connect(
|
117 |
-
r"DB/User.db", check_same_thread=False
|
118 |
-
) # connection with sql db
|
119 |
-
c = conn.cursor()
|
120 |
-
|
121 |
-
if not is_state_initiaized:
|
122 |
-
|
123 |
-
if "session_name" not in st.session_state:
|
124 |
-
st.session_state["session_name"] = None
|
125 |
-
|
126 |
-
if "project_dct" not in st.session_state:
|
127 |
-
st.error("Please load a project from Home page")
|
128 |
-
st.stop()
|
129 |
-
|
130 |
-
st.title("1. Build Your Model")
|
131 |
-
|
132 |
-
if not os.path.exists(
|
133 |
-
os.path.join(st.session_state["project_path"], "data_import.pkl")
|
134 |
-
):
|
135 |
-
st.error("Please move to Data Import Page and save.")
|
136 |
-
st.stop()
|
137 |
-
with open(
|
138 |
-
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
|
139 |
-
) as f:
|
140 |
-
data = pickle.load(f)
|
141 |
-
st.session_state["bin_dict"] = data["bin_dict"]
|
142 |
-
|
143 |
-
if not os.path.exists(
|
144 |
-
os.path.join(
|
145 |
-
st.session_state["project_path"], "final_df_transformed.pkl"
|
146 |
-
)
|
147 |
-
):
|
148 |
-
st.error(
|
149 |
-
"Please move to Transformation Page and save transformations."
|
150 |
-
)
|
151 |
-
st.stop()
|
152 |
-
with open(
|
153 |
-
os.path.join(
|
154 |
-
st.session_state["project_path"], "final_df_transformed.pkl"
|
155 |
-
),
|
156 |
-
"rb",
|
157 |
-
) as f:
|
158 |
-
data = pickle.load(f)
|
159 |
-
media_data = data["final_df_transformed"]
|
160 |
-
#media_data.to_csv("Test/media_data.csv", index=False)
|
161 |
-
train_idx = int(len(media_data) / 5) * 4
|
162 |
-
# Sprint4 - available response metrics is a list of all reponse metrics in the data
|
163 |
-
## these will be put in a drop down
|
164 |
-
|
165 |
-
st.session_state["media_data"] = media_data
|
166 |
-
|
167 |
-
if "available_response_metrics" not in st.session_state:
|
168 |
-
# st.session_state['available_response_metrics'] = ['Total Approved Accounts - Revenue',
|
169 |
-
# 'Total Approved Accounts - Appsflyer',
|
170 |
-
# 'Account Requests - Appsflyer',
|
171 |
-
# 'App Installs - Appsflyer']
|
172 |
-
|
173 |
-
st.session_state["available_response_metrics"] = st.session_state[
|
174 |
-
"bin_dict"
|
175 |
-
]["Response Metrics"]
|
176 |
-
# Sprint4
|
177 |
-
if "is_tuned_model" not in st.session_state:
|
178 |
-
st.session_state["is_tuned_model"] = {}
|
179 |
-
for resp_metric in st.session_state["available_response_metrics"]:
|
180 |
-
resp_metric = (
|
181 |
-
resp_metric.lower()
|
182 |
-
.replace(" ", "_")
|
183 |
-
.replace("-", "")
|
184 |
-
.replace(":", "")
|
185 |
-
.replace("__", "_")
|
186 |
-
)
|
187 |
-
st.session_state["is_tuned_model"][resp_metric] = False
|
188 |
-
|
189 |
-
# Sprint4 - used_response_metrics is a list of resp metrics for which user has created & saved a model
|
190 |
-
if "used_response_metrics" not in st.session_state:
|
191 |
-
st.session_state["used_response_metrics"] = []
|
192 |
-
|
193 |
-
# Sprint4 - saved_model_names
|
194 |
-
if "saved_model_names" not in st.session_state:
|
195 |
-
st.session_state["saved_model_names"] = []
|
196 |
-
|
197 |
-
if "Model" not in st.session_state:
|
198 |
-
if (
|
199 |
-
"session_state_saved"
|
200 |
-
in st.session_state["project_dct"]["model_build"].keys()
|
201 |
-
and st.session_state["project_dct"]["model_build"][
|
202 |
-
"session_state_saved"
|
203 |
-
]
|
204 |
-
is not None
|
205 |
-
and "Model"
|
206 |
-
in st.session_state["project_dct"]["model_build"][
|
207 |
-
"session_state_saved"
|
208 |
-
].keys()
|
209 |
-
):
|
210 |
-
st.session_state["Model"] = st.session_state["project_dct"][
|
211 |
-
"model_build"
|
212 |
-
]["session_state_saved"]["Model"]
|
213 |
-
else:
|
214 |
-
st.session_state["Model"] = {}
|
215 |
-
|
216 |
-
# Sprint4 - select a response metric
|
217 |
-
default_target_idx = (
|
218 |
-
st.session_state["project_dct"]["model_build"].get(
|
219 |
-
"sel_target_col", None
|
220 |
-
)
|
221 |
-
if st.session_state["project_dct"]["model_build"].get(
|
222 |
-
"sel_target_col", None
|
223 |
-
)
|
224 |
-
is not None
|
225 |
-
else st.session_state["available_response_metrics"][0]
|
226 |
-
)
|
227 |
-
|
228 |
-
sel_target_col = st.selectbox(
|
229 |
-
"Select the response metric",
|
230 |
-
st.session_state["available_response_metrics"],
|
231 |
-
index=st.session_state["available_response_metrics"].index(
|
232 |
-
default_target_idx
|
233 |
-
),
|
234 |
-
)
|
235 |
-
# , on_change=reset_save())
|
236 |
-
st.session_state["project_dct"]["model_build"][
|
237 |
-
"sel_target_col"
|
238 |
-
] = sel_target_col
|
239 |
-
|
240 |
-
target_col = (
|
241 |
-
sel_target_col.lower()
|
242 |
-
.replace(" ", "_")
|
243 |
-
.replace("-", "")
|
244 |
-
.replace(":", "")
|
245 |
-
.replace("__", "_")
|
246 |
-
)
|
247 |
-
new_name_dct = {
|
248 |
-
col: col.lower()
|
249 |
-
.replace(".", "_")
|
250 |
-
.lower()
|
251 |
-
.replace("@", "_")
|
252 |
-
.replace(" ", "_")
|
253 |
-
.replace("-", "")
|
254 |
-
.replace(":", "")
|
255 |
-
.replace("__", "_")
|
256 |
-
for col in media_data.columns
|
257 |
-
}
|
258 |
-
media_data.columns = [
|
259 |
-
col.lower()
|
260 |
-
.replace(".", "_")
|
261 |
-
.replace("@", "_")
|
262 |
-
.replace(" ", "_")
|
263 |
-
.replace("-", "")
|
264 |
-
.replace(":", "")
|
265 |
-
.replace("__", "_")
|
266 |
-
for col in media_data.columns
|
267 |
-
]
|
268 |
-
panel_col = [
|
269 |
-
col.lower()
|
270 |
-
.replace(".", "_")
|
271 |
-
.replace("@", "_")
|
272 |
-
.replace(" ", "_")
|
273 |
-
.replace("-", "")
|
274 |
-
.replace(":", "")
|
275 |
-
.replace("__", "_")
|
276 |
-
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
277 |
-
][
|
278 |
-
0
|
279 |
-
] # set the panel column
|
280 |
-
date_col = "date"
|
281 |
-
|
282 |
-
is_panel = True if len(panel_col) > 0 else False
|
283 |
-
|
284 |
-
if "is_panel" not in st.session_state:
|
285 |
-
st.session_state["is_panel"] = is_panel
|
286 |
-
|
287 |
-
if is_panel:
|
288 |
-
media_data.sort_values([date_col, panel_col], inplace=True)
|
289 |
-
else:
|
290 |
-
media_data.sort_values(date_col, inplace=True)
|
291 |
-
|
292 |
-
media_data.reset_index(drop=True, inplace=True)
|
293 |
-
|
294 |
-
date = media_data[date_col]
|
295 |
-
st.session_state["date"] = date
|
296 |
-
y = media_data[target_col]
|
297 |
-
|
298 |
-
if is_panel:
|
299 |
-
spends_data = media_data[
|
300 |
-
[
|
301 |
-
c
|
302 |
-
for c in media_data.columns
|
303 |
-
if "_cost" in c.lower() or "_spend" in c.lower()
|
304 |
-
]
|
305 |
-
+ [date_col, panel_col]
|
306 |
-
]
|
307 |
-
# Sprint3 - spends for resp curves
|
308 |
-
else:
|
309 |
-
spends_data = media_data[
|
310 |
-
[
|
311 |
-
c
|
312 |
-
for c in media_data.columns
|
313 |
-
if "_cost" in c.lower() or "_spend" in c.lower()
|
314 |
-
]
|
315 |
-
+ [date_col]
|
316 |
-
]
|
317 |
-
|
318 |
-
y = media_data[target_col]
|
319 |
-
media_data.drop([date_col], axis=1, inplace=True)
|
320 |
-
media_data.reset_index(drop=True, inplace=True)
|
321 |
-
|
322 |
-
columns = st.columns(2)
|
323 |
-
|
324 |
-
old_shape = media_data.shape
|
325 |
-
|
326 |
-
if "old_shape" not in st.session_state:
|
327 |
-
st.session_state["old_shape"] = old_shape
|
328 |
-
|
329 |
-
if "media_data" not in st.session_state:
|
330 |
-
st.session_state["media_data"] = pd.DataFrame()
|
331 |
-
|
332 |
-
# Sprint3
|
333 |
-
if "orig_media_data" not in st.session_state:
|
334 |
-
st.session_state["orig_media_data"] = pd.DataFrame()
|
335 |
-
|
336 |
-
# Sprint3 additions
|
337 |
-
if "random_effects" not in st.session_state:
|
338 |
-
st.session_state["random_effects"] = pd.DataFrame()
|
339 |
-
if "pred_train" not in st.session_state:
|
340 |
-
st.session_state["pred_train"] = []
|
341 |
-
if "pred_test" not in st.session_state:
|
342 |
-
st.session_state["pred_test"] = []
|
343 |
-
# end of Sprint3 additions
|
344 |
-
|
345 |
-
# Section 3 - Create combinations
|
346 |
-
|
347 |
-
# bucket=['paid_search', 'kwai','indicacao','infleux', 'influencer','FB: Level Achieved - Tier 1 Impressions',
|
348 |
-
# ' FB: Level Achieved - Tier 2 Impressions','paid_social_others',
|
349 |
-
# ' GA App: Will And Cid Pequena Baixo Risco Clicks',
|
350 |
-
# 'digital_tactic_others',"programmatic"
|
351 |
-
# ]
|
352 |
-
|
353 |
-
# srishti - bucket names changed
|
354 |
-
bucket = [
|
355 |
-
"paid_search",
|
356 |
-
"kwai",
|
357 |
-
"indicacao",
|
358 |
-
"infleux",
|
359 |
-
"influencer",
|
360 |
-
"fb_level_achieved_tier_2",
|
361 |
-
"fb_level_achieved_tier_1",
|
362 |
-
"paid_social_others",
|
363 |
-
"ga_app",
|
364 |
-
"digital_tactic_others",
|
365 |
-
"programmatic",
|
366 |
-
]
|
367 |
-
|
368 |
-
# with columns[0]:
|
369 |
-
# if st.button('Create Combinations of Variables'):
|
370 |
-
|
371 |
-
top_3_correlated_features = []
|
372 |
-
# # for col in st.session_state['media_data'].columns[:19]:
|
373 |
-
# original_cols = [c for c in st.session_state['media_data'].columns if
|
374 |
-
# "_clicks" in c.lower() or "_impressions" in c.lower()]
|
375 |
-
# original_cols = [c for c in original_cols if "_lag" not in c.lower() and "_adstock" not in c.lower()]
|
376 |
-
|
377 |
-
original_cols = (
|
378 |
-
st.session_state["bin_dict"]["Media"]
|
379 |
-
+ st.session_state["bin_dict"]["Internal"]
|
380 |
-
)
|
381 |
-
|
382 |
-
original_cols = [
|
383 |
-
col.lower()
|
384 |
-
.replace(".", "_")
|
385 |
-
.replace("@", "_")
|
386 |
-
.replace(" ", "_")
|
387 |
-
.replace("-", "")
|
388 |
-
.replace(":", "")
|
389 |
-
.replace("__", "_")
|
390 |
-
for col in original_cols
|
391 |
-
]
|
392 |
-
original_cols = [col for col in original_cols if "_cost" not in col]
|
393 |
-
# for col in st.session_state['media_data'].columns[:19]:
|
394 |
-
for col in original_cols: # srishti - new
|
395 |
-
corr_df = (
|
396 |
-
pd.concat(
|
397 |
-
[st.session_state["media_data"].filter(regex=col), y], axis=1
|
398 |
-
)
|
399 |
-
.corr()[target_col]
|
400 |
-
.iloc[:-1]
|
401 |
-
)
|
402 |
-
top_3_correlated_features.append(
|
403 |
-
list(corr_df.sort_values(ascending=False).head(2).index)
|
404 |
-
)
|
405 |
-
flattened_list = [
|
406 |
-
item for sublist in top_3_correlated_features for item in sublist
|
407 |
-
]
|
408 |
-
# all_features_set={var:[col for col in flattened_list if var in col] for var in bucket}
|
409 |
-
all_features_set = {
|
410 |
-
var: [col for col in flattened_list if var in col]
|
411 |
-
for var in bucket
|
412 |
-
if len([col for col in flattened_list if var in col]) > 0
|
413 |
-
} # srishti
|
414 |
-
channels_all = [values for values in all_features_set.values()]
|
415 |
-
st.session_state["combinations"] = list(itertools.product(*channels_all))
|
416 |
-
# if 'combinations' not in st.session_state:
|
417 |
-
# st.session_state['combinations']=combinations_all
|
418 |
-
|
419 |
-
st.session_state["final_selection"] = st.session_state["combinations"]
|
420 |
-
# st.success('Created combinations')
|
421 |
-
|
422 |
-
# revenue.reset_index(drop=True,inplace=True)
|
423 |
-
y.reset_index(drop=True, inplace=True)
|
424 |
-
if "Model_results" not in st.session_state:
|
425 |
-
st.session_state["Model_results"] = {
|
426 |
-
"Model_object": [],
|
427 |
-
"Model_iteration": [],
|
428 |
-
"Feature_set": [],
|
429 |
-
"MAPE": [],
|
430 |
-
"R2": [],
|
431 |
-
"ADJR2": [],
|
432 |
-
"pos_count": [],
|
433 |
-
}
|
434 |
-
|
435 |
-
def reset_model_result_dct():
|
436 |
-
st.session_state["Model_results"] = {
|
437 |
-
"Model_object": [],
|
438 |
-
"Model_iteration": [],
|
439 |
-
"Feature_set": [],
|
440 |
-
"MAPE": [],
|
441 |
-
"R2": [],
|
442 |
-
"ADJR2": [],
|
443 |
-
"pos_count": [],
|
444 |
-
}
|
445 |
-
|
446 |
-
# if st.button('Build Model'):
|
447 |
-
|
448 |
-
if "iterations" not in st.session_state:
|
449 |
-
st.session_state["iterations"] = 0
|
450 |
-
|
451 |
-
if "final_selection" not in st.session_state:
|
452 |
-
st.session_state["final_selection"] = False
|
453 |
-
|
454 |
-
save_path = r"Model/"
|
455 |
-
if st.session_state["final_selection"]:
|
456 |
-
st.write(
|
457 |
-
f'Total combinations created {format_numbers(len(st.session_state["final_selection"]))}'
|
458 |
-
)
|
459 |
-
|
460 |
-
# st.session_state["project_dct"]["model_build"]["all_iters_check"] = False
|
461 |
-
|
462 |
-
checkbox_default = (
|
463 |
-
st.session_state["project_dct"]["model_build"]["all_iters_check"]
|
464 |
-
if st.session_state["project_dct"]["model_build"]["all_iters_check"]
|
465 |
-
is not None
|
466 |
-
else False
|
467 |
-
)
|
468 |
-
|
469 |
-
if st.checkbox("Build all iterations", value=checkbox_default):
|
470 |
-
# st.session_state["project_dct"]["model_build"]["all_iters_check"]
|
471 |
-
iterations = len(st.session_state["final_selection"])
|
472 |
-
st.session_state["project_dct"]["model_build"][
|
473 |
-
"all_iters_check"
|
474 |
-
] = True
|
475 |
-
|
476 |
-
else:
|
477 |
-
iterations = st.number_input(
|
478 |
-
"Select the number of iterations to perform",
|
479 |
-
min_value=0,
|
480 |
-
step=100,
|
481 |
-
value=st.session_state["iterations"],
|
482 |
-
on_change=reset_model_result_dct,
|
483 |
-
)
|
484 |
-
st.session_state["project_dct"]["model_build"][
|
485 |
-
"all_iters_check"
|
486 |
-
] = False
|
487 |
-
st.session_state["project_dct"]["model_build"][
|
488 |
-
"iterations"
|
489 |
-
] = iterations
|
490 |
-
|
491 |
-
# st.stop()
|
492 |
-
|
493 |
-
# build_button = st.session_state["project_dct"]["model_build"]["build_button"] if \
|
494 |
-
# "build_button" in st.session_state["project_dct"]["model_build"].keys() else False
|
495 |
-
# model_button =st.button('Build Model', on_click=reset_model_result_dct, key='model_build_button')
|
496 |
-
# if
|
497 |
-
# if model_button:
|
498 |
-
if st.button(
|
499 |
-
"Build Model",
|
500 |
-
on_click=reset_model_result_dct,
|
501 |
-
key="model_build_button",
|
502 |
-
):
|
503 |
-
if iterations < 1:
|
504 |
-
st.error("Please select number of iterations")
|
505 |
-
st.stop()
|
506 |
-
st.session_state["project_dct"]["model_build"]["build_button"] = True
|
507 |
-
st.session_state["iterations"] = iterations
|
508 |
-
|
509 |
-
# Section 4 - Model
|
510 |
-
# st.session_state['media_data'] = st.session_state['media_data'].fillna(method='ffill')
|
511 |
-
st.session_state["media_data"] = st.session_state["media_data"].ffill()
|
512 |
-
st.markdown(
|
513 |
-
"Data Split -- Training Period: May 9th, 2023 - October 5th,2023 , Testing Period: October 6th, 2023 - November 7th, 2023 "
|
514 |
-
)
|
515 |
-
progress_bar = st.progress(0) # Initialize the progress bar
|
516 |
-
# time_remaining_text = st.empty() # Create an empty space for time remaining text
|
517 |
-
start_time = time.time() # Record the start time
|
518 |
-
progress_text = st.empty()
|
519 |
-
|
520 |
-
# time_elapsed_text = st.empty()
|
521 |
-
# for i, selected_features in enumerate(st.session_state["final_selection"][40000:40000 + int(iterations)]):
|
522 |
-
# for i, selected_features in enumerate(st.session_state["final_selection"]):
|
523 |
-
|
524 |
-
if is_panel == True:
|
525 |
-
for i, selected_features in enumerate(
|
526 |
-
st.session_state["final_selection"][0 : int(iterations)]
|
527 |
-
): # srishti
|
528 |
-
df = st.session_state["media_data"]
|
529 |
-
|
530 |
-
fet = [var for var in selected_features if len(var) > 0]
|
531 |
-
inp_vars_str = " + ".join(fet) # new
|
532 |
-
|
533 |
-
X = df[fet]
|
534 |
-
y = df[target_col]
|
535 |
-
ss = MinMaxScaler()
|
536 |
-
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
537 |
-
|
538 |
-
X[target_col] = y # Sprint2
|
539 |
-
X[panel_col] = df[panel_col] # Sprint2
|
540 |
-
|
541 |
-
X_train = X.iloc[:train_idx]
|
542 |
-
X_test = X.iloc[train_idx:]
|
543 |
-
y_train = y.iloc[:train_idx]
|
544 |
-
y_test = y.iloc[train_idx:]
|
545 |
-
|
546 |
-
print(X_train.shape)
|
547 |
-
# model = sm.OLS(y_train, X_train).fit()
|
548 |
-
md_str = target_col + " ~ " + inp_vars_str
|
549 |
-
# md = smf.mixedlm("total_approved_accounts_revenue ~ {}".format(inp_vars_str),
|
550 |
-
# data=X_train[[target_col] + fet],
|
551 |
-
# groups=X_train[panel_col])
|
552 |
-
md = smf.mixedlm(
|
553 |
-
md_str,
|
554 |
-
data=X_train[[target_col] + fet],
|
555 |
-
groups=X_train[panel_col],
|
556 |
-
)
|
557 |
-
mdf = md.fit()
|
558 |
-
predicted_values = mdf.fittedvalues
|
559 |
-
|
560 |
-
coefficients = mdf.fe_params.to_dict()
|
561 |
-
model_positive = [
|
562 |
-
col for col in coefficients.keys() if coefficients[col] > 0
|
563 |
-
]
|
564 |
-
|
565 |
-
pvalues = [var for var in list(mdf.pvalues) if var <= 0.06]
|
566 |
-
|
567 |
-
if (len(model_positive) / len(selected_features)) > 0 and (
|
568 |
-
len(pvalues) / len(selected_features)
|
569 |
-
) >= 0: # srishti - changed just for testing, revert later
|
570 |
-
# predicted_values = model.predict(X_train)
|
571 |
-
mape = mean_absolute_percentage_error(
|
572 |
-
y_train, predicted_values
|
573 |
-
)
|
574 |
-
r2 = r2_score(y_train, predicted_values)
|
575 |
-
adjr2 = 1 - (1 - r2) * (len(y_train) - 1) / (
|
576 |
-
len(y_train) - len(selected_features) - 1
|
577 |
-
)
|
578 |
-
|
579 |
-
filename = os.path.join(save_path, f"model_{i}.pkl")
|
580 |
-
with open(filename, "wb") as f:
|
581 |
-
pickle.dump(mdf, f)
|
582 |
-
# with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
|
583 |
-
# model = pickle.load(file)
|
584 |
-
|
585 |
-
st.session_state["Model_results"]["Model_object"].append(
|
586 |
-
filename
|
587 |
-
)
|
588 |
-
st.session_state["Model_results"][
|
589 |
-
"Model_iteration"
|
590 |
-
].append(i)
|
591 |
-
st.session_state["Model_results"]["Feature_set"].append(
|
592 |
-
fet
|
593 |
-
)
|
594 |
-
st.session_state["Model_results"]["MAPE"].append(mape)
|
595 |
-
st.session_state["Model_results"]["R2"].append(r2)
|
596 |
-
st.session_state["Model_results"]["pos_count"].append(
|
597 |
-
len(model_positive)
|
598 |
-
)
|
599 |
-
st.session_state["Model_results"]["ADJR2"].append(adjr2)
|
600 |
-
|
601 |
-
current_time = time.time()
|
602 |
-
time_taken = current_time - start_time
|
603 |
-
time_elapsed_minutes = time_taken / 60
|
604 |
-
completed_iterations_text = f"{i + 1}/{iterations}"
|
605 |
-
progress_bar.progress((i + 1) / int(iterations))
|
606 |
-
progress_text.text(
|
607 |
-
f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
|
608 |
-
)
|
609 |
-
st.write(
|
610 |
-
f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
|
611 |
-
)
|
612 |
-
|
613 |
-
else:
|
614 |
-
|
615 |
-
for i, selected_features in enumerate(
|
616 |
-
st.session_state["final_selection"][0 : int(iterations)]
|
617 |
-
): # srishti
|
618 |
-
df = st.session_state["media_data"]
|
619 |
-
|
620 |
-
fet = [var for var in selected_features if len(var) > 0]
|
621 |
-
inp_vars_str = " + ".join(fet)
|
622 |
-
|
623 |
-
X = df[fet]
|
624 |
-
y = df[target_col]
|
625 |
-
ss = MinMaxScaler()
|
626 |
-
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
627 |
-
X = sm.add_constant(X)
|
628 |
-
X_train = X.iloc[:130]
|
629 |
-
X_test = X.iloc[130:]
|
630 |
-
y_train = y.iloc[:130]
|
631 |
-
y_test = y.iloc[130:]
|
632 |
-
|
633 |
-
model = sm.OLS(y_train, X_train).fit()
|
634 |
-
|
635 |
-
coefficients = model.params.to_list()
|
636 |
-
model_positive = [coef for coef in coefficients if coef > 0]
|
637 |
-
predicted_values = model.predict(X_train)
|
638 |
-
pvalues = [var for var in list(model.pvalues) if var <= 0.06]
|
639 |
-
|
640 |
-
# if (len(model_possitive) / len(selected_features)) > 0.9 and (len(pvalues) / len(selected_features)) >= 0.8:
|
641 |
-
if (len(model_positive) / len(selected_features)) > 0 and (
|
642 |
-
len(pvalues) / len(selected_features)
|
643 |
-
) >= 0.5: # srishti - changed just for testing, revert later VALID MODEL CRITERIA
|
644 |
-
# predicted_values = model.predict(X_train)
|
645 |
-
mape = mean_absolute_percentage_error(
|
646 |
-
y_train, predicted_values
|
647 |
-
)
|
648 |
-
adjr2 = model.rsquared_adj
|
649 |
-
r2 = model.rsquared
|
650 |
-
|
651 |
-
filename = os.path.join(save_path, f"model_{i}.pkl")
|
652 |
-
with open(filename, "wb") as f:
|
653 |
-
pickle.dump(model, f)
|
654 |
-
# with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
|
655 |
-
# model = pickle.load(file)
|
656 |
-
|
657 |
-
st.session_state["Model_results"]["Model_object"].append(
|
658 |
-
filename
|
659 |
-
)
|
660 |
-
st.session_state["Model_results"][
|
661 |
-
"Model_iteration"
|
662 |
-
].append(i)
|
663 |
-
st.session_state["Model_results"]["Feature_set"].append(
|
664 |
-
fet
|
665 |
-
)
|
666 |
-
st.session_state["Model_results"]["MAPE"].append(mape)
|
667 |
-
st.session_state["Model_results"]["R2"].append(r2)
|
668 |
-
st.session_state["Model_results"]["ADJR2"].append(adjr2)
|
669 |
-
st.session_state["Model_results"]["pos_count"].append(
|
670 |
-
len(model_positive)
|
671 |
-
)
|
672 |
-
|
673 |
-
current_time = time.time()
|
674 |
-
time_taken = current_time - start_time
|
675 |
-
time_elapsed_minutes = time_taken / 60
|
676 |
-
completed_iterations_text = f"{i + 1}/{iterations}"
|
677 |
-
progress_bar.progress((i + 1) / int(iterations))
|
678 |
-
progress_text.text(
|
679 |
-
f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
|
680 |
-
)
|
681 |
-
st.write(
|
682 |
-
f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
|
683 |
-
)
|
684 |
-
|
685 |
-
pd.DataFrame(st.session_state["Model_results"]).to_csv(
|
686 |
-
"model_output.csv"
|
687 |
-
)
|
688 |
-
|
689 |
-
def to_percentage(value):
|
690 |
-
return f"{value * 100:.1f}%"
|
691 |
-
|
692 |
-
## Section 5 - Select Model
|
693 |
-
st.title("2. Select Models")
|
694 |
-
show_results_defualt = (
|
695 |
-
st.session_state["project_dct"]["model_build"]["show_results_check"]
|
696 |
-
if st.session_state["project_dct"]["model_build"]["show_results_check"]
|
697 |
-
is not None
|
698 |
-
else False
|
699 |
-
)
|
700 |
-
if "tick" not in st.session_state:
|
701 |
-
st.session_state["tick"] = False
|
702 |
-
if st.checkbox(
|
703 |
-
"Show results of top 10 models (based on MAPE and Adj. R2)",
|
704 |
-
value=show_results_defualt,
|
705 |
-
):
|
706 |
-
st.session_state["project_dct"]["model_build"][
|
707 |
-
"show_results_check"
|
708 |
-
] = True
|
709 |
-
st.session_state["tick"] = True
|
710 |
-
st.write(
|
711 |
-
"Select one model iteration to generate performance metrics for it:"
|
712 |
-
)
|
713 |
-
data = pd.DataFrame(st.session_state["Model_results"])
|
714 |
-
data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
|
715 |
-
drop=True
|
716 |
-
) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
|
717 |
-
data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
|
718 |
-
data.drop_duplicates(subset="Model_iteration", inplace=True)
|
719 |
-
top_10 = data.head(10)
|
720 |
-
top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
|
721 |
-
top_10[["MAPE", "R2", "ADJR2"]] = np.round(
|
722 |
-
top_10[["MAPE", "R2", "ADJR2"]], 4
|
723 |
-
).applymap(to_percentage)
|
724 |
-
top_10_table = top_10[
|
725 |
-
["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
|
726 |
-
]
|
727 |
-
# top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
|
728 |
-
gd = GridOptionsBuilder.from_dataframe(top_10_table)
|
729 |
-
gd.configure_pagination(enabled=True)
|
730 |
-
|
731 |
-
gd.configure_selection(
|
732 |
-
use_checkbox=True,
|
733 |
-
selection_mode="single",
|
734 |
-
pre_select_all_rows=False,
|
735 |
-
pre_selected_rows=[1],
|
736 |
-
)
|
737 |
-
|
738 |
-
gridoptions = gd.build()
|
739 |
-
|
740 |
-
table = AgGrid(
|
741 |
-
top_10,
|
742 |
-
gridOptions=gridoptions,
|
743 |
-
update_mode=GridUpdateMode.SELECTION_CHANGED,
|
744 |
-
)
|
745 |
-
|
746 |
-
selected_rows = table.selected_rows
|
747 |
-
# if st.session_state["selected_rows"] != selected_rows:
|
748 |
-
# st.session_state["build_rc_cb"] = False
|
749 |
-
st.session_state["selected_rows"] = selected_rows
|
750 |
-
|
751 |
-
# Section 6 - Display Results
|
752 |
-
|
753 |
-
if len(selected_rows) > 0:
|
754 |
-
st.header("2.1 Results Summary")
|
755 |
-
|
756 |
-
model_object = data[
|
757 |
-
data["Model_iteration"] == selected_rows[0]["Model_iteration"]
|
758 |
-
]["Model_object"]
|
759 |
-
features_set = data[
|
760 |
-
data["Model_iteration"] == selected_rows[0]["Model_iteration"]
|
761 |
-
]["Feature_set"]
|
762 |
-
|
763 |
-
with open(str(model_object.values[0]), "rb") as file:
|
764 |
-
# print(file)
|
765 |
-
model = pickle.load(file)
|
766 |
-
st.write(model.summary())
|
767 |
-
st.header("2.2 Actual vs. Predicted Plot")
|
768 |
-
|
769 |
-
if is_panel:
|
770 |
-
df = st.session_state["media_data"]
|
771 |
-
X = df[features_set.values[0]]
|
772 |
-
y = df[target_col]
|
773 |
-
|
774 |
-
ss = MinMaxScaler()
|
775 |
-
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
776 |
-
|
777 |
-
# Sprint2 changes
|
778 |
-
X[target_col] = y # new
|
779 |
-
X[panel_col] = df[panel_col]
|
780 |
-
X[date_col] = date
|
781 |
-
|
782 |
-
X_train = X.iloc[:train_idx]
|
783 |
-
X_test = X.iloc[train_idx:].reset_index(drop=True)
|
784 |
-
y_train = y.iloc[:train_idx]
|
785 |
-
y_test = y.iloc[train_idx:].reset_index(drop=True)
|
786 |
-
|
787 |
-
test_spends = spends_data[
|
788 |
-
train_idx:
|
789 |
-
] # Sprint3 - test spends for resp curves
|
790 |
-
random_eff_df = get_random_effects(
|
791 |
-
media_data, panel_col, model
|
792 |
-
)
|
793 |
-
train_pred = model.fittedvalues
|
794 |
-
test_pred = mdf_predict(X_test, model, random_eff_df)
|
795 |
-
print("__" * 20, test_pred.isna().sum())
|
796 |
-
|
797 |
-
else:
|
798 |
-
df = st.session_state["media_data"]
|
799 |
-
X = df[features_set.values[0]]
|
800 |
-
y = df[target_col]
|
801 |
-
|
802 |
-
ss = MinMaxScaler()
|
803 |
-
X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
804 |
-
X = sm.add_constant(X)
|
805 |
-
|
806 |
-
X[date_col] = date
|
807 |
-
|
808 |
-
X_train = X.iloc[:130]
|
809 |
-
X_test = X.iloc[130:].reset_index(drop=True)
|
810 |
-
y_train = y.iloc[:130]
|
811 |
-
y_test = y.iloc[130:].reset_index(drop=True)
|
812 |
-
|
813 |
-
test_spends = spends_data[
|
814 |
-
130:
|
815 |
-
] # Sprint3 - test spends for resp curves
|
816 |
-
train_pred = model.predict(
|
817 |
-
X_train[features_set.values[0] + ["const"]]
|
818 |
-
)
|
819 |
-
test_pred = model.predict(
|
820 |
-
X_test[features_set.values[0] + ["const"]]
|
821 |
-
)
|
822 |
-
|
823 |
-
# save x test to test - srishti
|
824 |
-
# x_test_to_save = X_test.copy()
|
825 |
-
# x_test_to_save['Actuals'] = y_test
|
826 |
-
# x_test_to_save['Predictions'] = test_pred
|
827 |
-
#
|
828 |
-
# x_train_to_save = X_train.copy()
|
829 |
-
# x_train_to_save['Actuals'] = y_train
|
830 |
-
# x_train_to_save['Predictions'] = train_pred
|
831 |
-
#
|
832 |
-
# x_train_to_save.to_csv('Test/x_train_to_save.csv', index=False)
|
833 |
-
# x_test_to_save.to_csv('Test/x_test_to_save.csv', index=False)
|
834 |
-
|
835 |
-
st.session_state["X"] = X_train
|
836 |
-
st.session_state["features_set"] = features_set.values[0]
|
837 |
-
print(
|
838 |
-
"**" * 20, "selected model features : ", features_set.values[0]
|
839 |
-
)
|
840 |
-
metrics_table, line, actual_vs_predicted_plot = (
|
841 |
-
plot_actual_vs_predicted(
|
842 |
-
X_train[date_col],
|
843 |
-
y_train,
|
844 |
-
train_pred,
|
845 |
-
model,
|
846 |
-
target_column=sel_target_col,
|
847 |
-
is_panel=is_panel,
|
848 |
-
)
|
849 |
-
) # Sprint2
|
850 |
-
|
851 |
-
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
|
852 |
-
|
853 |
-
st.markdown("## 2.3 Residual Analysis")
|
854 |
-
columns = st.columns(2)
|
855 |
-
with columns[0]:
|
856 |
-
fig = plot_residual_predicted(
|
857 |
-
y_train, train_pred, X_train
|
858 |
-
) # Sprint2
|
859 |
-
st.plotly_chart(fig)
|
860 |
-
|
861 |
-
with columns[1]:
|
862 |
-
st.empty()
|
863 |
-
fig = qqplot(y_train, train_pred) # Sprint2
|
864 |
-
st.plotly_chart(fig)
|
865 |
-
|
866 |
-
with columns[0]:
|
867 |
-
fig = residual_distribution(y_train, train_pred) # Sprint2
|
868 |
-
st.pyplot(fig)
|
869 |
-
|
870 |
-
vif_data = pd.DataFrame()
|
871 |
-
# X=X.drop('const',axis=1)
|
872 |
-
X_train_orig = (
|
873 |
-
X_train.copy()
|
874 |
-
) # Sprint2 -- creating a copy of xtrain. Later deleting panel, target & date from xtrain
|
875 |
-
del_col_list = list(
|
876 |
-
set([target_col, panel_col, date_col]).intersection(
|
877 |
-
set(X_train.columns)
|
878 |
-
)
|
879 |
-
)
|
880 |
-
X_train.drop(columns=del_col_list, inplace=True) # Sprint2
|
881 |
-
|
882 |
-
vif_data["Variable"] = X_train.columns
|
883 |
-
vif_data["VIF"] = [
|
884 |
-
variance_inflation_factor(X_train.values, i)
|
885 |
-
for i in range(X_train.shape[1])
|
886 |
-
]
|
887 |
-
vif_data.sort_values(by=["VIF"], ascending=False, inplace=True)
|
888 |
-
vif_data = np.round(vif_data)
|
889 |
-
vif_data["VIF"] = vif_data["VIF"].astype(float)
|
890 |
-
st.header("2.4 Variance Inflation Factor (VIF)")
|
891 |
-
# st.dataframe(vif_data)
|
892 |
-
color_mapping = {
|
893 |
-
"darkgreen": (vif_data["VIF"] < 3),
|
894 |
-
"orange": (vif_data["VIF"] >= 3) & (vif_data["VIF"] <= 10),
|
895 |
-
"darkred": (vif_data["VIF"] > 10),
|
896 |
-
}
|
897 |
-
|
898 |
-
# Create a horizontal bar plot
|
899 |
-
fig, ax = plt.subplots()
|
900 |
-
fig.set_figwidth(10) # Adjust the width of the figure as needed
|
901 |
-
|
902 |
-
# Sort the bars by descending VIF values
|
903 |
-
vif_data = vif_data.sort_values(by="VIF", ascending=False)
|
904 |
-
|
905 |
-
# Iterate through the color mapping and plot bars with corresponding colors
|
906 |
-
for color, condition in color_mapping.items():
|
907 |
-
subset = vif_data[condition]
|
908 |
-
bars = ax.barh(
|
909 |
-
subset["Variable"], subset["VIF"], color=color, label=color
|
910 |
-
)
|
911 |
-
|
912 |
-
# Add text annotations on top of the bars
|
913 |
-
for bar in bars:
|
914 |
-
width = bar.get_width()
|
915 |
-
ax.annotate(
|
916 |
-
f"{width:}",
|
917 |
-
xy=(width, bar.get_y() + bar.get_height() / 2),
|
918 |
-
xytext=(5, 0),
|
919 |
-
textcoords="offset points",
|
920 |
-
va="center",
|
921 |
-
)
|
922 |
-
|
923 |
-
# Customize the plot
|
924 |
-
ax.set_xlabel("VIF Values")
|
925 |
-
# ax.set_title('2.4 Variance Inflation Factor (VIF)')
|
926 |
-
# ax.legend(loc='upper right')
|
927 |
-
|
928 |
-
# Display the plot in Streamlit
|
929 |
-
st.pyplot(fig)
|
930 |
-
|
931 |
-
with st.expander("Results Summary Test data"):
|
932 |
-
# ss = MinMaxScaler()
|
933 |
-
# X_test = pd.DataFrame(ss.fit_transform(X_test), columns=X_test.columns)
|
934 |
-
st.header("2.2 Actual vs. Predicted Plot")
|
935 |
-
|
936 |
-
metrics_table, line, actual_vs_predicted_plot = (
|
937 |
-
plot_actual_vs_predicted(
|
938 |
-
X_test[date_col],
|
939 |
-
y_test,
|
940 |
-
test_pred,
|
941 |
-
model,
|
942 |
-
target_column=sel_target_col,
|
943 |
-
is_panel=is_panel,
|
944 |
-
)
|
945 |
-
) # Sprint2
|
946 |
-
|
947 |
-
st.plotly_chart(
|
948 |
-
actual_vs_predicted_plot, use_container_width=True
|
949 |
-
)
|
950 |
-
|
951 |
-
st.markdown("## 2.3 Residual Analysis")
|
952 |
-
columns = st.columns(2)
|
953 |
-
with columns[0]:
|
954 |
-
fig = plot_residual_predicted(
|
955 |
-
y, test_pred, X_test
|
956 |
-
) # Sprint2
|
957 |
-
st.plotly_chart(fig)
|
958 |
-
|
959 |
-
with columns[1]:
|
960 |
-
st.empty()
|
961 |
-
fig = qqplot(y, test_pred) # Sprint2
|
962 |
-
st.plotly_chart(fig)
|
963 |
-
|
964 |
-
with columns[0]:
|
965 |
-
fig = residual_distribution(y, test_pred) # Sprint2
|
966 |
-
st.pyplot(fig)
|
967 |
-
|
968 |
-
value = False
|
969 |
-
save_button_model = st.checkbox(
|
970 |
-
"Save this model to tune", key="build_rc_cb"
|
971 |
-
) # , on_click=set_save())
|
972 |
-
|
973 |
-
if save_button_model:
|
974 |
-
mod_name = st.text_input("Enter model name")
|
975 |
-
if len(mod_name) > 0:
|
976 |
-
mod_name = (
|
977 |
-
mod_name + "__" + target_col
|
978 |
-
) # Sprint4 - adding target col to model name
|
979 |
-
if is_panel:
|
980 |
-
pred_train = model.fittedvalues
|
981 |
-
pred_test = mdf_predict(X_test, model, random_eff_df)
|
982 |
-
else:
|
983 |
-
st.session_state["features_set"] = st.session_state[
|
984 |
-
"features_set"
|
985 |
-
] + ["const"]
|
986 |
-
pred_train = model.predict(
|
987 |
-
X_train_orig[st.session_state["features_set"]]
|
988 |
-
)
|
989 |
-
pred_test = model.predict(
|
990 |
-
X_test[st.session_state["features_set"]]
|
991 |
-
)
|
992 |
-
|
993 |
-
st.session_state["Model"][mod_name] = {
|
994 |
-
"Model_object": model,
|
995 |
-
"feature_set": st.session_state["features_set"],
|
996 |
-
"X_train": X_train_orig,
|
997 |
-
"X_test": X_test,
|
998 |
-
"y_train": y_train,
|
999 |
-
"y_test": y_test,
|
1000 |
-
"pred_train": pred_train,
|
1001 |
-
"pred_test": pred_test,
|
1002 |
-
}
|
1003 |
-
st.session_state["X_train"] = X_train_orig
|
1004 |
-
st.session_state["X_test_spends"] = test_spends
|
1005 |
-
st.session_state["saved_model_names"].append(mod_name)
|
1006 |
-
# Sprint3 additions
|
1007 |
-
if is_panel:
|
1008 |
-
random_eff_df = get_random_effects(
|
1009 |
-
media_data, panel_col, model
|
1010 |
-
)
|
1011 |
-
st.session_state["random_effects"] = random_eff_df
|
1012 |
-
|
1013 |
-
with open(
|
1014 |
-
os.path.join(
|
1015 |
-
st.session_state["project_path"], "best_models.pkl"
|
1016 |
-
),
|
1017 |
-
"wb",
|
1018 |
-
) as f:
|
1019 |
-
pickle.dump(st.session_state["Model"], f)
|
1020 |
-
st.success(
|
1021 |
-
mod_name
|
1022 |
-
+ " model saved! Proceed to the next page to tune the model"
|
1023 |
-
)
|
1024 |
-
|
1025 |
-
urm = st.session_state["used_response_metrics"]
|
1026 |
-
urm.append(sel_target_col)
|
1027 |
-
st.session_state["used_response_metrics"] = list(
|
1028 |
-
set(urm)
|
1029 |
-
)
|
1030 |
-
mod_name = ""
|
1031 |
-
# Sprint4 - add the formatted name of the target col to used resp metrics
|
1032 |
-
value = False
|
1033 |
-
|
1034 |
-
st.session_state["project_dct"]["model_build"][
|
1035 |
-
"session_state_saved"
|
1036 |
-
] = {}
|
1037 |
-
for key in [
|
1038 |
-
"Model",
|
1039 |
-
"bin_dict",
|
1040 |
-
"used_response_metrics",
|
1041 |
-
"date",
|
1042 |
-
"saved_model_names",
|
1043 |
-
"media_data",
|
1044 |
-
"X_test_spends",
|
1045 |
-
]:
|
1046 |
-
st.session_state["project_dct"]["model_build"][
|
1047 |
-
"session_state_saved"
|
1048 |
-
][key] = st.session_state[key]
|
1049 |
-
|
1050 |
-
project_dct_path = os.path.join(
|
1051 |
-
st.session_state["project_path"], "project_dct.pkl"
|
1052 |
-
)
|
1053 |
-
with open(project_dct_path, "wb") as f:
|
1054 |
-
pickle.dump(st.session_state["project_dct"], f)
|
1055 |
-
|
1056 |
-
update_db("4_Model_Build.py")
|
1057 |
-
|
1058 |
-
st.toast("💾 Saved Successfully!")
|
1059 |
-
else:
|
1060 |
-
st.session_state["project_dct"]["model_build"][
|
1061 |
-
"show_results_check"
|
1062 |
-
] = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/5_Model_Tuning.py
DELETED
@@ -1,912 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
MMO Build Sprint 3
|
3 |
-
date :
|
4 |
-
changes : capability to tune MixedLM as well as simple LR in the same page
|
5 |
-
"""
|
6 |
-
|
7 |
-
import os
|
8 |
-
|
9 |
-
import streamlit as st
|
10 |
-
import pandas as pd
|
11 |
-
from Eda_functions import format_numbers
|
12 |
-
import pickle
|
13 |
-
from utilities import set_header, load_local_css
|
14 |
-
import statsmodels.api as sm
|
15 |
-
import re
|
16 |
-
from sklearn.preprocessing import MinMaxScaler
|
17 |
-
import matplotlib.pyplot as plt
|
18 |
-
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
19 |
-
import yaml
|
20 |
-
from yaml import SafeLoader
|
21 |
-
import streamlit_authenticator as stauth
|
22 |
-
|
23 |
-
st.set_option("deprecation.showPyplotGlobalUse", False)
|
24 |
-
import statsmodels.formula.api as smf
|
25 |
-
from Data_prep_functions import *
|
26 |
-
import sqlite3
|
27 |
-
from utilities import update_db
|
28 |
-
|
29 |
-
# for i in ["model_tuned", "X_train_tuned", "X_test_tuned", "tuned_model_features", "tuned_model", "tuned_model_dict"] :
|
30 |
-
|
31 |
-
st.set_page_config(
|
32 |
-
page_title="Model Tuning",
|
33 |
-
page_icon=":shark:",
|
34 |
-
layout="wide",
|
35 |
-
initial_sidebar_state="collapsed",
|
36 |
-
)
|
37 |
-
load_local_css("styles.css")
|
38 |
-
set_header()
|
39 |
-
# Check for authentication status
|
40 |
-
for k, v in st.session_state.items():
|
41 |
-
# print(k, v)
|
42 |
-
if k not in [
|
43 |
-
"logout",
|
44 |
-
"login",
|
45 |
-
"config",
|
46 |
-
"build_tuned_model",
|
47 |
-
] and not k.startswith("FormSubmitter"):
|
48 |
-
st.session_state[k] = v
|
49 |
-
with open("config.yaml") as file:
|
50 |
-
config = yaml.load(file, Loader=SafeLoader)
|
51 |
-
st.session_state["config"] = config
|
52 |
-
authenticator = stauth.Authenticate(
|
53 |
-
config["credentials"],
|
54 |
-
config["cookie"]["name"],
|
55 |
-
config["cookie"]["key"],
|
56 |
-
config["cookie"]["expiry_days"],
|
57 |
-
config["preauthorized"],
|
58 |
-
)
|
59 |
-
st.session_state["authenticator"] = authenticator
|
60 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
61 |
-
auth_status = st.session_state.get("authentication_status")
|
62 |
-
|
63 |
-
if auth_status == True:
|
64 |
-
authenticator.logout("Logout", "main")
|
65 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
66 |
-
|
67 |
-
if "project_dct" not in st.session_state:
|
68 |
-
st.error("Please load a project from Home page")
|
69 |
-
st.stop()
|
70 |
-
|
71 |
-
if not os.path.exists(
|
72 |
-
os.path.join(st.session_state["project_path"], "best_models.pkl")
|
73 |
-
):
|
74 |
-
st.error("Please save a model before tuning")
|
75 |
-
st.stop()
|
76 |
-
|
77 |
-
conn = sqlite3.connect(
|
78 |
-
r"DB/User.db", check_same_thread=False
|
79 |
-
) # connection with sql db
|
80 |
-
c = conn.cursor()
|
81 |
-
|
82 |
-
if not is_state_initiaized:
|
83 |
-
if "session_name" not in st.session_state:
|
84 |
-
st.session_state["session_name"] = None
|
85 |
-
|
86 |
-
if (
|
87 |
-
"session_state_saved"
|
88 |
-
in st.session_state["project_dct"]["model_build"].keys()
|
89 |
-
):
|
90 |
-
for key in [
|
91 |
-
"Model",
|
92 |
-
"date",
|
93 |
-
"saved_model_names",
|
94 |
-
"media_data",
|
95 |
-
"X_test_spends",
|
96 |
-
]:
|
97 |
-
if key not in st.session_state:
|
98 |
-
st.session_state[key] = st.session_state["project_dct"][
|
99 |
-
"model_build"
|
100 |
-
]["session_state_saved"][key]
|
101 |
-
st.session_state["bin_dict"] = st.session_state["project_dct"][
|
102 |
-
"model_build"
|
103 |
-
]["session_state_saved"]["bin_dict"]
|
104 |
-
if (
|
105 |
-
"used_response_metrics" not in st.session_state
|
106 |
-
or st.session_state["used_response_metrics"] == []
|
107 |
-
):
|
108 |
-
st.session_state["used_response_metrics"] = st.session_state[
|
109 |
-
"project_dct"
|
110 |
-
]["model_build"]["session_state_saved"][
|
111 |
-
"used_response_metrics"
|
112 |
-
]
|
113 |
-
else:
|
114 |
-
st.error("Please load a session with a built model")
|
115 |
-
st.stop()
|
116 |
-
|
117 |
-
# if 'sel_model' not in st.session_state["project_dct"]["model_tuning"].keys():
|
118 |
-
# st.session_state["project_dct"]["model_tuning"]['sel_model']= {}
|
119 |
-
|
120 |
-
for key in ["select_all_flags_check", "selected_flags", "sel_model"]:
|
121 |
-
if key not in st.session_state["project_dct"]["model_tuning"].keys():
|
122 |
-
st.session_state["project_dct"]["model_tuning"][key] = {}
|
123 |
-
# Sprint3
|
124 |
-
# is_panel = st.session_state['is_panel']
|
125 |
-
# panel_col = 'markets' # set the panel column
|
126 |
-
date_col = "date"
|
127 |
-
|
128 |
-
panel_col = [
|
129 |
-
col.lower()
|
130 |
-
.replace(".", "_")
|
131 |
-
.replace("@", "_")
|
132 |
-
.replace(" ", "_")
|
133 |
-
.replace("-", "")
|
134 |
-
.replace(":", "")
|
135 |
-
.replace("__", "_")
|
136 |
-
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
137 |
-
][
|
138 |
-
0
|
139 |
-
] # set the panel column
|
140 |
-
is_panel = True if len(panel_col) > 0 else False
|
141 |
-
|
142 |
-
# flag indicating there is not tuned model till now
|
143 |
-
|
144 |
-
# Sprint4 - model tuned dict
|
145 |
-
if "Model_Tuned" not in st.session_state:
|
146 |
-
st.session_state["Model_Tuned"] = {}
|
147 |
-
|
148 |
-
st.title("1. Model Tuning")
|
149 |
-
|
150 |
-
if "is_tuned_model" not in st.session_state:
|
151 |
-
st.session_state["is_tuned_model"] = {}
|
152 |
-
# Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
|
153 |
-
if (
|
154 |
-
"used_response_metrics" in st.session_state
|
155 |
-
and st.session_state["used_response_metrics"] != []
|
156 |
-
):
|
157 |
-
default_target_idx = (
|
158 |
-
st.session_state["project_dct"]["model_tuning"].get(
|
159 |
-
"sel_target_col", None
|
160 |
-
)
|
161 |
-
if st.session_state["project_dct"]["model_tuning"].get(
|
162 |
-
"sel_target_col", None
|
163 |
-
)
|
164 |
-
is not None
|
165 |
-
else st.session_state["used_response_metrics"][0]
|
166 |
-
)
|
167 |
-
sel_target_col = st.selectbox(
|
168 |
-
"Select the response metric",
|
169 |
-
st.session_state["used_response_metrics"],
|
170 |
-
index=st.session_state["used_response_metrics"].index(
|
171 |
-
default_target_idx
|
172 |
-
),
|
173 |
-
)
|
174 |
-
target_col = (
|
175 |
-
sel_target_col.lower()
|
176 |
-
.replace(" ", "_")
|
177 |
-
.replace("-", "")
|
178 |
-
.replace(":", "")
|
179 |
-
.replace("__", "_")
|
180 |
-
)
|
181 |
-
st.session_state["project_dct"]["model_tuning"][
|
182 |
-
"sel_target_col"
|
183 |
-
] = sel_target_col
|
184 |
-
|
185 |
-
else:
|
186 |
-
sel_target_col = "Total Approved Accounts - Revenue"
|
187 |
-
target_col = "total_approved_accounts_revenue"
|
188 |
-
|
189 |
-
# Sprint4 - Look through all saved models, only show saved models of the sel resp metric (target_col)
|
190 |
-
# saved_models = st.session_state['saved_model_names']
|
191 |
-
with open(
|
192 |
-
os.path.join(st.session_state["project_path"], "best_models.pkl"), "rb"
|
193 |
-
) as file:
|
194 |
-
model_dict = pickle.load(file)
|
195 |
-
|
196 |
-
saved_models = model_dict.keys()
|
197 |
-
required_saved_models = [
|
198 |
-
m.split("__")[0]
|
199 |
-
for m in saved_models
|
200 |
-
if m.split("__")[1] == target_col
|
201 |
-
]
|
202 |
-
|
203 |
-
if len(required_saved_models) > 0:
|
204 |
-
default_model_idx = st.session_state["project_dct"]["model_tuning"][
|
205 |
-
"sel_model"
|
206 |
-
].get(sel_target_col, required_saved_models[0])
|
207 |
-
sel_model = st.selectbox(
|
208 |
-
"Select the model to tune",
|
209 |
-
required_saved_models,
|
210 |
-
index=required_saved_models.index(default_model_idx),
|
211 |
-
)
|
212 |
-
else:
|
213 |
-
default_model_idx = st.session_state["project_dct"]["model_tuning"][
|
214 |
-
"sel_model"
|
215 |
-
].get(sel_target_col, 0)
|
216 |
-
sel_model = st.selectbox(
|
217 |
-
"Select the model to tune", required_saved_models
|
218 |
-
)
|
219 |
-
|
220 |
-
st.session_state["project_dct"]["model_tuning"]["sel_model"][
|
221 |
-
sel_target_col
|
222 |
-
] = default_model_idx
|
223 |
-
|
224 |
-
sel_model_dict = model_dict[
|
225 |
-
sel_model + "__" + target_col
|
226 |
-
] # Sprint4 - get the model obj of the selected model
|
227 |
-
|
228 |
-
X_train = sel_model_dict["X_train"]
|
229 |
-
X_test = sel_model_dict["X_test"]
|
230 |
-
y_train = sel_model_dict["y_train"]
|
231 |
-
y_test = sel_model_dict["y_test"]
|
232 |
-
df = st.session_state["media_data"]
|
233 |
-
|
234 |
-
if "selected_model" not in st.session_state:
|
235 |
-
st.session_state["selected_model"] = 0
|
236 |
-
|
237 |
-
st.markdown("### 1.1 Event Flags")
|
238 |
-
st.markdown(
|
239 |
-
"Helps in quantifying the impact of specific occurrences of events"
|
240 |
-
)
|
241 |
-
|
242 |
-
flag_expander_default = (
|
243 |
-
st.session_state["project_dct"]["model_tuning"].get(
|
244 |
-
"flag_expander", None
|
245 |
-
)
|
246 |
-
if st.session_state["project_dct"]["model_tuning"].get(
|
247 |
-
"flag_expander", None
|
248 |
-
)
|
249 |
-
is not None
|
250 |
-
else False
|
251 |
-
)
|
252 |
-
|
253 |
-
with st.expander("Apply Event Flags", flag_expander_default):
|
254 |
-
st.session_state["project_dct"]["model_tuning"]["flag_expander"] = True
|
255 |
-
|
256 |
-
model = sel_model_dict["Model_object"]
|
257 |
-
date = st.session_state["date"]
|
258 |
-
date = pd.to_datetime(date)
|
259 |
-
X_train = sel_model_dict["X_train"]
|
260 |
-
|
261 |
-
# features_set= model_dict[st.session_state["selected_model"]]['feature_set']
|
262 |
-
features_set = sel_model_dict["feature_set"]
|
263 |
-
|
264 |
-
col = st.columns(3)
|
265 |
-
min_date = min(date)
|
266 |
-
max_date = max(date)
|
267 |
-
|
268 |
-
start_date_default = (
|
269 |
-
st.session_state["project_dct"]["model_tuning"].get(
|
270 |
-
"start_date_default"
|
271 |
-
)
|
272 |
-
if st.session_state["project_dct"]["model_tuning"].get(
|
273 |
-
"start_date_default"
|
274 |
-
)
|
275 |
-
is not None
|
276 |
-
else min_date
|
277 |
-
)
|
278 |
-
end_date_default = (
|
279 |
-
st.session_state["project_dct"]["model_tuning"].get(
|
280 |
-
"end_date_default"
|
281 |
-
)
|
282 |
-
if st.session_state["project_dct"]["model_tuning"].get(
|
283 |
-
"end_date_default"
|
284 |
-
)
|
285 |
-
is not None
|
286 |
-
else max_date
|
287 |
-
)
|
288 |
-
with col[0]:
|
289 |
-
start_date = st.date_input(
|
290 |
-
"Select Start Date",
|
291 |
-
start_date_default,
|
292 |
-
min_value=min_date,
|
293 |
-
max_value=max_date,
|
294 |
-
)
|
295 |
-
with col[1]:
|
296 |
-
end_date_default = (
|
297 |
-
end_date_default
|
298 |
-
if end_date_default >= start_date
|
299 |
-
else start_date
|
300 |
-
)
|
301 |
-
end_date = st.date_input(
|
302 |
-
"Select End Date",
|
303 |
-
end_date_default,
|
304 |
-
min_value=max(min_date, start_date),
|
305 |
-
max_value=max_date,
|
306 |
-
)
|
307 |
-
with col[2]:
|
308 |
-
repeat_default = (
|
309 |
-
st.session_state["project_dct"]["model_tuning"].get(
|
310 |
-
"repeat_default"
|
311 |
-
)
|
312 |
-
if st.session_state["project_dct"]["model_tuning"].get(
|
313 |
-
"repeat_default"
|
314 |
-
)
|
315 |
-
is not None
|
316 |
-
else "No"
|
317 |
-
)
|
318 |
-
repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1
|
319 |
-
repeat = st.selectbox(
|
320 |
-
"Repeat Annually", ["Yes", "No"], index=repeat_default_idx
|
321 |
-
)
|
322 |
-
st.session_state["project_dct"]["model_tuning"][
|
323 |
-
"start_date_default"
|
324 |
-
] = start_date
|
325 |
-
st.session_state["project_dct"]["model_tuning"][
|
326 |
-
"end_date_default"
|
327 |
-
] = end_date
|
328 |
-
st.session_state["project_dct"]["model_tuning"][
|
329 |
-
"repeat_default"
|
330 |
-
] = repeat
|
331 |
-
|
332 |
-
if repeat == "Yes":
|
333 |
-
repeat = True
|
334 |
-
else:
|
335 |
-
repeat = False
|
336 |
-
|
337 |
-
if "Flags" not in st.session_state:
|
338 |
-
st.session_state["Flags"] = {}
|
339 |
-
if "flags" in st.session_state["project_dct"]["model_tuning"].keys():
|
340 |
-
st.session_state["Flags"] = st.session_state["project_dct"][
|
341 |
-
"model_tuning"
|
342 |
-
]["flags"]
|
343 |
-
# print("**"*50)
|
344 |
-
# print(y_train)
|
345 |
-
# print("**"*50)
|
346 |
-
# print(model.fittedvalues)
|
347 |
-
if is_panel: # Sprint3
|
348 |
-
met, line_values, fig_flag = plot_actual_vs_predicted(
|
349 |
-
X_train[date_col],
|
350 |
-
y_train,
|
351 |
-
model.fittedvalues,
|
352 |
-
model,
|
353 |
-
target_column=sel_target_col,
|
354 |
-
flag=(start_date, end_date),
|
355 |
-
repeat_all_years=repeat,
|
356 |
-
is_panel=True,
|
357 |
-
)
|
358 |
-
st.plotly_chart(fig_flag, use_container_width=True)
|
359 |
-
|
360 |
-
# create flag on test
|
361 |
-
met, test_line_values, fig_flag = plot_actual_vs_predicted(
|
362 |
-
X_test[date_col],
|
363 |
-
y_test,
|
364 |
-
sel_model_dict["pred_test"],
|
365 |
-
model,
|
366 |
-
target_column=sel_target_col,
|
367 |
-
flag=(start_date, end_date),
|
368 |
-
repeat_all_years=repeat,
|
369 |
-
is_panel=True,
|
370 |
-
)
|
371 |
-
|
372 |
-
else:
|
373 |
-
pred_train = model.predict(X_train[features_set])
|
374 |
-
met, line_values, fig_flag = plot_actual_vs_predicted(
|
375 |
-
X_train[date_col],
|
376 |
-
y_train,
|
377 |
-
pred_train,
|
378 |
-
model,
|
379 |
-
flag=(start_date, end_date),
|
380 |
-
repeat_all_years=repeat,
|
381 |
-
is_panel=False,
|
382 |
-
)
|
383 |
-
st.plotly_chart(fig_flag, use_container_width=True)
|
384 |
-
|
385 |
-
pred_test = model.predict(X_test[features_set])
|
386 |
-
met, test_line_values, fig_flag = plot_actual_vs_predicted(
|
387 |
-
X_test[date_col],
|
388 |
-
y_test,
|
389 |
-
pred_test,
|
390 |
-
model,
|
391 |
-
flag=(start_date, end_date),
|
392 |
-
repeat_all_years=repeat,
|
393 |
-
is_panel=False,
|
394 |
-
)
|
395 |
-
flag_name = "f1_flag"
|
396 |
-
flag_name = st.text_input("Enter Flag Name")
|
397 |
-
# Sprint4 - add selected target col to flag name
|
398 |
-
if st.button("Update flag"):
|
399 |
-
st.session_state["Flags"][flag_name + "__" + target_col] = {}
|
400 |
-
st.session_state["Flags"][flag_name + "__" + target_col][
|
401 |
-
"train"
|
402 |
-
] = line_values
|
403 |
-
st.session_state["Flags"][flag_name + "__" + target_col][
|
404 |
-
"test"
|
405 |
-
] = test_line_values
|
406 |
-
st.success(f'{flag_name + "__" + target_col} stored')
|
407 |
-
|
408 |
-
st.session_state["project_dct"]["model_tuning"]["flags"] = (
|
409 |
-
st.session_state["Flags"]
|
410 |
-
)
|
411 |
-
# Sprint4 - only show flag created for the particular target col
|
412 |
-
if st.session_state["Flags"] is None:
|
413 |
-
st.session_state["Flags"] = {}
|
414 |
-
target_model_flags = [
|
415 |
-
f.split("__")[0]
|
416 |
-
for f in st.session_state["Flags"].keys()
|
417 |
-
if f.split("__")[1] == target_col
|
418 |
-
]
|
419 |
-
options = list(target_model_flags)
|
420 |
-
selected_options = []
|
421 |
-
num_columns = 4
|
422 |
-
num_rows = -(-len(options) // num_columns)
|
423 |
-
|
424 |
-
tick = False
|
425 |
-
if st.checkbox(
|
426 |
-
"Select all",
|
427 |
-
value=st.session_state["project_dct"]["model_tuning"][
|
428 |
-
"select_all_flags_check"
|
429 |
-
].get(sel_target_col, False),
|
430 |
-
):
|
431 |
-
tick = True
|
432 |
-
st.session_state["project_dct"]["model_tuning"][
|
433 |
-
"select_all_flags_check"
|
434 |
-
][sel_target_col] = True
|
435 |
-
else:
|
436 |
-
st.session_state["project_dct"]["model_tuning"][
|
437 |
-
"select_all_flags_check"
|
438 |
-
][sel_target_col] = False
|
439 |
-
selection_defualts = st.session_state["project_dct"]["model_tuning"][
|
440 |
-
"selected_flags"
|
441 |
-
].get(sel_target_col, [])
|
442 |
-
selected_options = selection_defualts
|
443 |
-
for row in range(num_rows):
|
444 |
-
cols = st.columns(num_columns)
|
445 |
-
for col in cols:
|
446 |
-
if options:
|
447 |
-
option = options.pop(0)
|
448 |
-
option_default = (
|
449 |
-
True if option in selection_defualts else False
|
450 |
-
)
|
451 |
-
selected = col.checkbox(option, value=(tick or option_default))
|
452 |
-
if selected:
|
453 |
-
selected_options.append(option)
|
454 |
-
st.session_state["project_dct"]["model_tuning"]["selected_flags"][
|
455 |
-
sel_target_col
|
456 |
-
] = selected_options
|
457 |
-
|
458 |
-
st.markdown("### 1.2 Select Parameters to Apply")
|
459 |
-
parameters = st.columns(3)
|
460 |
-
with parameters[0]:
|
461 |
-
Trend = st.checkbox(
|
462 |
-
"**Trend**",
|
463 |
-
value=st.session_state["project_dct"]["model_tuning"].get(
|
464 |
-
"trend_check", False
|
465 |
-
),
|
466 |
-
)
|
467 |
-
st.markdown(
|
468 |
-
"Helps account for long-term trends or seasonality that could influence advertising effectiveness"
|
469 |
-
)
|
470 |
-
with parameters[1]:
|
471 |
-
week_number = st.checkbox(
|
472 |
-
"**Week_number**",
|
473 |
-
value=st.session_state["project_dct"]["model_tuning"].get(
|
474 |
-
"week_num_check", False
|
475 |
-
),
|
476 |
-
)
|
477 |
-
st.markdown(
|
478 |
-
"Assists in detecting and incorporating weekly patterns or seasonality"
|
479 |
-
)
|
480 |
-
with parameters[2]:
|
481 |
-
sine_cosine = st.checkbox(
|
482 |
-
"**Sine and Cosine Waves**",
|
483 |
-
value=st.session_state["project_dct"]["model_tuning"].get(
|
484 |
-
"sine_cosine_check", False
|
485 |
-
),
|
486 |
-
)
|
487 |
-
st.markdown(
|
488 |
-
"Helps in capturing cyclical patterns or seasonality in the data"
|
489 |
-
)
|
490 |
-
#
|
491 |
-
# def get_tuned_model():
|
492 |
-
# st.session_state['build_tuned_model']=True
|
493 |
-
|
494 |
-
if st.button(
|
495 |
-
"Build model with Selected Parameters and Flags",
|
496 |
-
key="build_tuned_model",
|
497 |
-
):
|
498 |
-
new_features = features_set
|
499 |
-
st.header("2.1 Results Summary")
|
500 |
-
# date=list(df.index)
|
501 |
-
# df = df.reset_index(drop=True)
|
502 |
-
# X_train=df[features_set]
|
503 |
-
ss = MinMaxScaler()
|
504 |
-
if is_panel == True:
|
505 |
-
X_train_tuned = X_train[features_set]
|
506 |
-
# X_train_tuned = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
|
507 |
-
X_train_tuned[target_col] = X_train[target_col]
|
508 |
-
X_train_tuned[date_col] = X_train[date_col]
|
509 |
-
X_train_tuned[panel_col] = X_train[panel_col]
|
510 |
-
|
511 |
-
X_test_tuned = X_test[features_set]
|
512 |
-
# X_test_tuned = pd.DataFrame(ss.transform(X), columns=X.columns)
|
513 |
-
X_test_tuned[target_col] = X_test[target_col]
|
514 |
-
X_test_tuned[date_col] = X_test[date_col]
|
515 |
-
X_test_tuned[panel_col] = X_test[panel_col]
|
516 |
-
|
517 |
-
else:
|
518 |
-
X_train_tuned = X_train[features_set]
|
519 |
-
# X_train_tuned = pd.DataFrame(ss.fit_transform(X_train_tuned), columns=X_train_tuned.columns)
|
520 |
-
|
521 |
-
X_test_tuned = X_test[features_set]
|
522 |
-
# X_test_tuned = pd.DataFrame(ss.transform(X_test_tuned), columns=X_test_tuned.columns)
|
523 |
-
|
524 |
-
for flag in selected_options:
|
525 |
-
# Spirnt4 - added target_col in flag name
|
526 |
-
X_train_tuned[flag] = st.session_state["Flags"][
|
527 |
-
flag + "__" + target_col
|
528 |
-
]["train"]
|
529 |
-
X_test_tuned[flag] = st.session_state["Flags"][
|
530 |
-
flag + "__" + target_col
|
531 |
-
]["test"]
|
532 |
-
|
533 |
-
# test
|
534 |
-
# X_train_tuned.to_csv("Test/X_train_tuned_flag.csv",index=False)
|
535 |
-
# X_test_tuned.to_csv("Test/X_test_tuned_flag.csv",index=False)
|
536 |
-
|
537 |
-
# print("()()"*20,flag, len(st.session_state['Flags'][flag]))
|
538 |
-
if Trend:
|
539 |
-
st.session_state["project_dct"]["model_tuning"][
|
540 |
-
"trend_check"
|
541 |
-
] = True
|
542 |
-
# Sprint3 - group by panel, calculate trend of each panel spearately. Add trend to new feature set
|
543 |
-
if is_panel:
|
544 |
-
newdata = pd.DataFrame()
|
545 |
-
panel_wise_end_point_train = {}
|
546 |
-
for panel, groupdf in X_train_tuned.groupby(panel_col):
|
547 |
-
groupdf.sort_values(date_col, inplace=True)
|
548 |
-
groupdf["Trend"] = np.arange(1, len(groupdf) + 1, 1)
|
549 |
-
newdata = pd.concat([newdata, groupdf])
|
550 |
-
panel_wise_end_point_train[panel] = len(groupdf)
|
551 |
-
X_train_tuned = newdata.copy()
|
552 |
-
|
553 |
-
test_newdata = pd.DataFrame()
|
554 |
-
for panel, test_groupdf in X_test_tuned.groupby(panel_col):
|
555 |
-
test_groupdf.sort_values(date_col, inplace=True)
|
556 |
-
start = panel_wise_end_point_train[panel] + 1
|
557 |
-
end = start + len(test_groupdf) # should be + 1? - Sprint4
|
558 |
-
# print("??"*20, panel, len(test_groupdf), len(np.arange(start, end, 1)), start)
|
559 |
-
test_groupdf["Trend"] = np.arange(start, end, 1)
|
560 |
-
test_newdata = pd.concat([test_newdata, test_groupdf])
|
561 |
-
X_test_tuned = test_newdata.copy()
|
562 |
-
|
563 |
-
new_features = new_features + ["Trend"]
|
564 |
-
|
565 |
-
else:
|
566 |
-
X_train_tuned["Trend"] = np.arange(
|
567 |
-
1, len(X_train_tuned) + 1, 1
|
568 |
-
)
|
569 |
-
X_test_tuned["Trend"] = np.arange(
|
570 |
-
len(X_train_tuned) + 1,
|
571 |
-
len(X_train_tuned) + len(X_test_tuned) + 1,
|
572 |
-
1,
|
573 |
-
)
|
574 |
-
new_features = new_features + ["Trend"]
|
575 |
-
else:
|
576 |
-
st.session_state["project_dct"]["model_tuning"][
|
577 |
-
"trend_check"
|
578 |
-
] = False
|
579 |
-
|
580 |
-
if week_number:
|
581 |
-
st.session_state["project_dct"]["model_tuning"][
|
582 |
-
"week_num_check"
|
583 |
-
] = True
|
584 |
-
# Sprint3 - create weeknumber from date column in xtrain tuned. add week num to new feature set
|
585 |
-
if is_panel:
|
586 |
-
X_train_tuned[date_col] = pd.to_datetime(
|
587 |
-
X_train_tuned[date_col]
|
588 |
-
)
|
589 |
-
X_train_tuned["Week_number"] = X_train_tuned[
|
590 |
-
date_col
|
591 |
-
].dt.day_of_week
|
592 |
-
if X_train_tuned["Week_number"].nunique() == 1:
|
593 |
-
st.write(
|
594 |
-
"All dates in the data are of the same week day. Hence Week number can't be used."
|
595 |
-
)
|
596 |
-
else:
|
597 |
-
X_test_tuned[date_col] = pd.to_datetime(
|
598 |
-
X_test_tuned[date_col]
|
599 |
-
)
|
600 |
-
X_test_tuned["Week_number"] = X_test_tuned[
|
601 |
-
date_col
|
602 |
-
].dt.day_of_week
|
603 |
-
new_features = new_features + ["Week_number"]
|
604 |
-
|
605 |
-
else:
|
606 |
-
date = pd.to_datetime(date.values)
|
607 |
-
X_train_tuned["Week_number"] = pd.to_datetime(
|
608 |
-
X_train[date_col]
|
609 |
-
).dt.day_of_week
|
610 |
-
X_test_tuned["Week_number"] = pd.to_datetime(
|
611 |
-
X_test[date_col]
|
612 |
-
).dt.day_of_week
|
613 |
-
new_features = new_features + ["Week_number"]
|
614 |
-
else:
|
615 |
-
st.session_state["project_dct"]["model_tuning"][
|
616 |
-
"week_num_check"
|
617 |
-
] = False
|
618 |
-
|
619 |
-
if sine_cosine:
|
620 |
-
st.session_state["project_dct"]["model_tuning"][
|
621 |
-
"sine_cosine_check"
|
622 |
-
] = True
|
623 |
-
# Sprint3 - create panel wise sine cosine waves in xtrain tuned. add to new feature set
|
624 |
-
if is_panel:
|
625 |
-
new_features = new_features + ["sine_wave", "cosine_wave"]
|
626 |
-
newdata = pd.DataFrame()
|
627 |
-
newdata_test = pd.DataFrame()
|
628 |
-
groups = X_train_tuned.groupby(panel_col)
|
629 |
-
frequency = 2 * np.pi / 365 # Adjust the frequency as needed
|
630 |
-
|
631 |
-
train_panel_wise_end_point = {}
|
632 |
-
for panel, groupdf in groups:
|
633 |
-
num_samples = len(groupdf)
|
634 |
-
train_panel_wise_end_point[panel] = num_samples
|
635 |
-
days_since_start = np.arange(num_samples)
|
636 |
-
sine_wave = np.sin(frequency * days_since_start)
|
637 |
-
cosine_wave = np.cos(frequency * days_since_start)
|
638 |
-
sine_cosine_df = pd.DataFrame(
|
639 |
-
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
640 |
-
)
|
641 |
-
assert len(sine_cosine_df) == len(groupdf)
|
642 |
-
# groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
|
643 |
-
groupdf["sine_wave"] = sine_wave
|
644 |
-
groupdf["cosine_wave"] = cosine_wave
|
645 |
-
newdata = pd.concat([newdata, groupdf])
|
646 |
-
|
647 |
-
X_train_tuned = newdata.copy()
|
648 |
-
|
649 |
-
test_groups = X_test_tuned.groupby(panel_col)
|
650 |
-
for panel, test_groupdf in test_groups:
|
651 |
-
num_samples = len(test_groupdf)
|
652 |
-
start = train_panel_wise_end_point[panel]
|
653 |
-
days_since_start = np.arange(start, start + num_samples, 1)
|
654 |
-
# print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1)))
|
655 |
-
sine_wave = np.sin(frequency * days_since_start)
|
656 |
-
cosine_wave = np.cos(frequency * days_since_start)
|
657 |
-
sine_cosine_df = pd.DataFrame(
|
658 |
-
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
659 |
-
)
|
660 |
-
assert len(sine_cosine_df) == len(test_groupdf)
|
661 |
-
# groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
|
662 |
-
test_groupdf["sine_wave"] = sine_wave
|
663 |
-
test_groupdf["cosine_wave"] = cosine_wave
|
664 |
-
newdata_test = pd.concat([newdata_test, test_groupdf])
|
665 |
-
|
666 |
-
X_test_tuned = newdata_test.copy()
|
667 |
-
|
668 |
-
else:
|
669 |
-
new_features = new_features + ["sine_wave", "cosine_wave"]
|
670 |
-
|
671 |
-
num_samples = len(X_train_tuned)
|
672 |
-
frequency = 2 * np.pi / 365 # Adjust the frequency as needed
|
673 |
-
days_since_start = np.arange(num_samples)
|
674 |
-
sine_wave = np.sin(frequency * days_since_start)
|
675 |
-
cosine_wave = np.cos(frequency * days_since_start)
|
676 |
-
sine_cosine_df = pd.DataFrame(
|
677 |
-
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
678 |
-
)
|
679 |
-
# Concatenate the sine and cosine waves with the scaled X DataFrame
|
680 |
-
X_train_tuned = pd.concat(
|
681 |
-
[X_train_tuned, sine_cosine_df], axis=1
|
682 |
-
)
|
683 |
-
|
684 |
-
test_num_samples = len(X_test_tuned)
|
685 |
-
start = num_samples
|
686 |
-
days_since_start = np.arange(
|
687 |
-
start, start + test_num_samples, 1
|
688 |
-
)
|
689 |
-
sine_wave = np.sin(frequency * days_since_start)
|
690 |
-
cosine_wave = np.cos(frequency * days_since_start)
|
691 |
-
sine_cosine_df = pd.DataFrame(
|
692 |
-
{"sine_wave": sine_wave, "cosine_wave": cosine_wave}
|
693 |
-
)
|
694 |
-
# Concatenate the sine and cosine waves with the scaled X DataFrame
|
695 |
-
X_test_tuned = pd.concat(
|
696 |
-
[X_test_tuned, sine_cosine_df], axis=1
|
697 |
-
)
|
698 |
-
else:
|
699 |
-
st.session_state["project_dct"]["model_tuning"][
|
700 |
-
"sine_cosine_check"
|
701 |
-
] = False
|
702 |
-
|
703 |
-
# model
|
704 |
-
if selected_options:
|
705 |
-
new_features = new_features + selected_options
|
706 |
-
if is_panel:
|
707 |
-
inp_vars_str = " + ".join(new_features)
|
708 |
-
new_features = list(set(new_features))
|
709 |
-
|
710 |
-
md_str = target_col + " ~ " + inp_vars_str
|
711 |
-
md_tuned = smf.mixedlm(
|
712 |
-
md_str,
|
713 |
-
data=X_train_tuned[[target_col] + new_features],
|
714 |
-
groups=X_train_tuned[panel_col],
|
715 |
-
)
|
716 |
-
model_tuned = md_tuned.fit()
|
717 |
-
|
718 |
-
# plot act v pred for original model and tuned model
|
719 |
-
metrics_table, line, actual_vs_predicted_plot = (
|
720 |
-
plot_actual_vs_predicted(
|
721 |
-
X_train[date_col],
|
722 |
-
y_train,
|
723 |
-
model.fittedvalues,
|
724 |
-
model,
|
725 |
-
target_column=sel_target_col,
|
726 |
-
is_panel=True,
|
727 |
-
)
|
728 |
-
)
|
729 |
-
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
|
730 |
-
plot_actual_vs_predicted(
|
731 |
-
X_train_tuned[date_col],
|
732 |
-
X_train_tuned[target_col],
|
733 |
-
model_tuned.fittedvalues,
|
734 |
-
model_tuned,
|
735 |
-
target_column=sel_target_col,
|
736 |
-
is_panel=True,
|
737 |
-
)
|
738 |
-
)
|
739 |
-
|
740 |
-
else:
|
741 |
-
new_features = list(set(new_features))
|
742 |
-
model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit()
|
743 |
-
metrics_table, line, actual_vs_predicted_plot = (
|
744 |
-
plot_actual_vs_predicted(
|
745 |
-
date[:130],
|
746 |
-
y_train,
|
747 |
-
model.predict(X_train[features_set]),
|
748 |
-
model,
|
749 |
-
target_column=sel_target_col,
|
750 |
-
)
|
751 |
-
)
|
752 |
-
metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
|
753 |
-
plot_actual_vs_predicted(
|
754 |
-
date[:130],
|
755 |
-
y_train,
|
756 |
-
model_tuned.predict(X_train_tuned),
|
757 |
-
model_tuned,
|
758 |
-
target_column=sel_target_col,
|
759 |
-
)
|
760 |
-
)
|
761 |
-
|
762 |
-
mape = np.round(metrics_table.iloc[0, 1], 2)
|
763 |
-
r2 = np.round(metrics_table.iloc[1, 1], 2)
|
764 |
-
adjr2 = np.round(metrics_table.iloc[2, 1], 2)
|
765 |
-
|
766 |
-
mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2)
|
767 |
-
r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2)
|
768 |
-
adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2)
|
769 |
-
|
770 |
-
parameters_ = st.columns(3)
|
771 |
-
with parameters_[0]:
|
772 |
-
st.metric("R2", r2_tuned, np.round(r2_tuned - r2, 2))
|
773 |
-
with parameters_[1]:
|
774 |
-
st.metric(
|
775 |
-
"Adjusted R2", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2)
|
776 |
-
)
|
777 |
-
with parameters_[2]:
|
778 |
-
st.metric(
|
779 |
-
"MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse"
|
780 |
-
)
|
781 |
-
st.write(model_tuned.summary())
|
782 |
-
|
783 |
-
X_train_tuned[date_col] = X_train[date_col]
|
784 |
-
X_test_tuned[date_col] = X_test[date_col]
|
785 |
-
X_train_tuned[target_col] = y_train
|
786 |
-
X_test_tuned[target_col] = y_test
|
787 |
-
|
788 |
-
st.header("2.2 Actual vs. Predicted Plot")
|
789 |
-
# if is_panel:
|
790 |
-
# metrics_table, line, actual_vs_predicted_plot = plot_actual_vs_predicted(date, y_train, model.predict(X_train),
|
791 |
-
# model, target_column='Revenue',is_panel=True)
|
792 |
-
# else:
|
793 |
-
# metrics_table,line,actual_vs_predicted_plot=plot_actual_vs_predicted(date, y_train, model.predict(X_train), model,target_column='Revenue')
|
794 |
-
if is_panel:
|
795 |
-
metrics_table, line, actual_vs_predicted_plot = (
|
796 |
-
plot_actual_vs_predicted(
|
797 |
-
X_train_tuned[date_col],
|
798 |
-
X_train_tuned[target_col],
|
799 |
-
model_tuned.fittedvalues,
|
800 |
-
model_tuned,
|
801 |
-
target_column=sel_target_col,
|
802 |
-
is_panel=True,
|
803 |
-
)
|
804 |
-
)
|
805 |
-
else:
|
806 |
-
metrics_table, line, actual_vs_predicted_plot = (
|
807 |
-
plot_actual_vs_predicted(
|
808 |
-
X_train_tuned[date_col],
|
809 |
-
X_train_tuned[target_col],
|
810 |
-
model_tuned.predict(X_train_tuned[new_features]),
|
811 |
-
model_tuned,
|
812 |
-
target_column=sel_target_col,
|
813 |
-
is_panel=False,
|
814 |
-
)
|
815 |
-
)
|
816 |
-
# plot_actual_vs_predicted(X_train[date_col], y_train,
|
817 |
-
# model.fittedvalues, model,
|
818 |
-
# target_column='Revenue',
|
819 |
-
# is_panel=is_panel)
|
820 |
-
|
821 |
-
st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
|
822 |
-
|
823 |
-
st.markdown("## 2.3 Residual Analysis")
|
824 |
-
if is_panel:
|
825 |
-
columns = st.columns(2)
|
826 |
-
with columns[0]:
|
827 |
-
fig = plot_residual_predicted(
|
828 |
-
y_train, model_tuned.fittedvalues, X_train_tuned
|
829 |
-
)
|
830 |
-
st.plotly_chart(fig)
|
831 |
-
|
832 |
-
with columns[1]:
|
833 |
-
st.empty()
|
834 |
-
fig = qqplot(y_train, model_tuned.fittedvalues)
|
835 |
-
st.plotly_chart(fig)
|
836 |
-
|
837 |
-
with columns[0]:
|
838 |
-
fig = residual_distribution(y_train, model_tuned.fittedvalues)
|
839 |
-
st.pyplot(fig)
|
840 |
-
else:
|
841 |
-
columns = st.columns(2)
|
842 |
-
with columns[0]:
|
843 |
-
fig = plot_residual_predicted(
|
844 |
-
y_train,
|
845 |
-
model_tuned.predict(X_train_tuned[new_features]),
|
846 |
-
X_train,
|
847 |
-
)
|
848 |
-
st.plotly_chart(fig)
|
849 |
-
|
850 |
-
with columns[1]:
|
851 |
-
st.empty()
|
852 |
-
fig = qqplot(
|
853 |
-
y_train, model_tuned.predict(X_train_tuned[new_features])
|
854 |
-
)
|
855 |
-
st.plotly_chart(fig)
|
856 |
-
|
857 |
-
with columns[0]:
|
858 |
-
fig = residual_distribution(
|
859 |
-
y_train, model_tuned.predict(X_train_tuned[new_features])
|
860 |
-
)
|
861 |
-
st.pyplot(fig)
|
862 |
-
|
863 |
-
# st.session_state['is_tuned_model'][target_col] = True
|
864 |
-
# Sprint4 - saved tuned model in a dict
|
865 |
-
st.session_state["Model_Tuned"][sel_model + "__" + target_col] = {
|
866 |
-
"Model_object": model_tuned,
|
867 |
-
"feature_set": new_features,
|
868 |
-
"X_train_tuned": X_train_tuned,
|
869 |
-
"X_test_tuned": X_test_tuned,
|
870 |
-
}
|
871 |
-
|
872 |
-
# Pending
|
873 |
-
# if st.session_state['build_tuned_model']==True:
|
874 |
-
if st.session_state["Model_Tuned"] is not None:
|
875 |
-
if st.checkbox(
|
876 |
-
"Use this model to build response curves", key="save_model"
|
877 |
-
):
|
878 |
-
# save_model = st.button('Use this model to build response curves', key='saved_tuned_model')
|
879 |
-
# if save_model:
|
880 |
-
st.session_state["is_tuned_model"][target_col] = True
|
881 |
-
with open(
|
882 |
-
os.path.join(
|
883 |
-
st.session_state["project_path"], "tuned_model.pkl"
|
884 |
-
),
|
885 |
-
"wb",
|
886 |
-
) as f:
|
887 |
-
# pickle.dump(st.session_state['tuned_model'], f)
|
888 |
-
pickle.dump(st.session_state["Model_Tuned"], f) # Sprint4
|
889 |
-
|
890 |
-
st.session_state["project_dct"]["model_tuning"][
|
891 |
-
"session_state_saved"
|
892 |
-
] = {}
|
893 |
-
for key in [
|
894 |
-
"bin_dict",
|
895 |
-
"used_response_metrics",
|
896 |
-
"is_tuned_model",
|
897 |
-
"media_data",
|
898 |
-
"X_test_spends",
|
899 |
-
]:
|
900 |
-
st.session_state["project_dct"]["model_tuning"][
|
901 |
-
"session_state_saved"
|
902 |
-
][key] = st.session_state[key]
|
903 |
-
|
904 |
-
project_dct_path = os.path.join(
|
905 |
-
st.session_state["project_path"], "project_dct.pkl"
|
906 |
-
)
|
907 |
-
with open(project_dct_path, "wb") as f:
|
908 |
-
pickle.dump(st.session_state["project_dct"], f)
|
909 |
-
|
910 |
-
update_db("5_Model_Tuning.py")
|
911 |
-
|
912 |
-
st.success(sel_model + "__" + target_col + " Tuned saved!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/6_AI_Model_Results.py
DELETED
@@ -1,728 +0,0 @@
|
|
1 |
-
import plotly.express as px
|
2 |
-
import numpy as np
|
3 |
-
import plotly.graph_objects as go
|
4 |
-
import streamlit as st
|
5 |
-
import pandas as pd
|
6 |
-
import statsmodels.api as sm
|
7 |
-
from sklearn.metrics import mean_absolute_percentage_error
|
8 |
-
import sys
|
9 |
-
import os
|
10 |
-
from utilities import set_header, load_local_css, load_authenticator
|
11 |
-
import seaborn as sns
|
12 |
-
import matplotlib.pyplot as plt
|
13 |
-
import sweetviz as sv
|
14 |
-
import tempfile
|
15 |
-
from sklearn.preprocessing import MinMaxScaler
|
16 |
-
from st_aggrid import AgGrid
|
17 |
-
from st_aggrid import GridOptionsBuilder, GridUpdateMode
|
18 |
-
from st_aggrid import GridOptionsBuilder
|
19 |
-
import sys
|
20 |
-
import re
|
21 |
-
import pickle
|
22 |
-
from sklearn.metrics import r2_score, mean_absolute_percentage_error
|
23 |
-
from Data_prep_functions import plot_actual_vs_predicted
|
24 |
-
import sqlite3
|
25 |
-
from utilities import update_db
|
26 |
-
|
27 |
-
sys.setrecursionlimit(10**6)
|
28 |
-
|
29 |
-
original_stdout = sys.stdout
|
30 |
-
sys.stdout = open("temp_stdout.txt", "w")
|
31 |
-
sys.stdout.close()
|
32 |
-
sys.stdout = original_stdout
|
33 |
-
|
34 |
-
st.set_page_config(layout="wide")
|
35 |
-
load_local_css("styles.css")
|
36 |
-
set_header()
|
37 |
-
|
38 |
-
# TODO :
|
39 |
-
## 1. Add non panel model support
|
40 |
-
## 2. EDA Function
|
41 |
-
|
42 |
-
for k, v in st.session_state.items():
|
43 |
-
if k not in ["logout", "login", "config"] and not k.startswith(
|
44 |
-
"FormSubmitter"
|
45 |
-
):
|
46 |
-
st.session_state[k] = v
|
47 |
-
|
48 |
-
authenticator = st.session_state.get("authenticator")
|
49 |
-
if authenticator is None:
|
50 |
-
authenticator = load_authenticator()
|
51 |
-
|
52 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
53 |
-
auth_status = st.session_state.get("authentication_status")
|
54 |
-
|
55 |
-
if auth_status == True:
|
56 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
57 |
-
if not is_state_initiaized:
|
58 |
-
if "session_name" not in st.session_state:
|
59 |
-
st.session_state["session_name"] = None
|
60 |
-
|
61 |
-
if "project_dct" not in st.session_state:
|
62 |
-
st.error("Please load a project from Home page")
|
63 |
-
st.stop()
|
64 |
-
|
65 |
-
conn = sqlite3.connect(
|
66 |
-
r"DB/User.db", check_same_thread=False
|
67 |
-
) # connection with sql db
|
68 |
-
c = conn.cursor()
|
69 |
-
|
70 |
-
if not os.path.exists(
|
71 |
-
os.path.join(st.session_state["project_path"], "tuned_model.pkl")
|
72 |
-
):
|
73 |
-
st.error("Please save a tuned model")
|
74 |
-
st.stop()
|
75 |
-
|
76 |
-
if (
|
77 |
-
"session_state_saved"
|
78 |
-
in st.session_state["project_dct"]["model_tuning"].keys()
|
79 |
-
and st.session_state["project_dct"]["model_tuning"][
|
80 |
-
"session_state_saved"
|
81 |
-
]
|
82 |
-
!= []
|
83 |
-
):
|
84 |
-
for key in ["used_response_metrics", "media_data", "bin_dict"]:
|
85 |
-
if key not in st.session_state:
|
86 |
-
st.session_state[key] = st.session_state["project_dct"][
|
87 |
-
"model_tuning"
|
88 |
-
]["session_state_saved"][key]
|
89 |
-
st.session_state["bin_dict"] = st.session_state["project_dct"][
|
90 |
-
"model_build"
|
91 |
-
]["session_state_saved"]["bin_dict"]
|
92 |
-
|
93 |
-
media_data = st.session_state["media_data"]
|
94 |
-
panel_col = [
|
95 |
-
col.lower()
|
96 |
-
.replace(".", "_")
|
97 |
-
.replace("@", "_")
|
98 |
-
.replace(" ", "_")
|
99 |
-
.replace("-", "")
|
100 |
-
.replace(":", "")
|
101 |
-
.replace("__", "_")
|
102 |
-
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
103 |
-
][
|
104 |
-
0
|
105 |
-
] # set the panel column
|
106 |
-
is_panel = True if len(panel_col) > 0 else False
|
107 |
-
date_col = "date"
|
108 |
-
|
109 |
-
def plot_residual_predicted(actual, predicted, df_):
|
110 |
-
df_["Residuals"] = actual - pd.Series(predicted)
|
111 |
-
df_["StdResidual"] = (
|
112 |
-
df_["Residuals"] - df_["Residuals"].mean()
|
113 |
-
) / df_["Residuals"].std()
|
114 |
-
|
115 |
-
# Create a Plotly scatter plot
|
116 |
-
fig = px.scatter(
|
117 |
-
df_,
|
118 |
-
x=predicted,
|
119 |
-
y="StdResidual",
|
120 |
-
opacity=0.5,
|
121 |
-
color_discrete_sequence=["#11B6BD"],
|
122 |
-
)
|
123 |
-
|
124 |
-
# Add horizontal lines
|
125 |
-
fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
|
126 |
-
fig.add_hline(y=2, line_color="red")
|
127 |
-
fig.add_hline(y=-2, line_color="red")
|
128 |
-
|
129 |
-
fig.update_xaxes(title="Predicted")
|
130 |
-
fig.update_yaxes(title="Standardized Residuals (Actual - Predicted)")
|
131 |
-
|
132 |
-
# Set the same width and height for both figures
|
133 |
-
fig.update_layout(
|
134 |
-
title="Residuals over Predicted Values",
|
135 |
-
autosize=False,
|
136 |
-
width=600,
|
137 |
-
height=400,
|
138 |
-
)
|
139 |
-
|
140 |
-
return fig
|
141 |
-
|
142 |
-
def residual_distribution(actual, predicted):
|
143 |
-
Residuals = actual - pd.Series(predicted)
|
144 |
-
|
145 |
-
# Create a Seaborn distribution plot
|
146 |
-
sns.set(style="whitegrid")
|
147 |
-
plt.figure(figsize=(6, 4))
|
148 |
-
sns.histplot(Residuals, kde=True, color="#11B6BD")
|
149 |
-
|
150 |
-
plt.title(" Distribution of Residuals")
|
151 |
-
plt.xlabel("Residuals")
|
152 |
-
plt.ylabel("Probability Density")
|
153 |
-
|
154 |
-
return plt
|
155 |
-
|
156 |
-
def qqplot(actual, predicted):
|
157 |
-
Residuals = actual - pd.Series(predicted)
|
158 |
-
Residuals = pd.Series(Residuals)
|
159 |
-
Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
|
160 |
-
|
161 |
-
# Create a QQ plot using Plotly with custom colors
|
162 |
-
fig = go.Figure()
|
163 |
-
fig.add_trace(
|
164 |
-
go.Scatter(
|
165 |
-
x=sm.ProbPlot(Resud_std).theoretical_quantiles,
|
166 |
-
y=sm.ProbPlot(Resud_std).sample_quantiles,
|
167 |
-
mode="markers",
|
168 |
-
marker=dict(size=5, color="#11B6BD"),
|
169 |
-
name="QQ Plot",
|
170 |
-
)
|
171 |
-
)
|
172 |
-
|
173 |
-
# Add the 45-degree reference line
|
174 |
-
diagonal_line = go.Scatter(
|
175 |
-
x=[
|
176 |
-
-2,
|
177 |
-
2,
|
178 |
-
], # Adjust the x values as needed to fit the range of your data
|
179 |
-
y=[-2, 2], # Adjust the y values accordingly
|
180 |
-
mode="lines",
|
181 |
-
line=dict(color="red"), # Customize the line color and style
|
182 |
-
name=" ",
|
183 |
-
)
|
184 |
-
fig.add_trace(diagonal_line)
|
185 |
-
|
186 |
-
# Customize the layout
|
187 |
-
fig.update_layout(
|
188 |
-
title="QQ Plot of Residuals",
|
189 |
-
title_x=0.5,
|
190 |
-
autosize=False,
|
191 |
-
width=600,
|
192 |
-
height=400,
|
193 |
-
xaxis_title="Theoretical Quantiles",
|
194 |
-
yaxis_title="Sample Quantiles",
|
195 |
-
)
|
196 |
-
|
197 |
-
return fig
|
198 |
-
|
199 |
-
def get_random_effects(media_data, panel_col, mdf):
|
200 |
-
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
201 |
-
for i, market in enumerate(media_data[panel_col].unique()):
|
202 |
-
print(i, end="\r")
|
203 |
-
intercept = mdf.random_effects[market].values[0]
|
204 |
-
random_eff_df.loc[i, "random_effect"] = intercept
|
205 |
-
random_eff_df.loc[i, panel_col] = market
|
206 |
-
|
207 |
-
return random_eff_df
|
208 |
-
|
209 |
-
def mdf_predict(X_df, mdf, random_eff_df):
|
210 |
-
X = X_df.copy()
|
211 |
-
X = pd.merge(
|
212 |
-
X,
|
213 |
-
random_eff_df[[panel_col, "random_effect"]],
|
214 |
-
on=panel_col,
|
215 |
-
how="left",
|
216 |
-
)
|
217 |
-
X["pred_fixed_effect"] = mdf.predict(X)
|
218 |
-
|
219 |
-
X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
|
220 |
-
X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
|
221 |
-
return X
|
222 |
-
|
223 |
-
def metrics_df_panel(model_dict):
|
224 |
-
metrics_df = pd.DataFrame(
|
225 |
-
columns=[
|
226 |
-
"Model",
|
227 |
-
"R2",
|
228 |
-
"ADJR2",
|
229 |
-
"Train Mape",
|
230 |
-
"Test Mape",
|
231 |
-
"Summary",
|
232 |
-
"Model_object",
|
233 |
-
]
|
234 |
-
)
|
235 |
-
i = 0
|
236 |
-
for key in model_dict.keys():
|
237 |
-
target = key.split("__")[1]
|
238 |
-
metrics_df.at[i, "Model"] = target
|
239 |
-
y = model_dict[key]["X_train_tuned"][target]
|
240 |
-
|
241 |
-
random_df = get_random_effects(
|
242 |
-
media_data, panel_col, model_dict[key]["Model_object"]
|
243 |
-
)
|
244 |
-
pred = mdf_predict(
|
245 |
-
model_dict[key]["X_train_tuned"],
|
246 |
-
model_dict[key]["Model_object"],
|
247 |
-
random_df,
|
248 |
-
)["pred"]
|
249 |
-
|
250 |
-
ytest = model_dict[key]["X_test_tuned"][target]
|
251 |
-
predtest = mdf_predict(
|
252 |
-
model_dict[key]["X_test_tuned"],
|
253 |
-
model_dict[key]["Model_object"],
|
254 |
-
random_df,
|
255 |
-
)["pred"]
|
256 |
-
|
257 |
-
metrics_df.at[i, "R2"] = r2_score(y, pred)
|
258 |
-
metrics_df.at[i, "ADJR2"] = 1 - (1 - metrics_df.loc[i, "R2"]) * (
|
259 |
-
len(y) - 1
|
260 |
-
) / (len(y) - len(model_dict[key]["feature_set"]) - 1)
|
261 |
-
metrics_df.at[i, "Train Mape"] = mean_absolute_percentage_error(
|
262 |
-
y, pred
|
263 |
-
)
|
264 |
-
metrics_df.at[i, "Test Mape"] = mean_absolute_percentage_error(
|
265 |
-
ytest, predtest
|
266 |
-
)
|
267 |
-
metrics_df.at[i, "Summary"] = model_dict[key][
|
268 |
-
"Model_object"
|
269 |
-
].summary()
|
270 |
-
metrics_df.at[i, "Model_object"] = model_dict[key]["Model_object"]
|
271 |
-
i += 1
|
272 |
-
metrics_df = np.round(metrics_df, 2)
|
273 |
-
return metrics_df
|
274 |
-
|
275 |
-
with open(
|
276 |
-
os.path.join(
|
277 |
-
st.session_state["project_path"], "final_df_transformed.pkl"
|
278 |
-
),
|
279 |
-
"rb",
|
280 |
-
) as f:
|
281 |
-
data = pickle.load(f)
|
282 |
-
transformed_data = data["final_df_transformed"]
|
283 |
-
with open(
|
284 |
-
os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
|
285 |
-
) as f:
|
286 |
-
data = pickle.load(f)
|
287 |
-
st.session_state["bin_dict"] = data["bin_dict"]
|
288 |
-
with open(
|
289 |
-
os.path.join(st.session_state["project_path"], "tuned_model.pkl"), "rb"
|
290 |
-
) as file:
|
291 |
-
tuned_model_dict = pickle.load(file)
|
292 |
-
feature_set_dct = {
|
293 |
-
key.split("__")[1]: key_dict["feature_set"]
|
294 |
-
for key, key_dict in tuned_model_dict.items()
|
295 |
-
}
|
296 |
-
|
297 |
-
# """ the above part should be modified so that we are fetching features set from the saved model"""
|
298 |
-
|
299 |
-
def contributions(X, model, target):
|
300 |
-
X1 = X.copy()
|
301 |
-
for j, col in enumerate(X1.columns):
|
302 |
-
X1[col] = X1[col] * model.params.values[j]
|
303 |
-
|
304 |
-
contributions = np.round(
|
305 |
-
(X1.sum() / sum(X1.sum()) * 100).sort_values(ascending=False), 2
|
306 |
-
)
|
307 |
-
contributions = (
|
308 |
-
pd.DataFrame(contributions, columns=target)
|
309 |
-
.reset_index()
|
310 |
-
.rename(columns={"index": "Channel"})
|
311 |
-
)
|
312 |
-
contributions["Channel"] = [
|
313 |
-
re.split(r"_imp|_cli", col)[0] for col in contributions["Channel"]
|
314 |
-
]
|
315 |
-
|
316 |
-
return contributions
|
317 |
-
|
318 |
-
if "contribution_df" not in st.session_state:
|
319 |
-
st.session_state["contribution_df"] = None
|
320 |
-
|
321 |
-
def contributions_panel(model_dict):
|
322 |
-
media_data = st.session_state["media_data"]
|
323 |
-
contribution_df = pd.DataFrame(columns=["Channel"])
|
324 |
-
for key in model_dict.keys():
|
325 |
-
best_feature_set = model_dict[key]["feature_set"]
|
326 |
-
model = model_dict[key]["Model_object"]
|
327 |
-
target = key.split("__")[1]
|
328 |
-
X_train = model_dict[key]["X_train_tuned"]
|
329 |
-
contri_df = pd.DataFrame()
|
330 |
-
|
331 |
-
y = []
|
332 |
-
y_pred = []
|
333 |
-
|
334 |
-
random_eff_df = get_random_effects(media_data, panel_col, model)
|
335 |
-
random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
|
336 |
-
random_eff_df["panel_effect"] = (
|
337 |
-
random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
|
338 |
-
)
|
339 |
-
|
340 |
-
coef_df = pd.DataFrame(model.fe_params)
|
341 |
-
coef_df.reset_index(inplace=True)
|
342 |
-
coef_df.columns = ["feature", "coef"]
|
343 |
-
|
344 |
-
x_train_contribution = X_train.copy()
|
345 |
-
x_train_contribution = mdf_predict(
|
346 |
-
x_train_contribution, model, random_eff_df
|
347 |
-
)
|
348 |
-
|
349 |
-
x_train_contribution = pd.merge(
|
350 |
-
x_train_contribution,
|
351 |
-
random_eff_df[[panel_col, "panel_effect"]],
|
352 |
-
on=panel_col,
|
353 |
-
how="left",
|
354 |
-
)
|
355 |
-
|
356 |
-
for i in range(len(coef_df))[1:]:
|
357 |
-
coef = coef_df.loc[i, "coef"]
|
358 |
-
col = coef_df.loc[i, "feature"]
|
359 |
-
x_train_contribution[str(col) + "_contr"] = (
|
360 |
-
coef * x_train_contribution[col]
|
361 |
-
)
|
362 |
-
|
363 |
-
# x_train_contribution['sum_contributions'] = x_train_contribution.filter(regex="contr").sum(axis=1)
|
364 |
-
# x_train_contribution['sum_contributions'] = x_train_contribution['sum_contributions'] + x_train_contribution[
|
365 |
-
# 'panel_effect']
|
366 |
-
|
367 |
-
base_cols = ["panel_effect"] + [
|
368 |
-
c
|
369 |
-
for c in x_train_contribution.filter(regex="contr").columns
|
370 |
-
if c
|
371 |
-
in [
|
372 |
-
"Week_number_contr",
|
373 |
-
"Trend_contr",
|
374 |
-
"sine_wave_contr",
|
375 |
-
"cosine_wave_contr",
|
376 |
-
]
|
377 |
-
]
|
378 |
-
x_train_contribution["base_contr"] = x_train_contribution[
|
379 |
-
base_cols
|
380 |
-
].sum(axis=1)
|
381 |
-
x_train_contribution.drop(columns=base_cols, inplace=True)
|
382 |
-
# x_train_contribution.to_csv("Test/smr_x_train_contribution.csv", index=False)
|
383 |
-
|
384 |
-
contri_df = pd.DataFrame(
|
385 |
-
x_train_contribution.filter(regex="contr").sum(axis=0)
|
386 |
-
)
|
387 |
-
contri_df.reset_index(inplace=True)
|
388 |
-
contri_df.columns = ["Channel", target]
|
389 |
-
contri_df["Channel"] = (
|
390 |
-
contri_df["Channel"]
|
391 |
-
.str.split("(_impres|_clicks)")
|
392 |
-
.apply(lambda c: c[0])
|
393 |
-
)
|
394 |
-
contri_df[target] = (
|
395 |
-
100 * contri_df[target] / contri_df[target].sum()
|
396 |
-
)
|
397 |
-
contri_df["Channel"].replace("base_contr", "base", inplace=True)
|
398 |
-
contribution_df = pd.merge(
|
399 |
-
contribution_df, contri_df, on="Channel", how="outer"
|
400 |
-
)
|
401 |
-
# st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
|
402 |
-
return contribution_df
|
403 |
-
|
404 |
-
metrics_table = metrics_df_panel(tuned_model_dict)
|
405 |
-
|
406 |
-
eda_columns = st.columns(2)
|
407 |
-
with eda_columns[1]:
|
408 |
-
eda = st.button(
|
409 |
-
"Generate EDA Report",
|
410 |
-
help="Click to generate a bivariate report for the selected response metric from the table below.",
|
411 |
-
)
|
412 |
-
|
413 |
-
# st.markdown('Model Metrics')
|
414 |
-
st.title("Contribution Overview")
|
415 |
-
options = st.session_state["used_response_metrics"]
|
416 |
-
options = [
|
417 |
-
opt.lower()
|
418 |
-
.replace(" ", "_")
|
419 |
-
.replace("-", "")
|
420 |
-
.replace(":", "")
|
421 |
-
.replace("__", "_")
|
422 |
-
for opt in options
|
423 |
-
]
|
424 |
-
|
425 |
-
default_options = (
|
426 |
-
st.session_state["project_dct"]["saved_model_results"].get(
|
427 |
-
"selected_options"
|
428 |
-
)
|
429 |
-
if st.session_state["project_dct"]["saved_model_results"].get(
|
430 |
-
"selected_options"
|
431 |
-
)
|
432 |
-
is not None
|
433 |
-
else [options[-1]]
|
434 |
-
)
|
435 |
-
for i in default_options:
|
436 |
-
if i not in options:
|
437 |
-
st.write(i)
|
438 |
-
default_options.remove(i)
|
439 |
-
contribution_selections = st.multiselect(
|
440 |
-
"Select the Response Metrics to compare contributions",
|
441 |
-
options,
|
442 |
-
default=default_options,
|
443 |
-
)
|
444 |
-
trace_data = []
|
445 |
-
|
446 |
-
st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
|
447 |
-
|
448 |
-
for selection in contribution_selections:
|
449 |
-
|
450 |
-
trace = go.Bar(
|
451 |
-
x=st.session_state["contribution_df"]["Channel"],
|
452 |
-
y=st.session_state["contribution_df"][selection],
|
453 |
-
name=selection,
|
454 |
-
text=np.round(st.session_state["contribution_df"][selection], 0)
|
455 |
-
.astype(int)
|
456 |
-
.astype(str)
|
457 |
-
+ "%",
|
458 |
-
textposition="outside",
|
459 |
-
)
|
460 |
-
trace_data.append(trace)
|
461 |
-
|
462 |
-
layout = go.Layout(
|
463 |
-
title="Metrics Contribution by Channel",
|
464 |
-
xaxis=dict(title="Channel Name"),
|
465 |
-
yaxis=dict(title="Metrics Contribution"),
|
466 |
-
barmode="group",
|
467 |
-
)
|
468 |
-
fig = go.Figure(data=trace_data, layout=layout)
|
469 |
-
st.plotly_chart(fig, use_container_width=True)
|
470 |
-
|
471 |
-
############################################ Waterfall Chart ############################################
|
472 |
-
# import plotly.graph_objects as go
|
473 |
-
|
474 |
-
# # Initialize a Plotly figure
|
475 |
-
# fig = go.Figure()
|
476 |
-
|
477 |
-
# for selection in contribution_selections:
|
478 |
-
# # Ensure y_values are numeric
|
479 |
-
# y_values = st.session_state["contribution_df"][selection].values.astype(float)
|
480 |
-
|
481 |
-
# # Generating text labels for each bar, ensuring operations are compatible with string formats
|
482 |
-
# text_values = [f"{val}%" for val in np.round(y_values, 0).astype(int)]
|
483 |
-
|
484 |
-
# fig.add_trace(
|
485 |
-
# go.Waterfall(
|
486 |
-
# name=selection,
|
487 |
-
# orientation="v",
|
488 |
-
# measure=["relative"]
|
489 |
-
# * len(y_values), # Adjust if you have absolute values at certain points
|
490 |
-
# x=st.session_state["contribution_df"]["Channel"].tolist(),
|
491 |
-
# text=text_values,
|
492 |
-
# textposition="outside",
|
493 |
-
# y=y_values,
|
494 |
-
# increasing={"marker": {"color": "green"}},
|
495 |
-
# decreasing={"marker": {"color": "red"}},
|
496 |
-
# totals={"marker": {"color": "blue"}},
|
497 |
-
# )
|
498 |
-
# )
|
499 |
-
|
500 |
-
# fig.update_layout(
|
501 |
-
# title="Metrics Contribution by Channel",
|
502 |
-
# xaxis={"title": "Channel Name"},
|
503 |
-
# yaxis={"title": "Metrics Contribution"},
|
504 |
-
# height=600,
|
505 |
-
# )
|
506 |
-
|
507 |
-
# # Displaying the waterfall chart in Streamlit
|
508 |
-
# st.plotly_chart(fig, use_container_width=True)
|
509 |
-
|
510 |
-
import plotly.graph_objects as go
|
511 |
-
|
512 |
-
# Initialize a Plotly figure
|
513 |
-
fig = go.Figure()
|
514 |
-
|
515 |
-
for selection in contribution_selections:
|
516 |
-
# Ensure contributions are numeric
|
517 |
-
contributions = (
|
518 |
-
st.session_state["contribution_df"][selection]
|
519 |
-
.values.astype(float)
|
520 |
-
.tolist()
|
521 |
-
)
|
522 |
-
channel_names = st.session_state["contribution_df"]["Channel"].tolist()
|
523 |
-
|
524 |
-
display_name, display_contribution, base_contribution = [], [], 0
|
525 |
-
for channel_name, contribution in zip(channel_names, contributions):
|
526 |
-
if channel_name != "const" and channel_name != "base":
|
527 |
-
display_name.append(channel_name)
|
528 |
-
display_contribution.append(contribution)
|
529 |
-
else:
|
530 |
-
base_contribution = contribution
|
531 |
-
|
532 |
-
display_name = ["Base Sales"] + display_name
|
533 |
-
display_contribution = [base_contribution] + display_contribution
|
534 |
-
|
535 |
-
# Generating text labels for each bar, ensuring operations are compatible with string formats
|
536 |
-
text_values = [
|
537 |
-
f"{val}%" for val in np.round(display_contribution, 0).astype(int)
|
538 |
-
]
|
539 |
-
|
540 |
-
fig.add_trace(
|
541 |
-
go.Waterfall(
|
542 |
-
orientation="v",
|
543 |
-
measure=["relative"]
|
544 |
-
* len(
|
545 |
-
display_contribution
|
546 |
-
), # Adjust if you have absolute values at certain points
|
547 |
-
x=display_name,
|
548 |
-
text=text_values,
|
549 |
-
textposition="outside",
|
550 |
-
y=display_contribution,
|
551 |
-
increasing={"marker": {"color": "green"}},
|
552 |
-
decreasing={"marker": {"color": "red"}},
|
553 |
-
totals={"marker": {"color": "blue"}},
|
554 |
-
)
|
555 |
-
)
|
556 |
-
|
557 |
-
fig.update_layout(
|
558 |
-
title="Metrics Contribution by Channel",
|
559 |
-
xaxis={"title": "Channel Name"},
|
560 |
-
yaxis={"title": "Metrics Contribution"},
|
561 |
-
height=600,
|
562 |
-
)
|
563 |
-
|
564 |
-
# Displaying the waterfall chart in Streamlit
|
565 |
-
st.plotly_chart(fig, use_container_width=True)
|
566 |
-
|
567 |
-
############################################ Waterfall Chart ############################################
|
568 |
-
|
569 |
-
st.title("Analysis of Models Result")
|
570 |
-
# st.markdown()
|
571 |
-
previous_selection = st.session_state["project_dct"][
|
572 |
-
"saved_model_results"
|
573 |
-
].get("model_grid_sel", [1])
|
574 |
-
st.write(np.round(metrics_table, 2))
|
575 |
-
gd_table = metrics_table.iloc[:, :-2]
|
576 |
-
|
577 |
-
gd = GridOptionsBuilder.from_dataframe(gd_table)
|
578 |
-
# gd.configure_pagination(enabled=True)
|
579 |
-
gd.configure_selection(
|
580 |
-
use_checkbox=True,
|
581 |
-
selection_mode="single",
|
582 |
-
pre_select_all_rows=False,
|
583 |
-
pre_selected_rows=previous_selection,
|
584 |
-
)
|
585 |
-
|
586 |
-
gridoptions = gd.build()
|
587 |
-
table = AgGrid(
|
588 |
-
gd_table,
|
589 |
-
gridOptions=gridoptions,
|
590 |
-
fit_columns_on_grid_load=True,
|
591 |
-
height=200,
|
592 |
-
)
|
593 |
-
# table=metrics_table.iloc[:,:-2]
|
594 |
-
# table.insert(0, "Select", False)
|
595 |
-
# selection_table=st.data_editor(table,column_config={"Select": st.column_config.CheckboxColumn(required=True)})
|
596 |
-
if len(table.selected_rows) > 0:
|
597 |
-
st.session_state["project_dct"]["saved_model_results"][
|
598 |
-
"model_grid_sel"
|
599 |
-
] = table.selected_rows[0]["_selectedRowNodeInfo"]["nodeRowIndex"]
|
600 |
-
if len(table.selected_rows) == 0:
|
601 |
-
st.warning(
|
602 |
-
"Click on the checkbox to view comprehensive results of the selected model."
|
603 |
-
)
|
604 |
-
st.stop()
|
605 |
-
else:
|
606 |
-
target_column = table.selected_rows[0]["Model"]
|
607 |
-
feature_set = feature_set_dct[target_column]
|
608 |
-
|
609 |
-
# with eda_columns[1]:
|
610 |
-
# if eda:
|
611 |
-
# def generate_report_with_target(channel_data, target_feature):
|
612 |
-
# report = sv.analyze(
|
613 |
-
# [channel_data, "Dataset"], target_feat=target_feature, verbose=False
|
614 |
-
# )
|
615 |
-
# temp_dir = tempfile.mkdtemp()
|
616 |
-
# report_path = os.path.join(temp_dir, "report.html")
|
617 |
-
# report.show_html(
|
618 |
-
# filepath=report_path, open_browser=False
|
619 |
-
# ) # Generate the report as an HTML file
|
620 |
-
# return report_path
|
621 |
-
#
|
622 |
-
# report_data = transformed_data[feature_set]
|
623 |
-
# report_data[target_column] = transformed_data[target_column]
|
624 |
-
# report_file = generate_report_with_target(report_data, target_column)
|
625 |
-
#
|
626 |
-
# if os.path.exists(report_file):
|
627 |
-
# with open(report_file, "rb") as f:
|
628 |
-
# st.download_button(
|
629 |
-
# label="Download EDA Report",
|
630 |
-
# data=f.read(),
|
631 |
-
# file_name="report.html",
|
632 |
-
# mime="text/html",
|
633 |
-
# )
|
634 |
-
# else:
|
635 |
-
# st.warning("Report generation failed. Unable to find the report file.")
|
636 |
-
|
637 |
-
model = metrics_table[metrics_table["Model"] == target_column][
|
638 |
-
"Model_object"
|
639 |
-
].iloc[0]
|
640 |
-
target = metrics_table[metrics_table["Model"] == target_column][
|
641 |
-
"Model"
|
642 |
-
].iloc[0]
|
643 |
-
st.header("Model Summary")
|
644 |
-
st.write(model.summary())
|
645 |
-
|
646 |
-
sel_dict = tuned_model_dict[
|
647 |
-
[k for k in tuned_model_dict.keys() if k.split("__")[1] == target][0]
|
648 |
-
]
|
649 |
-
X_train = sel_dict["X_train_tuned"]
|
650 |
-
y_train = X_train[target]
|
651 |
-
random_effects = get_random_effects(media_data, panel_col, model)
|
652 |
-
pred = mdf_predict(X_train, model, random_effects)["pred"]
|
653 |
-
|
654 |
-
X_test = sel_dict["X_test_tuned"]
|
655 |
-
y_test = X_test[target]
|
656 |
-
predtest = mdf_predict(X_test, model, random_effects)["pred"]
|
657 |
-
metrics_table_train, _, fig_train = plot_actual_vs_predicted(
|
658 |
-
X_train[date_col],
|
659 |
-
y_train,
|
660 |
-
pred,
|
661 |
-
model,
|
662 |
-
target_column=target_column,
|
663 |
-
flag=None,
|
664 |
-
repeat_all_years=False,
|
665 |
-
is_panel=is_panel,
|
666 |
-
)
|
667 |
-
|
668 |
-
metrics_table_test, _, fig_test = plot_actual_vs_predicted(
|
669 |
-
X_test[date_col],
|
670 |
-
y_test,
|
671 |
-
predtest,
|
672 |
-
model,
|
673 |
-
target_column=target_column,
|
674 |
-
flag=None,
|
675 |
-
repeat_all_years=False,
|
676 |
-
is_panel=is_panel,
|
677 |
-
)
|
678 |
-
|
679 |
-
metrics_table_train = metrics_table_train.set_index("Metric").transpose()
|
680 |
-
metrics_table_train.index = ["Train"]
|
681 |
-
metrics_table_test = metrics_table_test.set_index("Metric").transpose()
|
682 |
-
metrics_table_test.index = ["test"]
|
683 |
-
metrics_table = np.round(
|
684 |
-
pd.concat([metrics_table_train, metrics_table_test]), 2
|
685 |
-
)
|
686 |
-
|
687 |
-
st.markdown("Result Overview")
|
688 |
-
st.dataframe(np.round(metrics_table, 2), use_container_width=True)
|
689 |
-
|
690 |
-
st.subheader("Actual vs Predicted Plot Train")
|
691 |
-
|
692 |
-
st.plotly_chart(fig_train, use_container_width=True)
|
693 |
-
st.subheader("Actual vs Predicted Plot Test")
|
694 |
-
st.plotly_chart(fig_test, use_container_width=True)
|
695 |
-
|
696 |
-
st.markdown("## Residual Analysis")
|
697 |
-
columns = st.columns(2)
|
698 |
-
|
699 |
-
Xtrain1 = X_train.copy()
|
700 |
-
with columns[0]:
|
701 |
-
fig = plot_residual_predicted(y_train, model.predict(Xtrain1), Xtrain1)
|
702 |
-
st.plotly_chart(fig)
|
703 |
-
|
704 |
-
with columns[1]:
|
705 |
-
st.empty()
|
706 |
-
fig = qqplot(y_train, model.predict(X_train))
|
707 |
-
st.plotly_chart(fig)
|
708 |
-
|
709 |
-
with columns[0]:
|
710 |
-
fig = residual_distribution(y_train, model.predict(X_train))
|
711 |
-
st.pyplot(fig)
|
712 |
-
|
713 |
-
update_db("6_AI_Model_Result.py")
|
714 |
-
|
715 |
-
|
716 |
-
elif auth_status == False:
|
717 |
-
st.error("Username/Password is incorrect")
|
718 |
-
try:
|
719 |
-
username_forgot_pw, email_forgot_password, random_password = (
|
720 |
-
authenticator.forgot_password("Forgot password")
|
721 |
-
)
|
722 |
-
if username_forgot_pw:
|
723 |
-
st.success("New password sent securely")
|
724 |
-
# Random password to be transferred to the user securely
|
725 |
-
elif username_forgot_pw == False:
|
726 |
-
st.error("Username not found")
|
727 |
-
except Exception as e:
|
728 |
-
st.error(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/7_Current_Media_Performance.py
DELETED
@@ -1,573 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
MMO Build Sprint 3
|
3 |
-
additions : contributions calculated using tuned Mixed LM model
|
4 |
-
pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
|
5 |
-
|
6 |
-
MMO Build Sprint 4
|
7 |
-
additions : response metrics selection
|
8 |
-
pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
|
9 |
-
"""
|
10 |
-
|
11 |
-
import streamlit as st
|
12 |
-
import pandas as pd
|
13 |
-
from sklearn.preprocessing import MinMaxScaler
|
14 |
-
import pickle
|
15 |
-
import os
|
16 |
-
|
17 |
-
from utilities_with_panel import load_local_css, set_header
|
18 |
-
import yaml
|
19 |
-
from yaml import SafeLoader
|
20 |
-
import streamlit_authenticator as stauth
|
21 |
-
import sqlite3
|
22 |
-
from utilities import update_db
|
23 |
-
|
24 |
-
st.set_page_config(layout="wide")
|
25 |
-
load_local_css("styles.css")
|
26 |
-
set_header()
|
27 |
-
for k, v in st.session_state.items():
|
28 |
-
# print(k, v)
|
29 |
-
if k not in [
|
30 |
-
"logout",
|
31 |
-
"login",
|
32 |
-
"config",
|
33 |
-
"build_tuned_model",
|
34 |
-
] and not k.startswith("FormSubmitter"):
|
35 |
-
st.session_state[k] = v
|
36 |
-
with open("config.yaml") as file:
|
37 |
-
config = yaml.load(file, Loader=SafeLoader)
|
38 |
-
st.session_state["config"] = config
|
39 |
-
authenticator = stauth.Authenticate(
|
40 |
-
config["credentials"],
|
41 |
-
config["cookie"]["name"],
|
42 |
-
config["cookie"]["key"],
|
43 |
-
config["cookie"]["expiry_days"],
|
44 |
-
config["preauthorized"],
|
45 |
-
)
|
46 |
-
st.session_state["authenticator"] = authenticator
|
47 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
48 |
-
auth_status = st.session_state.get("authentication_status")
|
49 |
-
|
50 |
-
if auth_status == True:
|
51 |
-
authenticator.logout("Logout", "main")
|
52 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
53 |
-
|
54 |
-
if "project_dct" not in st.session_state:
|
55 |
-
st.error("Please load a project from Home page")
|
56 |
-
st.stop()
|
57 |
-
|
58 |
-
conn = sqlite3.connect(
|
59 |
-
r"DB/User.db", check_same_thread=False
|
60 |
-
) # connection with sql db
|
61 |
-
c = conn.cursor()
|
62 |
-
|
63 |
-
if not os.path.exists(
|
64 |
-
os.path.join(st.session_state["project_path"], "tuned_model.pkl")
|
65 |
-
):
|
66 |
-
st.error("Please save a tuned model")
|
67 |
-
st.stop()
|
68 |
-
|
69 |
-
if (
|
70 |
-
"session_state_saved"
|
71 |
-
in st.session_state["project_dct"]["model_tuning"].keys()
|
72 |
-
and st.session_state["project_dct"]["model_tuning"][
|
73 |
-
"session_state_saved"
|
74 |
-
]
|
75 |
-
!= []
|
76 |
-
):
|
77 |
-
for key in [
|
78 |
-
"used_response_metrics",
|
79 |
-
"is_tuned_model",
|
80 |
-
"media_data",
|
81 |
-
"X_test_spends",
|
82 |
-
]:
|
83 |
-
st.session_state[key] = st.session_state["project_dct"][
|
84 |
-
"model_tuning"
|
85 |
-
]["session_state_saved"][key]
|
86 |
-
elif (
|
87 |
-
"session_state_saved"
|
88 |
-
in st.session_state["project_dct"]["model_build"].keys()
|
89 |
-
and st.session_state["project_dct"]["model_build"][
|
90 |
-
"session_state_saved"
|
91 |
-
]
|
92 |
-
!= []
|
93 |
-
):
|
94 |
-
for key in [
|
95 |
-
"used_response_metrics",
|
96 |
-
"date",
|
97 |
-
"saved_model_names",
|
98 |
-
"media_data",
|
99 |
-
"X_test_spends",
|
100 |
-
]:
|
101 |
-
st.session_state[key] = st.session_state["project_dct"][
|
102 |
-
"model_build"
|
103 |
-
]["session_state_saved"][key]
|
104 |
-
else:
|
105 |
-
st.error("Please tune a model first")
|
106 |
-
st.session_state["bin_dict"] = st.session_state["project_dct"][
|
107 |
-
"model_build"
|
108 |
-
]["session_state_saved"]["bin_dict"]
|
109 |
-
st.session_state["media_data"].columns = [
|
110 |
-
c.lower() for c in st.session_state["media_data"].columns
|
111 |
-
]
|
112 |
-
|
113 |
-
from utilities_with_panel import (
|
114 |
-
overview_test_data_prep_panel,
|
115 |
-
overview_test_data_prep_nonpanel,
|
116 |
-
initialize_data,
|
117 |
-
create_channel_summary,
|
118 |
-
create_contribution_pie,
|
119 |
-
create_contribuion_stacked_plot,
|
120 |
-
create_channel_spends_sales_plot,
|
121 |
-
format_numbers,
|
122 |
-
channel_name_formating,
|
123 |
-
)
|
124 |
-
|
125 |
-
import plotly.graph_objects as go
|
126 |
-
import streamlit_authenticator as stauth
|
127 |
-
import yaml
|
128 |
-
from yaml import SafeLoader
|
129 |
-
import time
|
130 |
-
|
131 |
-
def get_random_effects(media_data, panel_col, mdf):
|
132 |
-
random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
|
133 |
-
for i, market in enumerate(media_data[panel_col].unique()):
|
134 |
-
print(i, end="\r")
|
135 |
-
intercept = mdf.random_effects[market].values[0]
|
136 |
-
random_eff_df.loc[i, "random_effect"] = intercept
|
137 |
-
random_eff_df.loc[i, panel_col] = market
|
138 |
-
|
139 |
-
return random_eff_df
|
140 |
-
|
141 |
-
def process_train_and_test(train, test, features, panel_col, target_col):
|
142 |
-
X1 = train[features]
|
143 |
-
|
144 |
-
ss = MinMaxScaler()
|
145 |
-
X1 = pd.DataFrame(ss.fit_transform(X1), columns=X1.columns)
|
146 |
-
|
147 |
-
X1[panel_col] = train[panel_col]
|
148 |
-
X1[target_col] = train[target_col]
|
149 |
-
|
150 |
-
if test is not None:
|
151 |
-
X2 = test[features]
|
152 |
-
X2 = pd.DataFrame(ss.transform(X2), columns=X2.columns)
|
153 |
-
X2[panel_col] = test[panel_col]
|
154 |
-
X2[target_col] = test[target_col]
|
155 |
-
return X1, X2
|
156 |
-
return X1
|
157 |
-
|
158 |
-
def mdf_predict(X_df, mdf, random_eff_df):
|
159 |
-
X = X_df.copy()
|
160 |
-
X = pd.merge(
|
161 |
-
X,
|
162 |
-
random_eff_df[[panel_col, "random_effect"]],
|
163 |
-
on=panel_col,
|
164 |
-
how="left",
|
165 |
-
)
|
166 |
-
X["pred_fixed_effect"] = mdf.predict(X)
|
167 |
-
|
168 |
-
X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
|
169 |
-
X.to_csv("Test/merged_df_contri.csv", index=False)
|
170 |
-
X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
|
171 |
-
|
172 |
-
return X
|
173 |
-
|
174 |
-
# target='Revenue'
|
175 |
-
|
176 |
-
# is_panel=False
|
177 |
-
# is_panel = st.session_state['is_panel']
|
178 |
-
panel_col = [
|
179 |
-
col.lower()
|
180 |
-
.replace(".", "_")
|
181 |
-
.replace("@", "_")
|
182 |
-
.replace(" ", "_")
|
183 |
-
.replace("-", "")
|
184 |
-
.replace(":", "")
|
185 |
-
.replace("__", "_")
|
186 |
-
for col in st.session_state["bin_dict"]["Panel Level 1"]
|
187 |
-
][
|
188 |
-
0
|
189 |
-
] # set the panel column
|
190 |
-
is_panel = True if len(panel_col) > 0 else False
|
191 |
-
date_col = "date"
|
192 |
-
|
193 |
-
# Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
|
194 |
-
if (
|
195 |
-
"used_response_metrics" in st.session_state
|
196 |
-
and st.session_state["used_response_metrics"] != []
|
197 |
-
):
|
198 |
-
sel_target_col = st.selectbox(
|
199 |
-
"Select the response metric",
|
200 |
-
st.session_state["used_response_metrics"],
|
201 |
-
)
|
202 |
-
target_col = (
|
203 |
-
sel_target_col.lower()
|
204 |
-
.replace(" ", "_")
|
205 |
-
.replace("-", "")
|
206 |
-
.replace(":", "")
|
207 |
-
.replace("__", "_")
|
208 |
-
)
|
209 |
-
else:
|
210 |
-
sel_target_col = "Total Approved Accounts - Revenue"
|
211 |
-
target_col = "total_approved_accounts_revenue"
|
212 |
-
|
213 |
-
target = sel_target_col
|
214 |
-
|
215 |
-
# Sprint4 - Look through all saved tuned models, only show saved models of the sel resp metric (target_col)
|
216 |
-
# saved_models = st.session_state['saved_model_names']
|
217 |
-
# Sprint4 - get the model obj of the selected model
|
218 |
-
# st.write(sel_model_dict)
|
219 |
-
|
220 |
-
# Sprint3 - Contribution
|
221 |
-
if is_panel:
|
222 |
-
# read tuned mixedLM model
|
223 |
-
# if st.session_state["tuned_model"] is not None :
|
224 |
-
if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
|
225 |
-
with open(
|
226 |
-
os.path.join(
|
227 |
-
st.session_state["project_path"], "tuned_model.pkl"
|
228 |
-
),
|
229 |
-
"rb",
|
230 |
-
) as file:
|
231 |
-
model_dict = pickle.load(file)
|
232 |
-
saved_models = list(model_dict.keys())
|
233 |
-
# st.write(saved_models)
|
234 |
-
required_saved_models = [
|
235 |
-
m.split("__")[0]
|
236 |
-
for m in saved_models
|
237 |
-
if m.split("__")[1] == target_col
|
238 |
-
]
|
239 |
-
sel_model = st.selectbox(
|
240 |
-
"Select the model to review", required_saved_models
|
241 |
-
)
|
242 |
-
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
243 |
-
|
244 |
-
model = sel_model_dict["Model_object"]
|
245 |
-
X_train = sel_model_dict["X_train_tuned"]
|
246 |
-
X_test = sel_model_dict["X_test_tuned"]
|
247 |
-
best_feature_set = sel_model_dict["feature_set"]
|
248 |
-
|
249 |
-
else: # if non tuned model to be used # Pending
|
250 |
-
with open(
|
251 |
-
os.path.join(
|
252 |
-
st.session_state["project_path"], "best_models.pkl"
|
253 |
-
),
|
254 |
-
"rb",
|
255 |
-
) as file:
|
256 |
-
model_dict = pickle.load(file)
|
257 |
-
# st.write(model_dict)
|
258 |
-
saved_models = list(model_dict.keys())
|
259 |
-
required_saved_models = [
|
260 |
-
m.split("__")[0]
|
261 |
-
for m in saved_models
|
262 |
-
if m.split("__")[1] == target_col
|
263 |
-
]
|
264 |
-
sel_model = st.selectbox(
|
265 |
-
"Select the model to review", required_saved_models
|
266 |
-
)
|
267 |
-
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
268 |
-
# st.write(sel_model, sel_model_dict)
|
269 |
-
model = sel_model_dict["Model_object"]
|
270 |
-
X_train = sel_model_dict["X_train"]
|
271 |
-
X_test = sel_model_dict["X_test"]
|
272 |
-
best_feature_set = sel_model_dict["feature_set"]
|
273 |
-
|
274 |
-
# Calculate contributions
|
275 |
-
|
276 |
-
with open(
|
277 |
-
os.path.join(st.session_state["project_path"], "data_import.pkl"),
|
278 |
-
"rb",
|
279 |
-
) as f:
|
280 |
-
data = pickle.load(f)
|
281 |
-
|
282 |
-
# Accessing the loaded objects
|
283 |
-
st.session_state["orig_media_data"] = data["final_df"]
|
284 |
-
|
285 |
-
st.session_state["orig_media_data"].columns = [
|
286 |
-
col.lower()
|
287 |
-
.replace(".", "_")
|
288 |
-
.replace("@", "_")
|
289 |
-
.replace(" ", "_")
|
290 |
-
.replace("-", "")
|
291 |
-
.replace(":", "")
|
292 |
-
.replace("__", "_")
|
293 |
-
for col in st.session_state["orig_media_data"].columns
|
294 |
-
]
|
295 |
-
|
296 |
-
media_data = st.session_state["media_data"]
|
297 |
-
|
298 |
-
# st.session_state['orig_media_data']=st.session_state["media_data"]
|
299 |
-
|
300 |
-
# st.write(media_data)
|
301 |
-
|
302 |
-
contri_df = pd.DataFrame()
|
303 |
-
|
304 |
-
y = []
|
305 |
-
y_pred = []
|
306 |
-
|
307 |
-
random_eff_df = get_random_effects(media_data, panel_col, model)
|
308 |
-
random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
|
309 |
-
random_eff_df["panel_effect"] = (
|
310 |
-
random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
|
311 |
-
)
|
312 |
-
# random_eff_df.to_csv("Test/random_eff_df_contri.csv", index=False)
|
313 |
-
|
314 |
-
coef_df = pd.DataFrame(model.fe_params)
|
315 |
-
coef_df.reset_index(inplace=True)
|
316 |
-
coef_df.columns = ["feature", "coef"]
|
317 |
-
|
318 |
-
# coef_df.reset_index().to_csv("Test/coef_df_contri1.csv",index=False)
|
319 |
-
# print(model.fe_params)
|
320 |
-
|
321 |
-
x_train_contribution = X_train.copy()
|
322 |
-
x_test_contribution = X_test.copy()
|
323 |
-
|
324 |
-
# preprocessing not needed since X_train is already preprocessed
|
325 |
-
# X1, X2 = process_train_and_test(x_train_contribution, x_test_contribution, best_feature_set, panel_col, target_col)
|
326 |
-
# x_train_contribution[best_feature_set] = X1[best_feature_set]
|
327 |
-
# x_test_contribution[best_feature_set] = X2[best_feature_set]
|
328 |
-
|
329 |
-
x_train_contribution = mdf_predict(
|
330 |
-
x_train_contribution, model, random_eff_df
|
331 |
-
)
|
332 |
-
x_test_contribution = mdf_predict(
|
333 |
-
x_test_contribution, model, random_eff_df
|
334 |
-
)
|
335 |
-
|
336 |
-
x_train_contribution = pd.merge(
|
337 |
-
x_train_contribution,
|
338 |
-
random_eff_df[[panel_col, "panel_effect"]],
|
339 |
-
on=panel_col,
|
340 |
-
how="left",
|
341 |
-
)
|
342 |
-
x_test_contribution = pd.merge(
|
343 |
-
x_test_contribution,
|
344 |
-
random_eff_df[[panel_col, "panel_effect"]],
|
345 |
-
on=panel_col,
|
346 |
-
how="left",
|
347 |
-
)
|
348 |
-
|
349 |
-
for i in range(len(coef_df))[1:]:
|
350 |
-
coef = coef_df.loc[i, "coef"]
|
351 |
-
col = coef_df.loc[i, "feature"]
|
352 |
-
x_train_contribution[str(col) + "_contr"] = (
|
353 |
-
coef * x_train_contribution[col]
|
354 |
-
)
|
355 |
-
x_test_contribution[str(col) + "_contr"] = (
|
356 |
-
coef * x_train_contribution[col]
|
357 |
-
)
|
358 |
-
|
359 |
-
x_train_contribution["sum_contributions"] = (
|
360 |
-
x_train_contribution.filter(regex="contr").sum(axis=1)
|
361 |
-
)
|
362 |
-
x_train_contribution["sum_contributions"] = (
|
363 |
-
x_train_contribution["sum_contributions"]
|
364 |
-
+ x_train_contribution["panel_effect"]
|
365 |
-
)
|
366 |
-
|
367 |
-
x_test_contribution["sum_contributions"] = x_test_contribution.filter(
|
368 |
-
regex="contr"
|
369 |
-
).sum(axis=1)
|
370 |
-
x_test_contribution["sum_contributions"] = (
|
371 |
-
x_test_contribution["sum_contributions"]
|
372 |
-
+ x_test_contribution["panel_effect"]
|
373 |
-
)
|
374 |
-
|
375 |
-
# # test
|
376 |
-
x_train_contribution.to_csv(
|
377 |
-
"Test/x_train_contribution.csv", index=False
|
378 |
-
)
|
379 |
-
x_test_contribution.to_csv("Test/x_test_contribution.csv", index=False)
|
380 |
-
#
|
381 |
-
# st.session_state['orig_media_data'].to_csv("Test/transformed_data.csv",index=False)
|
382 |
-
# st.session_state['X_test_spends'].to_csv("Test/test_spends.csv",index=False)
|
383 |
-
# # st.write(st.session_state['orig_media_data'].columns)
|
384 |
-
|
385 |
-
# st.write(date_col,panel_col)
|
386 |
-
# st.write(x_test_contribution)
|
387 |
-
|
388 |
-
overview_test_data_prep_panel(
|
389 |
-
x_test_contribution,
|
390 |
-
st.session_state["orig_media_data"],
|
391 |
-
st.session_state["X_test_spends"],
|
392 |
-
date_col,
|
393 |
-
panel_col,
|
394 |
-
target_col,
|
395 |
-
)
|
396 |
-
|
397 |
-
else: # NON PANEL
|
398 |
-
if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
|
399 |
-
with open(
|
400 |
-
os.path.join(
|
401 |
-
st.session_state["project_path"], "tuned_model.pkl"
|
402 |
-
),
|
403 |
-
"rb",
|
404 |
-
) as file:
|
405 |
-
model_dict = pickle.load(file)
|
406 |
-
saved_models = list(model_dict.keys())
|
407 |
-
required_saved_models = [
|
408 |
-
m.split("__")[0]
|
409 |
-
for m in saved_models
|
410 |
-
if m.split("__")[1] == target_col
|
411 |
-
]
|
412 |
-
sel_model = st.selectbox(
|
413 |
-
"Select the model to review", required_saved_models
|
414 |
-
)
|
415 |
-
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
416 |
-
|
417 |
-
model = sel_model_dict["Model_object"]
|
418 |
-
X_train = sel_model_dict["X_train_tuned"]
|
419 |
-
X_test = sel_model_dict["X_test_tuned"]
|
420 |
-
best_feature_set = sel_model_dict["feature_set"]
|
421 |
-
|
422 |
-
else: # Sprint4
|
423 |
-
with open(
|
424 |
-
os.path.join(
|
425 |
-
st.session_state["project_path"], "best_models.pkl"
|
426 |
-
),
|
427 |
-
"rb",
|
428 |
-
) as file:
|
429 |
-
model_dict = pickle.load(file)
|
430 |
-
saved_models = list(model_dict.keys())
|
431 |
-
required_saved_models = [
|
432 |
-
m.split("__")[0]
|
433 |
-
for m in saved_models
|
434 |
-
if m.split("__")[1] == target_col
|
435 |
-
]
|
436 |
-
sel_model = st.selectbox(
|
437 |
-
"Select the model to review", required_saved_models
|
438 |
-
)
|
439 |
-
sel_model_dict = model_dict[sel_model + "__" + target_col]
|
440 |
-
|
441 |
-
model = sel_model_dict["Model_object"]
|
442 |
-
X_train = sel_model_dict["X_train"]
|
443 |
-
X_test = sel_model_dict["X_test"]
|
444 |
-
best_feature_set = sel_model_dict["feature_set"]
|
445 |
-
|
446 |
-
x_train_contribution = X_train.copy()
|
447 |
-
x_test_contribution = X_test.copy()
|
448 |
-
|
449 |
-
x_train_contribution["pred"] = model.predict(
|
450 |
-
x_train_contribution[best_feature_set]
|
451 |
-
)
|
452 |
-
x_test_contribution["pred"] = model.predict(
|
453 |
-
x_test_contribution[best_feature_set]
|
454 |
-
)
|
455 |
-
|
456 |
-
for num, i in enumerate(model.params.values):
|
457 |
-
col = best_feature_set[num]
|
458 |
-
x_train_contribution[col + "_contr"] = X_train[col] * i
|
459 |
-
x_test_contribution[col + "_contr"] = X_test[col] * i
|
460 |
-
|
461 |
-
x_test_contribution.to_csv(
|
462 |
-
"Test/x_test_contribution_non_panel.csv", index=False
|
463 |
-
)
|
464 |
-
overview_test_data_prep_nonpanel(
|
465 |
-
x_test_contribution,
|
466 |
-
st.session_state["orig_media_data"].copy(),
|
467 |
-
st.session_state["X_test_spends"].copy(),
|
468 |
-
date_col,
|
469 |
-
target_col,
|
470 |
-
)
|
471 |
-
# for k, v in st.session_sta
|
472 |
-
# te.items():
|
473 |
-
|
474 |
-
# if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
|
475 |
-
# st.session_state[k] = v
|
476 |
-
|
477 |
-
# authenticator = st.session_state.get('authenticator')
|
478 |
-
|
479 |
-
# if authenticator is None:
|
480 |
-
# authenticator = load_authenticator()
|
481 |
-
|
482 |
-
# name, authentication_status, username = authenticator.login('Login', 'main')
|
483 |
-
# auth_status = st.session_state['authentication_status']
|
484 |
-
|
485 |
-
# if auth_status:
|
486 |
-
# authenticator.logout('Logout', 'main')
|
487 |
-
|
488 |
-
# is_state_initiaized = st.session_state.get('initialized',False)
|
489 |
-
# if not is_state_initiaized:
|
490 |
-
|
491 |
-
initialize_data(target_col)
|
492 |
-
scenario = st.session_state["scenario"]
|
493 |
-
raw_df = st.session_state["raw_df"]
|
494 |
-
st.header("Overview of previous spends")
|
495 |
-
|
496 |
-
# st.write(scenario.actual_total_spends)
|
497 |
-
# st.write(scenario.actual_total_sales)
|
498 |
-
columns = st.columns((1, 1, 3))
|
499 |
-
|
500 |
-
with columns[0]:
|
501 |
-
st.metric(
|
502 |
-
label="Spends",
|
503 |
-
value=format_numbers(float(scenario.actual_total_spends)),
|
504 |
-
)
|
505 |
-
###print(f"##################### {scenario.actual_total_sales} ##################")
|
506 |
-
with columns[1]:
|
507 |
-
st.metric(
|
508 |
-
label=target,
|
509 |
-
value=format_numbers(
|
510 |
-
float(scenario.actual_total_sales), include_indicator=False
|
511 |
-
),
|
512 |
-
)
|
513 |
-
|
514 |
-
actual_summary_df = create_channel_summary(scenario)
|
515 |
-
actual_summary_df["Channel"] = actual_summary_df["Channel"].apply(
|
516 |
-
channel_name_formating
|
517 |
-
)
|
518 |
-
|
519 |
-
columns = st.columns((2, 1))
|
520 |
-
with columns[0]:
|
521 |
-
with st.expander("Channel wise overview"):
|
522 |
-
st.markdown(
|
523 |
-
actual_summary_df.style.set_table_styles(
|
524 |
-
[
|
525 |
-
{
|
526 |
-
"selector": "th",
|
527 |
-
"props": [("background-color", "#11B6BD")],
|
528 |
-
},
|
529 |
-
{
|
530 |
-
"selector": "tr:nth-child(even)",
|
531 |
-
"props": [("background-color", "#11B6BD")],
|
532 |
-
},
|
533 |
-
]
|
534 |
-
).to_html(),
|
535 |
-
unsafe_allow_html=True,
|
536 |
-
)
|
537 |
-
|
538 |
-
st.markdown("<hr>", unsafe_allow_html=True)
|
539 |
-
##############################
|
540 |
-
|
541 |
-
st.plotly_chart(
|
542 |
-
create_contribution_pie(scenario), use_container_width=True
|
543 |
-
)
|
544 |
-
st.markdown("<hr>", unsafe_allow_html=True)
|
545 |
-
|
546 |
-
################################3
|
547 |
-
st.plotly_chart(
|
548 |
-
create_contribuion_stacked_plot(scenario), use_container_width=True
|
549 |
-
)
|
550 |
-
st.markdown("<hr>", unsafe_allow_html=True)
|
551 |
-
#######################################
|
552 |
-
|
553 |
-
selected_channel_name = st.selectbox(
|
554 |
-
"Channel",
|
555 |
-
st.session_state["channels_list"] + ["non media"],
|
556 |
-
format_func=channel_name_formating,
|
557 |
-
)
|
558 |
-
selected_channel = scenario.channels.get(selected_channel_name, None)
|
559 |
-
|
560 |
-
st.plotly_chart(
|
561 |
-
create_channel_spends_sales_plot(selected_channel),
|
562 |
-
use_container_width=True,
|
563 |
-
)
|
564 |
-
|
565 |
-
st.markdown("<hr>", unsafe_allow_html=True)
|
566 |
-
|
567 |
-
if st.checkbox("Save this session", key="save"):
|
568 |
-
project_dct_path = os.path.join(
|
569 |
-
st.session_state["session_path"], "project_dct.pkl"
|
570 |
-
)
|
571 |
-
with open(project_dct_path, "wb") as f:
|
572 |
-
pickle.dump(st.session_state["project_dct"], f)
|
573 |
-
update_db("7_Current_Media_Performance.py")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/8_Build_Response_Curves.py
DELETED
@@ -1,596 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import plotly.express as px
|
3 |
-
import numpy as np
|
4 |
-
import plotly.graph_objects as go
|
5 |
-
from utilities import (
|
6 |
-
channel_name_formating,
|
7 |
-
load_authenticator,
|
8 |
-
initialize_data,
|
9 |
-
fetch_actual_data,
|
10 |
-
)
|
11 |
-
from sklearn.metrics import r2_score
|
12 |
-
from collections import OrderedDict
|
13 |
-
from classes import class_from_dict, class_to_dict
|
14 |
-
import pickle
|
15 |
-
import json
|
16 |
-
import sqlite3
|
17 |
-
from utilities import update_db
|
18 |
-
|
19 |
-
for k, v in st.session_state.items():
|
20 |
-
if k not in ["logout", "login", "config"] and not k.startswith(
|
21 |
-
"FormSubmitter"
|
22 |
-
):
|
23 |
-
st.session_state[k] = v
|
24 |
-
|
25 |
-
|
26 |
-
def s_curve(x, K, b, a, x0):
|
27 |
-
return K / (1 + b * np.exp(-a * (x - x0)))
|
28 |
-
|
29 |
-
|
30 |
-
def save_scenario(scenario_name):
|
31 |
-
"""
|
32 |
-
Save the current scenario with the mentioned name in the session state
|
33 |
-
|
34 |
-
Parameters
|
35 |
-
----------
|
36 |
-
scenario_name
|
37 |
-
Name of the scenario to be saved
|
38 |
-
"""
|
39 |
-
if "saved_scenarios" not in st.session_state:
|
40 |
-
st.session_state = OrderedDict()
|
41 |
-
|
42 |
-
# st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
|
43 |
-
st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
|
44 |
-
st.session_state["scenario"]
|
45 |
-
)
|
46 |
-
st.session_state["scenario_input"] = ""
|
47 |
-
print(type(st.session_state["saved_scenarios"]))
|
48 |
-
with open("../saved_scenarios.pkl", "wb") as f:
|
49 |
-
pickle.dump(st.session_state["saved_scenarios"], f)
|
50 |
-
|
51 |
-
|
52 |
-
def reset_curve_parameters(
|
53 |
-
metrics=None, panel=None, selected_channel_name=None
|
54 |
-
):
|
55 |
-
del st.session_state["K"]
|
56 |
-
del st.session_state["b"]
|
57 |
-
del st.session_state["a"]
|
58 |
-
del st.session_state["x0"]
|
59 |
-
|
60 |
-
if (
|
61 |
-
metrics is not None
|
62 |
-
and panel is not None
|
63 |
-
and selected_channel_name is not None
|
64 |
-
):
|
65 |
-
if f"{metrics}#@{panel}#@{selected_channel_name}" in list(
|
66 |
-
st.session_state["update_rcs"].keys()
|
67 |
-
):
|
68 |
-
del st.session_state["update_rcs"][
|
69 |
-
f"{metrics}#@{panel}#@{selected_channel_name}"
|
70 |
-
]
|
71 |
-
|
72 |
-
|
73 |
-
def update_response_curve(
|
74 |
-
K_updated,
|
75 |
-
b_updated,
|
76 |
-
a_updated,
|
77 |
-
x0_updated,
|
78 |
-
metrics=None,
|
79 |
-
panel=None,
|
80 |
-
selected_channel_name=None,
|
81 |
-
):
|
82 |
-
print(
|
83 |
-
"[DEBUG] update_response_curves: ",
|
84 |
-
st.session_state["project_dct"]["scenario_planner"].keys(),
|
85 |
-
)
|
86 |
-
st.session_state["project_dct"]["scenario_planner"][unique_key].channels[
|
87 |
-
selected_channel_name
|
88 |
-
].response_curve_params = {
|
89 |
-
"K": st.session_state["K"],
|
90 |
-
"b": st.session_state["b"],
|
91 |
-
"a": st.session_state["a"],
|
92 |
-
"x0": st.session_state["x0"],
|
93 |
-
}
|
94 |
-
|
95 |
-
# if (
|
96 |
-
# metrics is not None
|
97 |
-
# and panel is not None
|
98 |
-
# and selected_channel_name is not None
|
99 |
-
# ):
|
100 |
-
# st.session_state["update_rcs"][
|
101 |
-
# f"{metrics}#@{panel}#@{selected_channel_name}"
|
102 |
-
# ] = {
|
103 |
-
# "K": K_updated,
|
104 |
-
# "b": b_updated,
|
105 |
-
# "a": a_updated,
|
106 |
-
# "x0": x0_updated,
|
107 |
-
# }
|
108 |
-
|
109 |
-
# st.session_state["scenario"].channels[
|
110 |
-
# selected_channel_name
|
111 |
-
# ].response_curve_params = {
|
112 |
-
# "K": K_updated,
|
113 |
-
# "b": b_updated,
|
114 |
-
# "a": a_updated,
|
115 |
-
# "x0": x0_updated,
|
116 |
-
# }
|
117 |
-
|
118 |
-
|
119 |
-
# authenticator = st.session_state.get('authenticator')
|
120 |
-
# if authenticator is None:
|
121 |
-
# authenticator = load_authenticator()
|
122 |
-
|
123 |
-
# name, authentication_status, username = authenticator.login('Login', 'main')
|
124 |
-
# auth_status = st.session_state.get('authentication_status')
|
125 |
-
|
126 |
-
# if auth_status == True:
|
127 |
-
# is_state_initiaized = st.session_state.get('initialized',False)
|
128 |
-
# if not is_state_initiaized:
|
129 |
-
# print("Scenario page state reloaded")
|
130 |
-
|
131 |
-
import pandas as pd
|
132 |
-
|
133 |
-
|
134 |
-
@st.cache_resource(show_spinner=False)
|
135 |
-
def panel_fetch(file_selected):
|
136 |
-
raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
|
137 |
-
|
138 |
-
if "Panel" in raw_data_mmm_df.columns:
|
139 |
-
panel = list(set(raw_data_mmm_df["Panel"]))
|
140 |
-
else:
|
141 |
-
raw_data_mmm_df = None
|
142 |
-
panel = None
|
143 |
-
|
144 |
-
return panel
|
145 |
-
|
146 |
-
|
147 |
-
import glob
|
148 |
-
import os
|
149 |
-
|
150 |
-
|
151 |
-
def get_excel_names(directory):
|
152 |
-
# Create a list to hold the final parts of the filenames
|
153 |
-
last_portions = []
|
154 |
-
|
155 |
-
# Patterns to match Excel files (.xlsx and .xls) that contain @#
|
156 |
-
patterns = [
|
157 |
-
os.path.join(directory, "*@#*.xlsx"),
|
158 |
-
os.path.join(directory, "*@#*.xls"),
|
159 |
-
]
|
160 |
-
|
161 |
-
# Process each pattern
|
162 |
-
for pattern in patterns:
|
163 |
-
files = glob.glob(pattern)
|
164 |
-
|
165 |
-
# Extracting the last portion after @# for each file
|
166 |
-
for file in files:
|
167 |
-
base_name = os.path.basename(file)
|
168 |
-
last_portion = base_name.split("@#")[-1]
|
169 |
-
last_portion = last_portion.replace(".xlsx", "").replace(
|
170 |
-
".xls", ""
|
171 |
-
) # Removing extensions
|
172 |
-
last_portions.append(last_portion)
|
173 |
-
|
174 |
-
return last_portions
|
175 |
-
|
176 |
-
|
177 |
-
def name_formating(channel_name):
|
178 |
-
# Replace underscores with spaces
|
179 |
-
name_mod = channel_name.replace("_", " ")
|
180 |
-
|
181 |
-
# Capitalize the first letter of each word
|
182 |
-
name_mod = name_mod.title()
|
183 |
-
|
184 |
-
return name_mod
|
185 |
-
|
186 |
-
|
187 |
-
def fetch_panel_data():
|
188 |
-
print("DEBUG etch_panel_data: running... ")
|
189 |
-
file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx"
|
190 |
-
panel_selected = st.session_state["panel_selected_selectbox"]
|
191 |
-
print(panel_selected)
|
192 |
-
if panel_selected == "Aggregated":
|
193 |
-
(
|
194 |
-
st.session_state["actual_input_df"],
|
195 |
-
st.session_state["actual_contribution_df"],
|
196 |
-
) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
|
197 |
-
else:
|
198 |
-
(
|
199 |
-
st.session_state["actual_input_df"],
|
200 |
-
st.session_state["actual_contribution_df"],
|
201 |
-
) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
|
202 |
-
|
203 |
-
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
|
204 |
-
print("unique_key")
|
205 |
-
if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
|
206 |
-
if panel_selected == "Aggregated":
|
207 |
-
initialize_data(
|
208 |
-
panel=panel_selected,
|
209 |
-
target_file=file_selected,
|
210 |
-
updated_rcs={},
|
211 |
-
metrics=metrics_selected,
|
212 |
-
)
|
213 |
-
panel = None
|
214 |
-
else:
|
215 |
-
initialize_data(
|
216 |
-
panel=panel_selected,
|
217 |
-
target_file=file_selected,
|
218 |
-
updated_rcs={},
|
219 |
-
metrics=metrics_selected,
|
220 |
-
)
|
221 |
-
st.session_state["project_dct"]["scenario_planner"][unique_key] = (
|
222 |
-
st.session_state["scenario"]
|
223 |
-
)
|
224 |
-
# print(
|
225 |
-
# "DEBUG etch_panel_data: ",
|
226 |
-
# st.session_state["project_dct"]["scenario_planner"][
|
227 |
-
# unique_key
|
228 |
-
# ].keys(),
|
229 |
-
# )
|
230 |
-
|
231 |
-
else:
|
232 |
-
st.session_state["scenario"] = st.session_state["project_dct"][
|
233 |
-
"scenario_planner"
|
234 |
-
][unique_key]
|
235 |
-
st.session_state["rcs"] = {}
|
236 |
-
st.session_state["powers"] = {}
|
237 |
-
|
238 |
-
for channel_name, _channel in st.session_state["project_dct"][
|
239 |
-
"scenario_planner"
|
240 |
-
][unique_key].channels.items():
|
241 |
-
st.session_state["rcs"][
|
242 |
-
channel_name
|
243 |
-
] = _channel.response_curve_params
|
244 |
-
st.session_state["powers"][channel_name] = _channel.power
|
245 |
-
|
246 |
-
if "K" in st.session_state:
|
247 |
-
del st.session_state["K"]
|
248 |
-
|
249 |
-
if "b" in st.session_state:
|
250 |
-
del st.session_state["b"]
|
251 |
-
|
252 |
-
if "a" in st.session_state:
|
253 |
-
del st.session_state["a"]
|
254 |
-
|
255 |
-
if "x0" in st.session_state:
|
256 |
-
del st.session_state["x0"]
|
257 |
-
|
258 |
-
|
259 |
-
if "project_dct" not in st.session_state:
|
260 |
-
st.error("Please load a project from home")
|
261 |
-
st.stop()
|
262 |
-
|
263 |
-
database_file = r"DB\User.db"
|
264 |
-
|
265 |
-
conn = sqlite3.connect(
|
266 |
-
database_file, check_same_thread=False
|
267 |
-
) # connection with sql db
|
268 |
-
c = conn.cursor()
|
269 |
-
|
270 |
-
st.subheader("Build Response Curves")
|
271 |
-
|
272 |
-
|
273 |
-
if "update_rcs" not in st.session_state:
|
274 |
-
st.session_state["update_rcs"] = {}
|
275 |
-
|
276 |
-
st.session_state["first_time"] = True
|
277 |
-
|
278 |
-
col1, col2, col3 = st.columns([1, 1, 1])
|
279 |
-
|
280 |
-
directory = "metrics_level_data"
|
281 |
-
metrics_list = get_excel_names(directory)
|
282 |
-
|
283 |
-
|
284 |
-
metrics_selected = col1.selectbox(
|
285 |
-
"Response Metrics",
|
286 |
-
metrics_list,
|
287 |
-
on_change=fetch_panel_data,
|
288 |
-
format_func=name_formating,
|
289 |
-
key="response_metrics_selectbox",
|
290 |
-
)
|
291 |
-
|
292 |
-
|
293 |
-
file_selected = (
|
294 |
-
f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
|
295 |
-
)
|
296 |
-
|
297 |
-
panel_list = panel_fetch(file_selected)
|
298 |
-
final_panel_list = ["Aggregated"] + panel_list
|
299 |
-
|
300 |
-
panel_selected = col3.selectbox(
|
301 |
-
"Panel",
|
302 |
-
final_panel_list,
|
303 |
-
on_change=fetch_panel_data,
|
304 |
-
key="panel_selected_selectbox",
|
305 |
-
)
|
306 |
-
|
307 |
-
|
308 |
-
is_state_initiaized = st.session_state.get("initialized_rcs", False)
|
309 |
-
print(is_state_initiaized)
|
310 |
-
if not is_state_initiaized:
|
311 |
-
print("DEBUG.....", "Here")
|
312 |
-
fetch_panel_data()
|
313 |
-
# if panel_selected == "Aggregated":
|
314 |
-
# initialize_data(panel=panel_selected, target_file=file_selected)
|
315 |
-
# panel = None
|
316 |
-
# else:
|
317 |
-
# initialize_data(panel=panel_selected, target_file=file_selected)
|
318 |
-
|
319 |
-
st.session_state["initialized_rcs"] = True
|
320 |
-
|
321 |
-
# channels_list = st.session_state["channels_list"]
|
322 |
-
unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
|
323 |
-
chanel_list_final = list(
|
324 |
-
st.session_state["project_dct"]["scenario_planner"][
|
325 |
-
unique_key
|
326 |
-
].channels.keys()
|
327 |
-
) + ["Others"]
|
328 |
-
|
329 |
-
|
330 |
-
selected_channel_name = col2.selectbox(
|
331 |
-
"Channel",
|
332 |
-
chanel_list_final,
|
333 |
-
format_func=channel_name_formating,
|
334 |
-
on_change=reset_curve_parameters,
|
335 |
-
key="selected_channel_name_selectbox",
|
336 |
-
)
|
337 |
-
|
338 |
-
|
339 |
-
rcs = st.session_state["rcs"]
|
340 |
-
|
341 |
-
if "K" not in st.session_state:
|
342 |
-
st.session_state["K"] = rcs[selected_channel_name]["K"]
|
343 |
-
|
344 |
-
if "b" not in st.session_state:
|
345 |
-
st.session_state["b"] = rcs[selected_channel_name]["b"]
|
346 |
-
|
347 |
-
|
348 |
-
if "a" not in st.session_state:
|
349 |
-
st.session_state["a"] = rcs[selected_channel_name]["a"]
|
350 |
-
|
351 |
-
if "x0" not in st.session_state:
|
352 |
-
st.session_state["x0"] = rcs[selected_channel_name]["x0"]
|
353 |
-
|
354 |
-
|
355 |
-
x = st.session_state["actual_input_df"][selected_channel_name].values
|
356 |
-
y = st.session_state["actual_contribution_df"][selected_channel_name].values
|
357 |
-
|
358 |
-
|
359 |
-
power = np.ceil(np.log(x.max()) / np.log(10)) - 3
|
360 |
-
|
361 |
-
print(f"DEBUG BUILD RCS: {selected_channel_name}")
|
362 |
-
print(f"DEBUG BUILD RCS: K : {st.session_state['K']}")
|
363 |
-
print(f"DEBUG BUILD RCS: b : {st.session_state['b']}")
|
364 |
-
print(f"DEBUG BUILD RCS: a : {st.session_state['a']}")
|
365 |
-
print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}")
|
366 |
-
|
367 |
-
# fig = px.scatter(x, s_curve(x/10**power,
|
368 |
-
# st.session_state['K'],
|
369 |
-
# st.session_state['b'],
|
370 |
-
# st.session_state['a'],
|
371 |
-
# st.session_state['x0']))
|
372 |
-
|
373 |
-
x_plot = np.linspace(0, 5 * max(x), 50)
|
374 |
-
|
375 |
-
fig = px.scatter(x=x, y=y)
|
376 |
-
fig.add_trace(
|
377 |
-
go.Scatter(
|
378 |
-
x=x_plot,
|
379 |
-
y=s_curve(
|
380 |
-
x_plot / 10**power,
|
381 |
-
st.session_state["K"],
|
382 |
-
st.session_state["b"],
|
383 |
-
st.session_state["a"],
|
384 |
-
st.session_state["x0"],
|
385 |
-
),
|
386 |
-
line=dict(color="red"),
|
387 |
-
name="Modified",
|
388 |
-
),
|
389 |
-
)
|
390 |
-
|
391 |
-
fig.add_trace(
|
392 |
-
go.Scatter(
|
393 |
-
x=x_plot,
|
394 |
-
y=s_curve(
|
395 |
-
x_plot / 10**power,
|
396 |
-
rcs[selected_channel_name]["K"],
|
397 |
-
rcs[selected_channel_name]["b"],
|
398 |
-
rcs[selected_channel_name]["a"],
|
399 |
-
rcs[selected_channel_name]["x0"],
|
400 |
-
),
|
401 |
-
line=dict(color="rgba(0, 255, 0, 0.4)"),
|
402 |
-
name="Actual",
|
403 |
-
),
|
404 |
-
)
|
405 |
-
|
406 |
-
fig.update_layout(title_text="Response Curve", showlegend=True)
|
407 |
-
fig.update_annotations(font_size=10)
|
408 |
-
fig.update_xaxes(title="Spends")
|
409 |
-
fig.update_yaxes(title="Revenue")
|
410 |
-
|
411 |
-
st.plotly_chart(fig, use_container_width=True)
|
412 |
-
|
413 |
-
r2 = r2_score(
|
414 |
-
y,
|
415 |
-
s_curve(
|
416 |
-
x / 10**power,
|
417 |
-
st.session_state["K"],
|
418 |
-
st.session_state["b"],
|
419 |
-
st.session_state["a"],
|
420 |
-
st.session_state["x0"],
|
421 |
-
),
|
422 |
-
)
|
423 |
-
|
424 |
-
r2_actual = r2_score(
|
425 |
-
y,
|
426 |
-
s_curve(
|
427 |
-
x / 10**power,
|
428 |
-
rcs[selected_channel_name]["K"],
|
429 |
-
rcs[selected_channel_name]["b"],
|
430 |
-
rcs[selected_channel_name]["a"],
|
431 |
-
rcs[selected_channel_name]["x0"],
|
432 |
-
),
|
433 |
-
)
|
434 |
-
|
435 |
-
columns = st.columns((1, 1, 2))
|
436 |
-
with columns[0]:
|
437 |
-
st.metric("R2 Modified", round(r2, 2))
|
438 |
-
with columns[1]:
|
439 |
-
st.metric("R2 Actual", round(r2_actual, 2))
|
440 |
-
|
441 |
-
|
442 |
-
st.markdown("#### Set Parameters", unsafe_allow_html=True)
|
443 |
-
columns = st.columns(4)
|
444 |
-
|
445 |
-
if "updated_parms" not in st.session_state:
|
446 |
-
st.session_state["updated_parms"] = {
|
447 |
-
"K_updated": 0,
|
448 |
-
"b_updated": 0,
|
449 |
-
"a_updated": 0,
|
450 |
-
"x0_updated": 0,
|
451 |
-
}
|
452 |
-
|
453 |
-
with columns[0]:
|
454 |
-
st.session_state["updated_parms"]["K_updated"] = st.number_input(
|
455 |
-
"K", key="K", format="%0.5f"
|
456 |
-
)
|
457 |
-
with columns[1]:
|
458 |
-
st.session_state["updated_parms"]["b_updated"] = st.number_input(
|
459 |
-
"b", key="b", format="%0.5f"
|
460 |
-
)
|
461 |
-
with columns[2]:
|
462 |
-
st.session_state["updated_parms"]["a_updated"] = st.number_input(
|
463 |
-
"a", key="a", step=0.0001, format="%0.5f"
|
464 |
-
)
|
465 |
-
with columns[3]:
|
466 |
-
st.session_state["updated_parms"]["x0_updated"] = st.number_input(
|
467 |
-
"x0", key="x0", format="%0.5f"
|
468 |
-
)
|
469 |
-
|
470 |
-
# st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = (
|
471 |
-
# st.session_state["updated_parms"]["K_updated"]
|
472 |
-
# )
|
473 |
-
# st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = (
|
474 |
-
# st.session_state["updated_parms"]["b_updated"]
|
475 |
-
# )
|
476 |
-
# st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = (
|
477 |
-
# st.session_state["updated_parms"]["a_updated"]
|
478 |
-
# )
|
479 |
-
# st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = (
|
480 |
-
# st.session_state["updated_parms"]["x0_updated"]
|
481 |
-
# )
|
482 |
-
|
483 |
-
update_col, reset_col = st.columns([1, 1])
|
484 |
-
if update_col.button(
|
485 |
-
"Update Parameters",
|
486 |
-
on_click=update_response_curve,
|
487 |
-
args=(
|
488 |
-
st.session_state["updated_parms"]["K_updated"],
|
489 |
-
st.session_state["updated_parms"]["b_updated"],
|
490 |
-
st.session_state["updated_parms"]["a_updated"],
|
491 |
-
st.session_state["updated_parms"]["x0_updated"],
|
492 |
-
metrics_selected,
|
493 |
-
panel_selected,
|
494 |
-
selected_channel_name,
|
495 |
-
),
|
496 |
-
use_container_width=True,
|
497 |
-
):
|
498 |
-
st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[
|
499 |
-
"updated_parms"
|
500 |
-
]["K_updated"]
|
501 |
-
st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[
|
502 |
-
"updated_parms"
|
503 |
-
]["b_updated"]
|
504 |
-
st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[
|
505 |
-
"updated_parms"
|
506 |
-
]["a_updated"]
|
507 |
-
st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[
|
508 |
-
"updated_parms"
|
509 |
-
]["x0_updated"]
|
510 |
-
|
511 |
-
reset_col.button(
|
512 |
-
"Reset Parameters",
|
513 |
-
on_click=reset_curve_parameters,
|
514 |
-
args=(metrics_selected, panel_selected, selected_channel_name),
|
515 |
-
use_container_width=True,
|
516 |
-
)
|
517 |
-
|
518 |
-
st.divider()
|
519 |
-
save_col, down_col = st.columns([1, 1])
|
520 |
-
|
521 |
-
|
522 |
-
with save_col:
|
523 |
-
file_name = st.text_input(
|
524 |
-
"rcs download file name",
|
525 |
-
key="file_name_input",
|
526 |
-
placeholder="File name",
|
527 |
-
label_visibility="collapsed",
|
528 |
-
)
|
529 |
-
down_col.download_button(
|
530 |
-
label="Download response curves",
|
531 |
-
data=json.dumps(rcs),
|
532 |
-
file_name=f"{file_name}.json",
|
533 |
-
mime="application/json",
|
534 |
-
disabled=len(file_name) == 0,
|
535 |
-
use_container_width=True,
|
536 |
-
)
|
537 |
-
|
538 |
-
|
539 |
-
def s_curve_derivative(x, K, b, a, x0):
|
540 |
-
# Derivative of the S-curve function
|
541 |
-
return (
|
542 |
-
a
|
543 |
-
* b
|
544 |
-
* K
|
545 |
-
* np.exp(-a * (x - x0))
|
546 |
-
/ ((1 + b * np.exp(-a * (x - x0))) ** 2)
|
547 |
-
)
|
548 |
-
|
549 |
-
|
550 |
-
# Parameters of the S-curve
|
551 |
-
K = st.session_state["K"]
|
552 |
-
b = st.session_state["b"]
|
553 |
-
a = st.session_state["a"]
|
554 |
-
x0 = st.session_state["x0"]
|
555 |
-
|
556 |
-
# # Optimized spend value obtained from the tool
|
557 |
-
# optimized_spend = st.number_input(
|
558 |
-
# "value of x"
|
559 |
-
# ) # Replace this with your optimized spend value
|
560 |
-
|
561 |
-
# # Calculate the slope at the optimized spend value
|
562 |
-
# slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0)
|
563 |
-
|
564 |
-
# st.write("Slope ", slope_at_optimized_spend)
|
565 |
-
|
566 |
-
|
567 |
-
# Initialize a list to hold our rows
|
568 |
-
rows = []
|
569 |
-
|
570 |
-
# Iterate over the dictionary
|
571 |
-
for key, value in st.session_state["update_rcs"].items():
|
572 |
-
# Split the key into its components
|
573 |
-
metrics, panel, channel_name = key.split("#@")
|
574 |
-
# Create a new row with the components and the values
|
575 |
-
row = {
|
576 |
-
"Metrics": name_formating(metrics),
|
577 |
-
"Panel": name_formating(panel),
|
578 |
-
"Channel Name": channel_name,
|
579 |
-
"K": value["K"],
|
580 |
-
"b": value["b"],
|
581 |
-
"a": value["a"],
|
582 |
-
"x0": value["x0"],
|
583 |
-
}
|
584 |
-
# Append the row to our list
|
585 |
-
rows.append(row)
|
586 |
-
|
587 |
-
# Convert the list of rows into a DataFrame
|
588 |
-
updated_parms_df = pd.DataFrame(rows)
|
589 |
-
|
590 |
-
if len(list(st.session_state["update_rcs"].keys())) > 0:
|
591 |
-
st.markdown("#### Updated Parameters", unsafe_allow_html=True)
|
592 |
-
st.dataframe(updated_parms_df, hide_index=True)
|
593 |
-
else:
|
594 |
-
st.info("No parameters are updated")
|
595 |
-
|
596 |
-
update_db("8_Build_Response_Curves.py")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/9_Scenario_Planner.py
DELETED
@@ -1,1712 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from numerize.numerize import numerize
|
3 |
-
import numpy as np
|
4 |
-
from functools import partial
|
5 |
-
from collections import OrderedDict
|
6 |
-
from plotly.subplots import make_subplots
|
7 |
-
import plotly.graph_objects as go
|
8 |
-
from utilities import (
|
9 |
-
format_numbers,
|
10 |
-
load_local_css,
|
11 |
-
set_header,
|
12 |
-
initialize_data,
|
13 |
-
load_authenticator,
|
14 |
-
send_email,
|
15 |
-
channel_name_formating,
|
16 |
-
)
|
17 |
-
from classes import class_from_dict, class_to_dict
|
18 |
-
import pickle
|
19 |
-
import streamlit_authenticator as stauth
|
20 |
-
import yaml
|
21 |
-
from yaml import SafeLoader
|
22 |
-
import re
|
23 |
-
import pandas as pd
|
24 |
-
import plotly.express as px
|
25 |
-
import logging
|
26 |
-
from utilities import update_db
|
27 |
-
import sqlite3
|
28 |
-
|
29 |
-
|
30 |
-
st.set_page_config(layout="wide")
|
31 |
-
load_local_css("styles.css")
|
32 |
-
set_header()
|
33 |
-
|
34 |
-
for k, v in st.session_state.items():
|
35 |
-
if k not in ["logout", "login", "config"] and not k.startswith(
|
36 |
-
"FormSubmitter"
|
37 |
-
):
|
38 |
-
st.session_state[k] = v
|
39 |
-
# ======================================================== #
|
40 |
-
# ======================= Functions ====================== #
|
41 |
-
# ======================================================== #
|
42 |
-
|
43 |
-
|
44 |
-
def optimize(key, status_placeholder):
|
45 |
-
"""
|
46 |
-
Optimize the spends for the sales
|
47 |
-
"""
|
48 |
-
|
49 |
-
channel_list = [
|
50 |
-
key
|
51 |
-
for key, value in st.session_state["optimization_channels"].items()
|
52 |
-
if value
|
53 |
-
]
|
54 |
-
|
55 |
-
if len(channel_list) > 0:
|
56 |
-
scenario = st.session_state["scenario"]
|
57 |
-
if key.lower() == "media spends":
|
58 |
-
with status_placeholder:
|
59 |
-
with st.spinner("Optimizing"):
|
60 |
-
result = st.session_state["scenario"].optimize(
|
61 |
-
st.session_state["total_spends_change"], channel_list
|
62 |
-
)
|
63 |
-
# elif key.lower() == "revenue":
|
64 |
-
else:
|
65 |
-
with status_placeholder:
|
66 |
-
with st.spinner("Optimizing"):
|
67 |
-
|
68 |
-
result = st.session_state["scenario"].optimize_spends(
|
69 |
-
st.session_state["total_sales_change"], channel_list
|
70 |
-
)
|
71 |
-
for channel_name, modified_spends in result:
|
72 |
-
|
73 |
-
st.session_state[channel_name] = numerize(
|
74 |
-
modified_spends
|
75 |
-
* scenario.channels[channel_name].conversion_rate,
|
76 |
-
1,
|
77 |
-
)
|
78 |
-
prev_spends = (
|
79 |
-
st.session_state["scenario"]
|
80 |
-
.channels[channel_name]
|
81 |
-
.actual_total_spends
|
82 |
-
)
|
83 |
-
st.session_state[f"{channel_name}_change"] = round(
|
84 |
-
100 * (modified_spends - prev_spends) / prev_spends, 2
|
85 |
-
)
|
86 |
-
|
87 |
-
|
88 |
-
def save_scenario(scenario_name):
|
89 |
-
"""
|
90 |
-
Save the current scenario with the mentioned name in the session state
|
91 |
-
|
92 |
-
Parameters
|
93 |
-
----------
|
94 |
-
scenario_name
|
95 |
-
Name of the scenario to be saved
|
96 |
-
"""
|
97 |
-
if "saved_scenarios" not in st.session_state:
|
98 |
-
st.session_state = OrderedDict()
|
99 |
-
|
100 |
-
# st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
|
101 |
-
st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
|
102 |
-
st.session_state["scenario"]
|
103 |
-
)
|
104 |
-
st.session_state["scenario_input"] = ""
|
105 |
-
# print(type(st.session_state['saved_scenarios']))
|
106 |
-
with open("../saved_scenarios.pkl", "wb") as f:
|
107 |
-
pickle.dump(st.session_state["saved_scenarios"], f)
|
108 |
-
|
109 |
-
|
110 |
-
def update_sales_abs_slider():
|
111 |
-
actual_sales = st.session_state["scenario"].actual_total_sales
|
112 |
-
if validate_input(st.session_state["total_sales_change_abs_slider"]):
|
113 |
-
modified_sales = extract_number_for_string(
|
114 |
-
st.session_state["total_sales_change_abs_slider"]
|
115 |
-
)
|
116 |
-
st.session_state["total_sales_change"] = round(
|
117 |
-
((modified_sales / actual_sales) - 1) * 100
|
118 |
-
)
|
119 |
-
st.session_state["total_sales_change_abs"] = numerize(
|
120 |
-
modified_sales, 1
|
121 |
-
)
|
122 |
-
|
123 |
-
st.session_state["project_dct"]["scenario_planner"][
|
124 |
-
"total_sales_change"
|
125 |
-
] = st.session_state.total_sales_change
|
126 |
-
|
127 |
-
|
128 |
-
def update_sales_abs():
|
129 |
-
actual_sales = st.session_state["scenario"].actual_total_sales
|
130 |
-
if validate_input(st.session_state["total_sales_change_abs"]):
|
131 |
-
modified_sales = extract_number_for_string(
|
132 |
-
st.session_state["total_sales_change_abs"]
|
133 |
-
)
|
134 |
-
st.session_state["total_sales_change"] = round(
|
135 |
-
((modified_sales / actual_sales) - 1) * 100
|
136 |
-
)
|
137 |
-
st.session_state["total_sales_change_abs_slider"] = numerize(
|
138 |
-
modified_sales, 1
|
139 |
-
)
|
140 |
-
|
141 |
-
|
142 |
-
def update_sales():
|
143 |
-
# print("DEBUG: running update_sales")
|
144 |
-
# st.session_state["project_dct"]["scenario_planner"][
|
145 |
-
# "total_sales_change"
|
146 |
-
# ] = st.session_state.total_sales_change
|
147 |
-
# st.session_state["total_spends_change"] = st.session_state[
|
148 |
-
# "total_sales_change"
|
149 |
-
# ]
|
150 |
-
|
151 |
-
st.session_state["total_sales_change_abs"] = numerize(
|
152 |
-
(1 + st.session_state["total_sales_change"] / 100)
|
153 |
-
* st.session_state["scenario"].actual_total_sales,
|
154 |
-
1,
|
155 |
-
)
|
156 |
-
st.session_state["total_sales_change_abs_slider"] = numerize(
|
157 |
-
(1 + st.session_state["total_sales_change"] / 100)
|
158 |
-
* st.session_state["scenario"].actual_total_sales,
|
159 |
-
1,
|
160 |
-
)
|
161 |
-
# update_spends()
|
162 |
-
|
163 |
-
|
164 |
-
def update_all_spends_abs_slider():
|
165 |
-
actual_spends = st.session_state["scenario"].actual_total_spends
|
166 |
-
if validate_input(st.session_state["total_spends_change_abs_slider"]):
|
167 |
-
modified_spends = extract_number_for_string(
|
168 |
-
st.session_state["total_spends_change_abs_slider"]
|
169 |
-
)
|
170 |
-
st.session_state["total_spends_change"] = round(
|
171 |
-
((modified_spends / actual_spends) - 1) * 100
|
172 |
-
)
|
173 |
-
st.session_state["total_spends_change_abs"] = numerize(
|
174 |
-
modified_spends, 1
|
175 |
-
)
|
176 |
-
|
177 |
-
st.session_state["project_dct"]["scenario_planner"][
|
178 |
-
"total_spends_change"
|
179 |
-
] = st.session_state.total_spends_change
|
180 |
-
|
181 |
-
update_all_spends()
|
182 |
-
|
183 |
-
|
184 |
-
# def update_all_spends_abs_slider():
|
185 |
-
# actual_spends = _scenario.actual_total_spends
|
186 |
-
# if validate_input(st.session_state["total_spends_change_abs_slider"]):
|
187 |
-
# print("#" * 100)
|
188 |
-
# print(st.session_state["total_spends_change_abs_slider"])
|
189 |
-
# print("#" * 100)
|
190 |
-
|
191 |
-
# modified_spends = extract_number_for_string(
|
192 |
-
# st.session_state["total_spends_change_abs_slider"]
|
193 |
-
# )
|
194 |
-
# st.session_state["total_spends_change"] = (
|
195 |
-
# (modified_spends / actual_spends) - 1
|
196 |
-
# ) * 100
|
197 |
-
# st.session_state["total_spends_change_abs"] = st.session_state[
|
198 |
-
# "total_spends_change_abs_slider"
|
199 |
-
# ]
|
200 |
-
|
201 |
-
# update_all_spends()
|
202 |
-
|
203 |
-
|
204 |
-
def update_all_spends_abs():
|
205 |
-
print("DEBUG: ", "inside update_all_spends_abs")
|
206 |
-
# print(st.session_state["total_spends_change_abs_slider_options"])
|
207 |
-
|
208 |
-
actual_spends = st.session_state["scenario"].actual_total_spends
|
209 |
-
if validate_input(st.session_state["total_spends_change_abs"]):
|
210 |
-
modified_spends = extract_number_for_string(
|
211 |
-
st.session_state["total_spends_change_abs"]
|
212 |
-
)
|
213 |
-
st.session_state["total_spends_change"] = (
|
214 |
-
(modified_spends / actual_spends) - 1
|
215 |
-
) * 100
|
216 |
-
st.session_state["total_spends_change_abs_slider"] = numerize(
|
217 |
-
extract_number_for_string(
|
218 |
-
st.session_state["total_spends_change_abs"]
|
219 |
-
),
|
220 |
-
1,
|
221 |
-
)
|
222 |
-
|
223 |
-
st.session_state["project_dct"]["scenario_planner"][
|
224 |
-
"total_spends_change"
|
225 |
-
] = st.session_state.total_spends_change
|
226 |
-
|
227 |
-
# print(
|
228 |
-
# "DEBUG UPDATE_ALL_SPENDS_ABS: ",
|
229 |
-
# st.session_state["total_spends_change"],
|
230 |
-
# )
|
231 |
-
update_all_spends()
|
232 |
-
|
233 |
-
|
234 |
-
def update_spends():
|
235 |
-
print("update_spends")
|
236 |
-
st.session_state["total_spends_change_abs"] = numerize(
|
237 |
-
(1 + st.session_state["total_spends_change"] / 100)
|
238 |
-
* st.session_state["scenario"].actual_total_spends,
|
239 |
-
1,
|
240 |
-
)
|
241 |
-
st.session_state["total_spends_change_abs_slider"] = numerize(
|
242 |
-
(1 + st.session_state["total_spends_change"] / 100)
|
243 |
-
* st.session_state["scenario"].actual_total_spends,
|
244 |
-
1,
|
245 |
-
)
|
246 |
-
|
247 |
-
st.session_state["project_dct"]["scenario_planner"][
|
248 |
-
"total_spends_change"
|
249 |
-
] = st.session_state.total_spends_change
|
250 |
-
|
251 |
-
update_all_spends()
|
252 |
-
|
253 |
-
|
254 |
-
def update_all_spends():
|
255 |
-
"""
|
256 |
-
Updates spends for all the channels with the given overall spends change
|
257 |
-
"""
|
258 |
-
percent_change = st.session_state["total_spends_change"]
|
259 |
-
print("runs update_all")
|
260 |
-
for channel_name in list(
|
261 |
-
st.session_state["project_dct"]["scenario_planner"][
|
262 |
-
unique_key
|
263 |
-
].channels.keys()
|
264 |
-
):
|
265 |
-
st.session_state[f"{channel_name}_percent"] = percent_change
|
266 |
-
channel = st.session_state["scenario"].channels[channel_name]
|
267 |
-
current_spends = channel.actual_total_spends
|
268 |
-
modified_spends = (1 + percent_change / 100) * current_spends
|
269 |
-
st.session_state["scenario"].update(channel_name, modified_spends)
|
270 |
-
st.session_state[channel_name] = numerize(
|
271 |
-
modified_spends * channel.conversion_rate, 1
|
272 |
-
)
|
273 |
-
st.session_state[f"{channel_name}_change"] = percent_change
|
274 |
-
|
275 |
-
|
276 |
-
def extract_number_for_string(string_input):
|
277 |
-
string_input = string_input.upper()
|
278 |
-
if string_input.endswith("K"):
|
279 |
-
return float(string_input[:-1]) * 10**3
|
280 |
-
elif string_input.endswith("M"):
|
281 |
-
return float(string_input[:-1]) * 10**6
|
282 |
-
elif string_input.endswith("B"):
|
283 |
-
return float(string_input[:-1]) * 10**9
|
284 |
-
|
285 |
-
|
286 |
-
def validate_input(string_input):
|
287 |
-
pattern = r"\d+\.?\d*[K|M|B]$"
|
288 |
-
match = re.match(pattern, string_input)
|
289 |
-
if match is None:
|
290 |
-
return False
|
291 |
-
return True
|
292 |
-
|
293 |
-
|
294 |
-
def update_data_by_percent(channel_name):
|
295 |
-
prev_spends = (
|
296 |
-
st.session_state["scenario"].channels[channel_name].actual_total_spends
|
297 |
-
* st.session_state["scenario"].channels[channel_name].conversion_rate
|
298 |
-
)
|
299 |
-
modified_spends = prev_spends * (
|
300 |
-
1 + st.session_state[f"{channel_name}_percent"] / 100
|
301 |
-
)
|
302 |
-
|
303 |
-
st.session_state[channel_name] = numerize(modified_spends, 1)
|
304 |
-
|
305 |
-
st.session_state["scenario"].update(
|
306 |
-
channel_name,
|
307 |
-
modified_spends
|
308 |
-
/ st.session_state["scenario"].channels[channel_name].conversion_rate,
|
309 |
-
)
|
310 |
-
|
311 |
-
|
312 |
-
def update_data(channel_name):
|
313 |
-
"""
|
314 |
-
Updates the spends for the given channel
|
315 |
-
"""
|
316 |
-
print("tuns update_Data")
|
317 |
-
if validate_input(st.session_state[channel_name]):
|
318 |
-
modified_spends = extract_number_for_string(
|
319 |
-
st.session_state[channel_name]
|
320 |
-
)
|
321 |
-
|
322 |
-
prev_spends = (
|
323 |
-
st.session_state["scenario"]
|
324 |
-
.channels[channel_name]
|
325 |
-
.actual_total_spends
|
326 |
-
* st.session_state["scenario"]
|
327 |
-
.channels[channel_name]
|
328 |
-
.conversion_rate
|
329 |
-
)
|
330 |
-
st.session_state[f"{channel_name}_percent"] = round(
|
331 |
-
100 * (modified_spends - prev_spends) / prev_spends, 2
|
332 |
-
)
|
333 |
-
st.session_state["scenario"].update(
|
334 |
-
channel_name,
|
335 |
-
modified_spends
|
336 |
-
/ st.session_state["scenario"]
|
337 |
-
.channels[channel_name]
|
338 |
-
.conversion_rate,
|
339 |
-
)
|
340 |
-
# st.session_state['scenario'].update(channel_name, modified_spends)
|
341 |
-
# else:
|
342 |
-
# try:
|
343 |
-
# modified_spends = float(st.session_state[channel_name])
|
344 |
-
# prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate
|
345 |
-
# st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2)
|
346 |
-
# st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate)
|
347 |
-
# st.session_state[f'{channel_name}'] = numerize(modified_spends,1)
|
348 |
-
# except ValueError:
|
349 |
-
# st.write('Invalid input')
|
350 |
-
|
351 |
-
|
352 |
-
def select_channel_for_optimization(channel_name):
|
353 |
-
"""
|
354 |
-
Marks the given channel for optimization
|
355 |
-
"""
|
356 |
-
st.session_state["optimization_channels"][channel_name] = st.session_state[
|
357 |
-
f"{channel_name}_selected"
|
358 |
-
]
|
359 |
-
|
360 |
-
|
361 |
-
def select_all_channels_for_optimization():
|
362 |
-
"""
|
363 |
-
Marks all the channel for optimization
|
364 |
-
"""
|
365 |
-
# print(
|
366 |
-
# "DEBUG: select_all_channels_for_opt",
|
367 |
-
# st.session_state["optimze_all_channels"],
|
368 |
-
# )
|
369 |
-
|
370 |
-
for channel_name in st.session_state["optimization_channels"].keys():
|
371 |
-
st.session_state[f"{channel_name}_selected"] = st.session_state[
|
372 |
-
"optimze_all_channels"
|
373 |
-
]
|
374 |
-
st.session_state["optimization_channels"][channel_name] = (
|
375 |
-
st.session_state["optimze_all_channels"]
|
376 |
-
)
|
377 |
-
from pprint import pprint
|
378 |
-
|
379 |
-
|
380 |
-
def update_penalty():
|
381 |
-
"""
|
382 |
-
Updates the penalty flag for sales calculation
|
383 |
-
"""
|
384 |
-
st.session_state["scenario"].update_penalty(
|
385 |
-
st.session_state["apply_penalty"]
|
386 |
-
)
|
387 |
-
|
388 |
-
|
389 |
-
def reset_optimization():
|
390 |
-
print("DEBUG: ", "Running reset_optimization")
|
391 |
-
for channel_name in list(
|
392 |
-
st.session_state["project_dct"]["scenario_planner"][
|
393 |
-
unique_key
|
394 |
-
].channels.keys()
|
395 |
-
):
|
396 |
-
st.session_state[f"{channel_name}_selected"] = False
|
397 |
-
# st.session_state[f"{channel_name}_change"] = 0
|
398 |
-
st.session_state["optimze_all_channels"] = False
|
399 |
-
st.session_state["initialized"] = False
|
400 |
-
del st.session_state["total_sales_change_abs_slider"]
|
401 |
-
del st.session_state["total_sales_change_abs"]
|
402 |
-
del st.session_state["total_sales_change"]
|
403 |
-
|
404 |
-
|
405 |
-
def reset_scenario():
|
406 |
-
print("[DEBUG]: reset_scenario")
|
407 |
-
# def reset_scenario(panel_selected, file_selected, updated_rcs):
|
408 |
-
# #print(st.session_state['default_scenario_dict'])
|
409 |
-
# st.session_state['scenario'] = class_from_dict(st.session_state['default_scenario_dict'])
|
410 |
-
# for channel in st.session_state['scenario'].channels.values():
|
411 |
-
# st.session_state[channel.name] = float(channel.actual_total_spends * channel.conversion_rate)
|
412 |
-
for channel_name in list(
|
413 |
-
st.session_state["project_dct"]["scenario_planner"][
|
414 |
-
unique_key
|
415 |
-
].channels.keys()
|
416 |
-
):
|
417 |
-
st.session_state[f"{channel_name}_selected"] = False
|
418 |
-
# st.session_state[f"{channel_name}_change"] = 0
|
419 |
-
st.session_state["optimze_all_channels"] = False
|
420 |
-
st.session_state["initialized"] = False
|
421 |
-
|
422 |
-
del st.session_state["optimization_channels"]
|
423 |
-
panel_selected = st.session_state.get("panel_selected", 0)
|
424 |
-
file_selected = st.session_state["file_selected"]
|
425 |
-
update_rcs = st.session_state.get("update_rcs", None)
|
426 |
-
|
427 |
-
# print(f"## [DEBUG] [SCENARIO PLANNER][RESET SCENARIO]: {}")
|
428 |
-
del st.session_state["project_dct"]["scenario_planner"][
|
429 |
-
f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
|
430 |
-
]
|
431 |
-
del st.session_state["total_sales_change_abs_slider"]
|
432 |
-
del st.session_state["total_sales_change_abs"]
|
433 |
-
del st.session_state["total_sales_change"]
|
434 |
-
# if panel_selected == "Aggregated":
|
435 |
-
# initialize_data(
|
436 |
-
# panel=panel_selected,
|
437 |
-
# target_file=file_selected,
|
438 |
-
# updated_rcs=updated_rcs,
|
439 |
-
# metrics=metrics_selected,
|
440 |
-
# )
|
441 |
-
# panel = None
|
442 |
-
# else:
|
443 |
-
# initialize_data(
|
444 |
-
# panel=panel_selected,
|
445 |
-
# target_file=file_selected,
|
446 |
-
# updated_rcs=updated_rcs,
|
447 |
-
# metrics=metrics_selected,
|
448 |
-
# )
|
449 |
-
# st.session_state["total_spends_change"] = 0
|
450 |
-
# update_all_spends()
|
451 |
-
|
452 |
-
|
453 |
-
def format_number(num):
|
454 |
-
if num >= 1_000_000:
|
455 |
-
return f"{num / 1_000_000:.2f}M"
|
456 |
-
elif num >= 1_000:
|
457 |
-
return f"{num / 1_000:.0f}K"
|
458 |
-
else:
|
459 |
-
return f"{num:.2f}"
|
460 |
-
|
461 |
-
|
462 |
-
def summary_plot(data, x, y, title, text_column):
|
463 |
-
fig = px.bar(
|
464 |
-
data,
|
465 |
-
x=x,
|
466 |
-
y=y,
|
467 |
-
orientation="h",
|
468 |
-
title=title,
|
469 |
-
text=text_column,
|
470 |
-
color="Channel_name",
|
471 |
-
)
|
472 |
-
|
473 |
-
# Convert text_column to numeric values
|
474 |
-
data[text_column] = pd.to_numeric(data[text_column], errors="coerce")
|
475 |
-
|
476 |
-
# Update the format of the displayed text based on magnitude
|
477 |
-
fig.update_traces(
|
478 |
-
texttemplate="%{text:.2s}",
|
479 |
-
textposition="outside",
|
480 |
-
hovertemplate="%{x:.2s}",
|
481 |
-
)
|
482 |
-
|
483 |
-
fig.update_layout(
|
484 |
-
xaxis_title=x, yaxis_title="Channel Name", showlegend=False
|
485 |
-
)
|
486 |
-
return fig
|
487 |
-
|
488 |
-
|
489 |
-
def s_curve(x, K, b, a, x0):
|
490 |
-
return K / (1 + b * np.exp(-a * (x - x0)))
|
491 |
-
|
492 |
-
|
493 |
-
def find_segment_value(x, roi, mroi):
|
494 |
-
start_value = x[0]
|
495 |
-
end_value = x[len(x) - 1]
|
496 |
-
|
497 |
-
# Condition for green region: Both MROI and ROI > 1
|
498 |
-
green_condition = (roi > 1) & (mroi > 1)
|
499 |
-
left_indices = np.where(green_condition)[0]
|
500 |
-
left_value = x[left_indices[0]] if left_indices.size > 0 else x[0]
|
501 |
-
|
502 |
-
right_indices = np.where(green_condition)[0]
|
503 |
-
right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0]
|
504 |
-
|
505 |
-
return start_value, end_value, left_value, right_value
|
506 |
-
|
507 |
-
|
508 |
-
def calculate_rgba(
|
509 |
-
start_value, end_value, left_value, right_value, current_channel_spends
|
510 |
-
):
|
511 |
-
# Initialize alpha to None for clarity
|
512 |
-
alpha = None
|
513 |
-
|
514 |
-
# Determine the color and calculate relative_position and alpha based on the point's position
|
515 |
-
if start_value <= current_channel_spends <= left_value:
|
516 |
-
color = "yellow"
|
517 |
-
relative_position = (current_channel_spends - start_value) / (
|
518 |
-
left_value - start_value
|
519 |
-
)
|
520 |
-
alpha = 0.8 - (
|
521 |
-
0.6 * relative_position
|
522 |
-
) # Alpha decreases from start to end
|
523 |
-
|
524 |
-
elif left_value < current_channel_spends <= right_value:
|
525 |
-
color = "green"
|
526 |
-
relative_position = (current_channel_spends - left_value) / (
|
527 |
-
right_value - left_value
|
528 |
-
)
|
529 |
-
alpha = 0.8 - (
|
530 |
-
0.6 * relative_position
|
531 |
-
) # Alpha decreases from start to end
|
532 |
-
|
533 |
-
elif right_value < current_channel_spends <= end_value:
|
534 |
-
color = "red"
|
535 |
-
relative_position = (current_channel_spends - right_value) / (
|
536 |
-
end_value - right_value
|
537 |
-
)
|
538 |
-
alpha = 0.2 + (
|
539 |
-
0.6 * relative_position
|
540 |
-
) # Alpha increases from start to end
|
541 |
-
|
542 |
-
else:
|
543 |
-
# Default case, if the spends are outside the defined ranges
|
544 |
-
return "rgba(136, 136, 136, 0.5)" # Grey for values outside the range
|
545 |
-
|
546 |
-
# Ensure alpha is within the intended range in case of any calculation overshoot
|
547 |
-
alpha = max(0.2, min(alpha, 0.8))
|
548 |
-
|
549 |
-
# Define color codes for RGBA
|
550 |
-
color_codes = {
|
551 |
-
"yellow": "255, 255, 0", # RGB for yellow
|
552 |
-
"green": "0, 128, 0", # RGB for green
|
553 |
-
"red": "255, 0, 0", # RGB for red
|
554 |
-
}
|
555 |
-
|
556 |
-
rgba = f"rgba({color_codes[color]}, {alpha})"
|
557 |
-
return rgba
|
558 |
-
|
559 |
-
|
560 |
-
def debug_temp(x_test, power, K, b, a, x0):
|
561 |
-
print("*" * 100)
|
562 |
-
# Calculate the count of bins
|
563 |
-
count_lower_bin = sum(1 for x in x_test if x <= 2524)
|
564 |
-
count_center_bin = sum(1 for x in x_test if x > 2524 and x <= 3377)
|
565 |
-
count_ = sum(1 for x in x_test if x > 3377)
|
566 |
-
|
567 |
-
print(
|
568 |
-
f"""
|
569 |
-
lower : {count_lower_bin}
|
570 |
-
center : {count_center_bin}
|
571 |
-
upper : {count_}
|
572 |
-
"""
|
573 |
-
)
|
574 |
-
|
575 |
-
|
576 |
-
# @st.cache
|
577 |
-
def plot_response_curves():
|
578 |
-
cols = 4
|
579 |
-
rows = (
|
580 |
-
len(channels_list) // cols
|
581 |
-
if len(channels_list) % cols == 0
|
582 |
-
else len(channels_list) // cols + 1
|
583 |
-
)
|
584 |
-
rcs = st.session_state["rcs"]
|
585 |
-
shapes = []
|
586 |
-
fig = make_subplots(rows=rows, cols=cols, subplot_titles=channels_list)
|
587 |
-
for i in range(0, len(channels_list)):
|
588 |
-
col = channels_list[i]
|
589 |
-
x_actual = st.session_state["scenario"].channels[col].actual_spends
|
590 |
-
# x_modified = st.session_state["scenario"].channels[col].modified_spends
|
591 |
-
|
592 |
-
power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
|
593 |
-
|
594 |
-
K = rcs[col]["K"]
|
595 |
-
b = rcs[col]["b"]
|
596 |
-
a = rcs[col]["a"]
|
597 |
-
x0 = rcs[col]["x0"]
|
598 |
-
|
599 |
-
x_plot = np.linspace(0, 5 * x_actual.sum(), 50)
|
600 |
-
|
601 |
-
x, y, marginal_roi = [], [], []
|
602 |
-
for x_p in x_plot:
|
603 |
-
x.append(x_p * x_actual / x_actual.sum())
|
604 |
-
|
605 |
-
for index in range(len(x_plot)):
|
606 |
-
y.append(s_curve(x[index] / 10**power, K, b, a, x0))
|
607 |
-
|
608 |
-
for index in range(len(x_plot)):
|
609 |
-
marginal_roi.append(
|
610 |
-
a
|
611 |
-
* y[index]
|
612 |
-
* (1 - y[index] / np.maximum(K, np.finfo(float).eps))
|
613 |
-
)
|
614 |
-
|
615 |
-
x = (
|
616 |
-
np.sum(x, axis=1)
|
617 |
-
* st.session_state["scenario"].channels[col].conversion_rate
|
618 |
-
)
|
619 |
-
y = np.sum(y, axis=1)
|
620 |
-
marginal_roi = (
|
621 |
-
np.average(marginal_roi, axis=1)
|
622 |
-
/ st.session_state["scenario"].channels[col].conversion_rate
|
623 |
-
)
|
624 |
-
|
625 |
-
roi = y / np.maximum(x, np.finfo(float).eps)
|
626 |
-
|
627 |
-
fig.add_trace(
|
628 |
-
go.Scatter(
|
629 |
-
x=x,
|
630 |
-
y=y,
|
631 |
-
name=col,
|
632 |
-
customdata=np.stack((roi, marginal_roi), axis=-1),
|
633 |
-
hovertemplate="Spend:%{x:$.2s}<br>Sale:%{y:$.2s}<br>ROI:%{customdata[0]:.3f}<br>MROI:%{customdata[1]:.3f}",
|
634 |
-
line=dict(color="blue"),
|
635 |
-
),
|
636 |
-
row=1 + (i) // cols,
|
637 |
-
col=i % cols + 1,
|
638 |
-
)
|
639 |
-
|
640 |
-
x_optimal = (
|
641 |
-
st.session_state["scenario"].channels[col].modified_total_spends
|
642 |
-
* st.session_state["scenario"].channels[col].conversion_rate
|
643 |
-
)
|
644 |
-
y_optimal = (
|
645 |
-
st.session_state["scenario"].channels[col].modified_total_sales
|
646 |
-
)
|
647 |
-
|
648 |
-
# if col == "Paid_social_others":
|
649 |
-
# debug_temp(x_optimal * x_actual / x_actual.sum(), power, K, b, a, x0)
|
650 |
-
|
651 |
-
fig.add_trace(
|
652 |
-
go.Scatter(
|
653 |
-
x=[x_optimal],
|
654 |
-
y=[y_optimal],
|
655 |
-
name=col,
|
656 |
-
legendgroup=col,
|
657 |
-
showlegend=False,
|
658 |
-
marker=dict(color=["black"]),
|
659 |
-
),
|
660 |
-
row=1 + (i) // cols,
|
661 |
-
col=i % cols + 1,
|
662 |
-
)
|
663 |
-
|
664 |
-
shapes.append(
|
665 |
-
go.layout.Shape(
|
666 |
-
type="line",
|
667 |
-
x0=0,
|
668 |
-
y0=y_optimal,
|
669 |
-
x1=x_optimal,
|
670 |
-
y1=y_optimal,
|
671 |
-
line_width=1,
|
672 |
-
line_dash="dash",
|
673 |
-
line_color="black",
|
674 |
-
xref=f"x{i+1}",
|
675 |
-
yref=f"y{i+1}",
|
676 |
-
)
|
677 |
-
)
|
678 |
-
|
679 |
-
shapes.append(
|
680 |
-
go.layout.Shape(
|
681 |
-
type="line",
|
682 |
-
x0=x_optimal,
|
683 |
-
y0=0,
|
684 |
-
x1=x_optimal,
|
685 |
-
y1=y_optimal,
|
686 |
-
line_width=1,
|
687 |
-
line_dash="dash",
|
688 |
-
line_color="black",
|
689 |
-
xref=f"x{i+1}",
|
690 |
-
yref=f"y{i+1}",
|
691 |
-
)
|
692 |
-
)
|
693 |
-
|
694 |
-
start_value, end_value, left_value, right_value = find_segment_value(
|
695 |
-
x,
|
696 |
-
roi,
|
697 |
-
marginal_roi,
|
698 |
-
)
|
699 |
-
|
700 |
-
# Adding background colors
|
701 |
-
y_max = y.max() * 1.3 # 30% extra space above the max
|
702 |
-
|
703 |
-
# Yellow region
|
704 |
-
shapes.append(
|
705 |
-
go.layout.Shape(
|
706 |
-
type="rect",
|
707 |
-
x0=start_value,
|
708 |
-
y0=0,
|
709 |
-
x1=left_value,
|
710 |
-
y1=y_max,
|
711 |
-
line=dict(width=0),
|
712 |
-
fillcolor="rgba(255, 255, 0, 0.3)",
|
713 |
-
layer="below",
|
714 |
-
xref=f"x{i+1}",
|
715 |
-
yref=f"y{i+1}",
|
716 |
-
)
|
717 |
-
)
|
718 |
-
|
719 |
-
# Green region
|
720 |
-
shapes.append(
|
721 |
-
go.layout.Shape(
|
722 |
-
type="rect",
|
723 |
-
x0=left_value,
|
724 |
-
y0=0,
|
725 |
-
x1=right_value,
|
726 |
-
y1=y_max,
|
727 |
-
line=dict(width=0),
|
728 |
-
fillcolor="rgba(0, 255, 0, 0.3)",
|
729 |
-
layer="below",
|
730 |
-
xref=f"x{i+1}",
|
731 |
-
yref=f"y{i+1}",
|
732 |
-
)
|
733 |
-
)
|
734 |
-
|
735 |
-
# Red region
|
736 |
-
shapes.append(
|
737 |
-
go.layout.Shape(
|
738 |
-
type="rect",
|
739 |
-
x0=right_value,
|
740 |
-
y0=0,
|
741 |
-
x1=end_value,
|
742 |
-
y1=y_max,
|
743 |
-
line=dict(width=0),
|
744 |
-
fillcolor="rgba(255, 0, 0, 0.3)",
|
745 |
-
layer="below",
|
746 |
-
xref=f"x{i+1}",
|
747 |
-
yref=f"y{i+1}",
|
748 |
-
)
|
749 |
-
)
|
750 |
-
|
751 |
-
fig.update_layout(
|
752 |
-
# height=1000,
|
753 |
-
# width=1000,
|
754 |
-
title_text=f"Response Curves (X: Spends Vs Y: {target})",
|
755 |
-
showlegend=False,
|
756 |
-
shapes=shapes,
|
757 |
-
)
|
758 |
-
fig.update_annotations(font_size=10)
|
759 |
-
# fig.update_xaxes(title="Spends")
|
760 |
-
# fig.update_yaxes(title=target)
|
761 |
-
fig.update_yaxes(
|
762 |
-
gridcolor="rgba(136, 136, 136, 0.5)", gridwidth=0.5, griddash="dash"
|
763 |
-
)
|
764 |
-
|
765 |
-
return fig
|
766 |
-
|
767 |
-
|
768 |
-
# ======================================================== #
|
769 |
-
# ==================== HTML Components =================== #
|
770 |
-
# ======================================================== #
|
771 |
-
|
772 |
-
|
773 |
-
def generate_spending_header(heading):
|
774 |
-
return st.markdown(
|
775 |
-
f"""<h2 class="spends-header">{heading}</h2>""", unsafe_allow_html=True
|
776 |
-
)
|
777 |
-
|
778 |
-
|
779 |
-
def save_checkpoint():
|
780 |
-
project_dct_path = os.path.join(
|
781 |
-
st.session_state["project_path"], "project_dct.pkl"
|
782 |
-
)
|
783 |
-
|
784 |
-
try:
|
785 |
-
pickle.dumps(st.session_state["project_dct"])
|
786 |
-
with open(project_dct_path, "wb") as f:
|
787 |
-
pickle.dump(st.session_state["project_dct"], f)
|
788 |
-
except Exception:
|
789 |
-
# with warning_placeholder:
|
790 |
-
st.toast("Unknown Issue, please reload the page.")
|
791 |
-
|
792 |
-
|
793 |
-
def reset_checkpoint():
|
794 |
-
st.session_state["project_dct"]["scenario_planner"] = {}
|
795 |
-
save_checkpoint()
|
796 |
-
|
797 |
-
|
798 |
-
# ======================================================== #
|
799 |
-
# =================== Session variables ================== #
|
800 |
-
# ======================================================== #
|
801 |
-
|
802 |
-
with open("config.yaml") as file:
|
803 |
-
config = yaml.load(file, Loader=SafeLoader)
|
804 |
-
st.session_state["config"] = config
|
805 |
-
|
806 |
-
authenticator = stauth.Authenticate(
|
807 |
-
config["credentials"],
|
808 |
-
config["cookie"]["name"],
|
809 |
-
config["cookie"]["key"],
|
810 |
-
config["cookie"]["expiry_days"],
|
811 |
-
config["preauthorized"],
|
812 |
-
)
|
813 |
-
st.session_state["authenticator"] = authenticator
|
814 |
-
name, authentication_status, username = authenticator.login("Login", "main")
|
815 |
-
auth_status = st.session_state.get("authentication_status")
|
816 |
-
|
817 |
-
import os
|
818 |
-
import glob
|
819 |
-
|
820 |
-
|
821 |
-
def get_excel_names(directory):
|
822 |
-
# Create a list to hold the final parts of the filenames
|
823 |
-
last_portions = []
|
824 |
-
|
825 |
-
# Patterns to match Excel files (.xlsx and .xls) that contain @#
|
826 |
-
patterns = [
|
827 |
-
os.path.join(directory, "*@#*.xlsx"),
|
828 |
-
os.path.join(directory, "*@#*.xls"),
|
829 |
-
]
|
830 |
-
|
831 |
-
# Process each pattern
|
832 |
-
for pattern in patterns:
|
833 |
-
files = glob.glob(pattern)
|
834 |
-
|
835 |
-
# Extracting the last portion after @# for each file
|
836 |
-
for file in files:
|
837 |
-
base_name = os.path.basename(file)
|
838 |
-
last_portion = base_name.split("@#")[-1]
|
839 |
-
last_portion = last_portion.replace(".xlsx", "").replace(
|
840 |
-
".xls", ""
|
841 |
-
) # Removing extensions
|
842 |
-
last_portions.append(last_portion)
|
843 |
-
|
844 |
-
return last_portions
|
845 |
-
|
846 |
-
|
847 |
-
def name_formating(channel_name):
|
848 |
-
# Replace underscores with spaces
|
849 |
-
name_mod = channel_name.replace("_", " ")
|
850 |
-
|
851 |
-
# Capitalize the first letter of each word
|
852 |
-
name_mod = name_mod.title()
|
853 |
-
|
854 |
-
return name_mod
|
855 |
-
|
856 |
-
|
857 |
-
@st.cache_resource(show_spinner=False)
|
858 |
-
def panel_fetch(file_selected):
|
859 |
-
raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
|
860 |
-
|
861 |
-
if "Panel" in raw_data_mmm_df.columns:
|
862 |
-
panel = list(set(raw_data_mmm_df["Panel"]))
|
863 |
-
else:
|
864 |
-
raw_data_mmm_df = None
|
865 |
-
panel = None
|
866 |
-
|
867 |
-
return panel
|
868 |
-
|
869 |
-
|
870 |
-
if auth_status is True:
|
871 |
-
authenticator.logout("Logout", "main")
|
872 |
-
|
873 |
-
if "project_dct" not in st.session_state:
|
874 |
-
st.error("Please load a project from home")
|
875 |
-
st.stop()
|
876 |
-
|
877 |
-
database_file = r"DB\User.db"
|
878 |
-
|
879 |
-
conn = sqlite3.connect(
|
880 |
-
database_file, check_same_thread=False
|
881 |
-
) # connection with sql db
|
882 |
-
c = conn.cursor()
|
883 |
-
|
884 |
-
with st.sidebar:
|
885 |
-
st.button("Save checkpoint", on_click=save_checkpoint)
|
886 |
-
st.button("Reset Checkpoint", on_click=reset_checkpoint)
|
887 |
-
|
888 |
-
warning_placeholder = st.empty()
|
889 |
-
st.header("Scenario Planner")
|
890 |
-
|
891 |
-
# st.subheader("Simulation")
|
892 |
-
col1, col2 = st.columns([1, 1])
|
893 |
-
|
894 |
-
# Get metric and panel from last saved state
|
895 |
-
if "last_saved_metric" not in st.session_state:
|
896 |
-
st.session_state["last_saved_metric"] = st.session_state[
|
897 |
-
"project_dct"
|
898 |
-
]["scenario_planner"].get("metric_selected", 0)
|
899 |
-
# st.session_state["last_saved_metric"] = st.session_state[
|
900 |
-
# "project_dct"
|
901 |
-
# ]["scenario_planner"].get("metric_selected", 0)
|
902 |
-
|
903 |
-
if "last_saved_panel" not in st.session_state:
|
904 |
-
st.session_state["last_saved_panel"] = st.session_state["project_dct"][
|
905 |
-
"scenario_planner"
|
906 |
-
].get("panel_selected", 0)
|
907 |
-
# st.session_state["last_saved_panel"] = st.session_state["project_dct"][
|
908 |
-
# "scenario_planner"
|
909 |
-
# ].get("panel_selected", 0)
|
910 |
-
|
911 |
-
# Response Metrics
|
912 |
-
directory = "metrics_level_data"
|
913 |
-
metrics_list = get_excel_names(directory)
|
914 |
-
metrics_selected = col1.selectbox(
|
915 |
-
"Response Metrics",
|
916 |
-
metrics_list,
|
917 |
-
format_func=name_formating,
|
918 |
-
index=st.session_state["last_saved_metric"],
|
919 |
-
on_change=reset_optimization,
|
920 |
-
key="metric_selected",
|
921 |
-
)
|
922 |
-
|
923 |
-
# Target
|
924 |
-
target = name_formating(metrics_selected)
|
925 |
-
|
926 |
-
file_selected = f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
|
927 |
-
# print(f"[DEBUG]: {metrics_selected}")
|
928 |
-
# print(f"[DEBUG]: {file_selected}")
|
929 |
-
st.session_state["file_selected"] = file_selected
|
930 |
-
# Panel List
|
931 |
-
panel_list = panel_fetch(file_selected)
|
932 |
-
panel_list_final = ["Aggregated"] + panel_list
|
933 |
-
|
934 |
-
# Panel Selected
|
935 |
-
panel_selected = col2.selectbox(
|
936 |
-
"Panel",
|
937 |
-
panel_list_final,
|
938 |
-
on_change=reset_optimization,
|
939 |
-
key="panel_selected",
|
940 |
-
index=st.session_state["last_saved_panel"],
|
941 |
-
)
|
942 |
-
|
943 |
-
unique_key = f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
|
944 |
-
|
945 |
-
if "update_rcs" in st.session_state:
|
946 |
-
updated_rcs = st.session_state["update_rcs"]
|
947 |
-
else:
|
948 |
-
updated_rcs = None
|
949 |
-
|
950 |
-
if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
|
951 |
-
if panel_selected == "Aggregated":
|
952 |
-
initialize_data(
|
953 |
-
panel=panel_selected,
|
954 |
-
target_file=file_selected,
|
955 |
-
updated_rcs=updated_rcs,
|
956 |
-
metrics=metrics_selected,
|
957 |
-
)
|
958 |
-
panel = None
|
959 |
-
else:
|
960 |
-
initialize_data(
|
961 |
-
panel=panel_selected,
|
962 |
-
target_file=file_selected,
|
963 |
-
updated_rcs=updated_rcs,
|
964 |
-
metrics=metrics_selected,
|
965 |
-
)
|
966 |
-
st.session_state["project_dct"]["scenario_planner"][unique_key] = (
|
967 |
-
st.session_state["scenario"]
|
968 |
-
)
|
969 |
-
|
970 |
-
else:
|
971 |
-
st.session_state["scenario"] = st.session_state["project_dct"][
|
972 |
-
"scenario_planner"
|
973 |
-
][unique_key]
|
974 |
-
st.session_state["rcs"] = {}
|
975 |
-
st.session_state["powers"] = {}
|
976 |
-
if "optimization_channels" not in st.session_state:
|
977 |
-
st.session_state["optimization_channels"] = {}
|
978 |
-
|
979 |
-
for channel_name, _channel in st.session_state["project_dct"][
|
980 |
-
"scenario_planner"
|
981 |
-
][unique_key].channels.items():
|
982 |
-
st.session_state[channel_name] = numerize(
|
983 |
-
_channel.modified_total_spends, 1
|
984 |
-
)
|
985 |
-
st.session_state["rcs"][
|
986 |
-
channel_name
|
987 |
-
] = _channel.response_curve_params
|
988 |
-
st.session_state["powers"][channel_name] = _channel.power
|
989 |
-
if channel_name not in st.session_state["optimization_channels"]:
|
990 |
-
st.session_state["optimization_channels"][channel_name] = False
|
991 |
-
|
992 |
-
if "first_time" not in st.session_state:
|
993 |
-
st.session_state["first_time"] = True
|
994 |
-
st.session_state["first_run_scenario"] = True
|
995 |
-
|
996 |
-
# Check if state is initiaized
|
997 |
-
is_state_initiaized = st.session_state.get("initialized", False)
|
998 |
-
|
999 |
-
# if not is_state_initiaized:
|
1000 |
-
# print("running initialize...")
|
1001 |
-
# # initialize_data()
|
1002 |
-
# if panel_selected == "Aggregated":
|
1003 |
-
# initialize_data(
|
1004 |
-
# panel=panel_selected,
|
1005 |
-
# target_file=file_selected,
|
1006 |
-
# updated_rcs=updated_rcs,
|
1007 |
-
# metrics=metrics_selected,
|
1008 |
-
# )
|
1009 |
-
# panel = None
|
1010 |
-
# else:
|
1011 |
-
# initialize_data(
|
1012 |
-
# panel=panel_selected,
|
1013 |
-
# target_file=file_selected,
|
1014 |
-
# updated_rcs=updated_rcs,
|
1015 |
-
# metrics=metrics_selected,
|
1016 |
-
# )
|
1017 |
-
# st.session_state["initialized"] = True
|
1018 |
-
# st.session_state["first_time"] = False
|
1019 |
-
|
1020 |
-
# Channels List
|
1021 |
-
channels_list = list(
|
1022 |
-
st.session_state["project_dct"]["scenario_planner"][
|
1023 |
-
unique_key
|
1024 |
-
].channels.keys()
|
1025 |
-
)
|
1026 |
-
|
1027 |
-
# ======================================================== #
|
1028 |
-
# ========================== UI ========================== #
|
1029 |
-
# ======================================================== #
|
1030 |
-
|
1031 |
-
main_header = st.columns((2, 2))
|
1032 |
-
sub_header = st.columns((1, 1, 1, 1))
|
1033 |
-
# _scenario = st.session_state["scenario"]
|
1034 |
-
|
1035 |
-
st.session_state.total_spends_change = round(
|
1036 |
-
(
|
1037 |
-
st.session_state["scenario"].modified_total_spends
|
1038 |
-
/ st.session_state["scenario"].actual_total_spends
|
1039 |
-
- 1
|
1040 |
-
)
|
1041 |
-
* 100
|
1042 |
-
)
|
1043 |
-
|
1044 |
-
if "total_sales_change" not in st.session_state:
|
1045 |
-
st.session_state.total_sales_change = round(
|
1046 |
-
(
|
1047 |
-
st.session_state["scenario"].modified_total_sales
|
1048 |
-
/ st.session_state["scenario"].actual_total_sales
|
1049 |
-
- 1
|
1050 |
-
)
|
1051 |
-
* 100
|
1052 |
-
)
|
1053 |
-
|
1054 |
-
st.session_state["total_spends_change_abs"] = numerize(
|
1055 |
-
st.session_state["scenario"].modified_total_spends,
|
1056 |
-
1,
|
1057 |
-
)
|
1058 |
-
if "total_sales_change_abs" not in st.session_state:
|
1059 |
-
st.session_state["total_sales_change_abs"] = numerize(
|
1060 |
-
st.session_state["scenario"].modified_total_sales,
|
1061 |
-
1,
|
1062 |
-
)
|
1063 |
-
|
1064 |
-
# if "total_spends_change_abs_slider" not in st.session_state:
|
1065 |
-
st.session_state.total_spends_change_abs_slider = numerize(
|
1066 |
-
st.session_state["scenario"].modified_total_spends, 1
|
1067 |
-
)
|
1068 |
-
|
1069 |
-
if "total_sales_change_abs_slider" not in st.session_state:
|
1070 |
-
st.session_state.total_sales_change_abs_slider = numerize(
|
1071 |
-
st.session_state["scenario"].actual_total_sales, 1
|
1072 |
-
)
|
1073 |
-
|
1074 |
-
st.session_state["allow_sales_update"] = True
|
1075 |
-
|
1076 |
-
st.session_state["allow_spends_update"] = True
|
1077 |
-
|
1078 |
-
# if "panel_selected" not in st.session_state:
|
1079 |
-
# st.session_state["panel_selected"] = 0
|
1080 |
-
|
1081 |
-
with main_header[0]:
|
1082 |
-
st.subheader("Actual")
|
1083 |
-
|
1084 |
-
with main_header[-1]:
|
1085 |
-
st.subheader("Simulated")
|
1086 |
-
|
1087 |
-
with sub_header[0]:
|
1088 |
-
st.metric(
|
1089 |
-
label="Spends",
|
1090 |
-
value=format_numbers(
|
1091 |
-
st.session_state["scenario"].actual_total_spends
|
1092 |
-
),
|
1093 |
-
)
|
1094 |
-
|
1095 |
-
with sub_header[1]:
|
1096 |
-
st.metric(
|
1097 |
-
label=target,
|
1098 |
-
value=format_numbers(
|
1099 |
-
float(st.session_state["scenario"].actual_total_sales),
|
1100 |
-
include_indicator=False,
|
1101 |
-
),
|
1102 |
-
)
|
1103 |
-
|
1104 |
-
with sub_header[2]:
|
1105 |
-
st.metric(
|
1106 |
-
label="Spends",
|
1107 |
-
value=format_numbers(
|
1108 |
-
st.session_state["scenario"].modified_total_spends
|
1109 |
-
),
|
1110 |
-
delta=numerize(st.session_state["scenario"].delta_spends, 1),
|
1111 |
-
)
|
1112 |
-
|
1113 |
-
with sub_header[3]:
|
1114 |
-
st.metric(
|
1115 |
-
label=target,
|
1116 |
-
value=format_numbers(
|
1117 |
-
float(st.session_state["scenario"].modified_total_sales),
|
1118 |
-
include_indicator=False,
|
1119 |
-
),
|
1120 |
-
delta=numerize(st.session_state["scenario"].delta_sales, 1),
|
1121 |
-
)
|
1122 |
-
|
1123 |
-
with st.expander("Channel Spends Simulator", expanded=True):
|
1124 |
-
_columns1 = st.columns((2, 2, 1, 1))
|
1125 |
-
with _columns1[0]:
|
1126 |
-
optimization_selection = st.selectbox(
|
1127 |
-
"Optimize",
|
1128 |
-
options=["Media Spends", target],
|
1129 |
-
key="optimization_key_value",
|
1130 |
-
)
|
1131 |
-
|
1132 |
-
with _columns1[1]:
|
1133 |
-
st.markdown("#")
|
1134 |
-
# if st.checkbox(
|
1135 |
-
# label="Optimize all Channels",
|
1136 |
-
# key="optimze_all_channels",
|
1137 |
-
# value=False,
|
1138 |
-
# # on_change=select_all_channels_for_optimization,
|
1139 |
-
# ):
|
1140 |
-
# select_all_channels_for_optimization()
|
1141 |
-
|
1142 |
-
st.checkbox(
|
1143 |
-
label="Optimize all Channels",
|
1144 |
-
key="optimze_all_channels",
|
1145 |
-
on_change=select_all_channels_for_optimization,
|
1146 |
-
)
|
1147 |
-
|
1148 |
-
with _columns1[2]:
|
1149 |
-
st.markdown("#")
|
1150 |
-
# st.button(
|
1151 |
-
# "Optimize",
|
1152 |
-
# on_click=optimize,
|
1153 |
-
# args=(st.session_state["optimization_key_value"]),
|
1154 |
-
# use_container_width=True,
|
1155 |
-
# )
|
1156 |
-
|
1157 |
-
optimize_placeholder = st.empty()
|
1158 |
-
|
1159 |
-
with _columns1[3]:
|
1160 |
-
st.markdown("#")
|
1161 |
-
st.button(
|
1162 |
-
"Reset",
|
1163 |
-
on_click=reset_scenario,
|
1164 |
-
# args=(panel_selected, file_selected, updated_rcs),
|
1165 |
-
use_container_width=True,
|
1166 |
-
)
|
1167 |
-
|
1168 |
-
_columns2 = st.columns((2, 2, 2))
|
1169 |
-
if st.session_state["optimization_key_value"] == "Media Spends":
|
1170 |
-
|
1171 |
-
# update_spends()
|
1172 |
-
|
1173 |
-
with _columns2[0]:
|
1174 |
-
spend_input = st.text_input(
|
1175 |
-
"Absolute",
|
1176 |
-
key="total_spends_change_abs",
|
1177 |
-
# label_visibility="collapsed",
|
1178 |
-
on_change=update_all_spends_abs,
|
1179 |
-
)
|
1180 |
-
|
1181 |
-
with _columns2[1]:
|
1182 |
-
st.number_input(
|
1183 |
-
"Percent Change",
|
1184 |
-
key="total_spends_change",
|
1185 |
-
min_value=-50,
|
1186 |
-
max_value=50,
|
1187 |
-
step=1,
|
1188 |
-
on_change=update_spends,
|
1189 |
-
)
|
1190 |
-
|
1191 |
-
with _columns2[2]:
|
1192 |
-
scenario = st.session_state["project_dct"]["scenario_planner"][
|
1193 |
-
unique_key
|
1194 |
-
]
|
1195 |
-
min_value = round(scenario.actual_total_spends * 0.5)
|
1196 |
-
max_value = round(scenario.actual_total_spends * 1.5)
|
1197 |
-
st.session_state["total_spends_change_abs_slider_options"] = [
|
1198 |
-
numerize(value, 1)
|
1199 |
-
for value in range(min_value, max_value + 1, int(1e4))
|
1200 |
-
]
|
1201 |
-
|
1202 |
-
st.select_slider(
|
1203 |
-
"Absolute Slider",
|
1204 |
-
options=st.session_state[
|
1205 |
-
"total_spends_change_abs_slider_options"
|
1206 |
-
],
|
1207 |
-
key="total_spends_change_abs_slider",
|
1208 |
-
on_change=update_all_spends_abs_slider,
|
1209 |
-
)
|
1210 |
-
|
1211 |
-
elif st.session_state["optimization_key_value"] == target:
|
1212 |
-
# update_sales()
|
1213 |
-
|
1214 |
-
with _columns2[0]:
|
1215 |
-
sales_input = st.text_input(
|
1216 |
-
"Absolute",
|
1217 |
-
key="total_sales_change_abs",
|
1218 |
-
on_change=update_sales_abs,
|
1219 |
-
)
|
1220 |
-
|
1221 |
-
with _columns2[1]:
|
1222 |
-
st.number_input(
|
1223 |
-
"Percent Change",
|
1224 |
-
key="total_sales_change",
|
1225 |
-
min_value=-50,
|
1226 |
-
max_value=50,
|
1227 |
-
step=1,
|
1228 |
-
on_change=update_sales,
|
1229 |
-
)
|
1230 |
-
|
1231 |
-
with _columns2[2]:
|
1232 |
-
min_value = round(
|
1233 |
-
st.session_state["scenario"].actual_total_sales * 0.5
|
1234 |
-
)
|
1235 |
-
max_value = round(
|
1236 |
-
st.session_state["scenario"].actual_total_sales * 1.5
|
1237 |
-
)
|
1238 |
-
st.session_state["total_sales_change_abs_slider_options"] = [
|
1239 |
-
numerize(value, 1)
|
1240 |
-
for value in range(min_value, max_value + 1, int(1e5))
|
1241 |
-
]
|
1242 |
-
|
1243 |
-
st.select_slider(
|
1244 |
-
"Absolute Slider",
|
1245 |
-
options=st.session_state[
|
1246 |
-
"total_sales_change_abs_slider_options"
|
1247 |
-
],
|
1248 |
-
key="total_sales_change_abs_slider",
|
1249 |
-
on_change=update_sales_abs_slider,
|
1250 |
-
)
|
1251 |
-
|
1252 |
-
if (
|
1253 |
-
not st.session_state["allow_sales_update"]
|
1254 |
-
and optimization_selection == target
|
1255 |
-
):
|
1256 |
-
st.warning("Invalid Input")
|
1257 |
-
|
1258 |
-
if (
|
1259 |
-
not st.session_state["allow_spends_update"]
|
1260 |
-
and optimization_selection == "Media Spends"
|
1261 |
-
):
|
1262 |
-
st.warning("Invalid Input")
|
1263 |
-
|
1264 |
-
status_placeholder = st.empty()
|
1265 |
-
|
1266 |
-
# if optimize_placeholder.button("Optimize", use_container_width=True):
|
1267 |
-
# optimize(st.session_state["optimization_key_value"], status_placeholder)
|
1268 |
-
# st.rerun()
|
1269 |
-
|
1270 |
-
optimize_placeholder.button(
|
1271 |
-
"Optimize",
|
1272 |
-
on_click=optimize,
|
1273 |
-
args=(
|
1274 |
-
st.session_state["optimization_key_value"],
|
1275 |
-
status_placeholder,
|
1276 |
-
),
|
1277 |
-
use_container_width=True,
|
1278 |
-
)
|
1279 |
-
|
1280 |
-
st.markdown(
|
1281 |
-
"""<hr class="spends-heading-seperator">""", unsafe_allow_html=True
|
1282 |
-
)
|
1283 |
-
_columns = st.columns((2.5, 2, 1.5, 1.5, 1))
|
1284 |
-
with _columns[0]:
|
1285 |
-
generate_spending_header("Channel")
|
1286 |
-
with _columns[1]:
|
1287 |
-
generate_spending_header("Spends Input")
|
1288 |
-
with _columns[2]:
|
1289 |
-
generate_spending_header("Spends")
|
1290 |
-
with _columns[3]:
|
1291 |
-
generate_spending_header(target)
|
1292 |
-
with _columns[4]:
|
1293 |
-
generate_spending_header("Optimize")
|
1294 |
-
|
1295 |
-
st.markdown(
|
1296 |
-
"""<hr class="spends-heading-seperator">""", unsafe_allow_html=True
|
1297 |
-
)
|
1298 |
-
|
1299 |
-
if "acutual_predicted" not in st.session_state:
|
1300 |
-
st.session_state["acutual_predicted"] = {
|
1301 |
-
"Channel_name": [],
|
1302 |
-
"Actual_spend": [],
|
1303 |
-
"Optimized_spend": [],
|
1304 |
-
"Delta": [],
|
1305 |
-
}
|
1306 |
-
for i, channel_name in enumerate(channels_list):
|
1307 |
-
_channel_class = st.session_state["scenario"].channels[
|
1308 |
-
channel_name
|
1309 |
-
]
|
1310 |
-
|
1311 |
-
st.session_state[f"{channel_name}_percent"] = round(
|
1312 |
-
(
|
1313 |
-
_channel_class.modified_total_spends
|
1314 |
-
/ _channel_class.actual_total_spends
|
1315 |
-
- 1
|
1316 |
-
)
|
1317 |
-
* 100
|
1318 |
-
)
|
1319 |
-
|
1320 |
-
_columns = st.columns((2.5, 1.5, 1.5, 1.5, 1))
|
1321 |
-
with _columns[0]:
|
1322 |
-
st.write(channel_name_formating(channel_name))
|
1323 |
-
bin_placeholder = st.container()
|
1324 |
-
|
1325 |
-
with _columns[1]:
|
1326 |
-
channel_bounds = _channel_class.bounds
|
1327 |
-
channel_spends = float(_channel_class.actual_total_spends)
|
1328 |
-
min_value = float(
|
1329 |
-
(1 + channel_bounds[0] / 100) * channel_spends
|
1330 |
-
)
|
1331 |
-
max_value = float(
|
1332 |
-
(1 + channel_bounds[1] / 100) * channel_spends
|
1333 |
-
)
|
1334 |
-
# print("##########", st.session_state[channel_name])
|
1335 |
-
spend_input = st.text_input(
|
1336 |
-
channel_name,
|
1337 |
-
key=channel_name,
|
1338 |
-
label_visibility="collapsed",
|
1339 |
-
on_change=partial(update_data, channel_name),
|
1340 |
-
)
|
1341 |
-
if not validate_input(spend_input):
|
1342 |
-
st.error("Invalid input")
|
1343 |
-
|
1344 |
-
channel_name_current = f"{channel_name}_change"
|
1345 |
-
|
1346 |
-
st.number_input(
|
1347 |
-
"Percent Change",
|
1348 |
-
key=f"{channel_name}_percent",
|
1349 |
-
step=1,
|
1350 |
-
on_change=partial(update_data_by_percent, channel_name),
|
1351 |
-
)
|
1352 |
-
|
1353 |
-
with _columns[2]:
|
1354 |
-
# spends
|
1355 |
-
current_channel_spends = float(
|
1356 |
-
_channel_class.modified_total_spends
|
1357 |
-
* _channel_class.conversion_rate
|
1358 |
-
)
|
1359 |
-
actual_channel_spends = float(
|
1360 |
-
_channel_class.actual_total_spends
|
1361 |
-
* _channel_class.conversion_rate
|
1362 |
-
)
|
1363 |
-
spends_delta = float(
|
1364 |
-
_channel_class.delta_spends
|
1365 |
-
* _channel_class.conversion_rate
|
1366 |
-
)
|
1367 |
-
st.session_state["acutual_predicted"]["Channel_name"].append(
|
1368 |
-
channel_name
|
1369 |
-
)
|
1370 |
-
st.session_state["acutual_predicted"]["Actual_spend"].append(
|
1371 |
-
actual_channel_spends
|
1372 |
-
)
|
1373 |
-
st.session_state["acutual_predicted"][
|
1374 |
-
"Optimized_spend"
|
1375 |
-
].append(current_channel_spends)
|
1376 |
-
st.session_state["acutual_predicted"]["Delta"].append(
|
1377 |
-
spends_delta
|
1378 |
-
)
|
1379 |
-
## REMOVE
|
1380 |
-
st.metric(
|
1381 |
-
"Spends",
|
1382 |
-
format_numbers(current_channel_spends),
|
1383 |
-
delta=numerize(spends_delta, 1),
|
1384 |
-
label_visibility="collapsed",
|
1385 |
-
)
|
1386 |
-
|
1387 |
-
with _columns[3]:
|
1388 |
-
# sales
|
1389 |
-
current_channel_sales = float(
|
1390 |
-
_channel_class.modified_total_sales
|
1391 |
-
)
|
1392 |
-
actual_channel_sales = float(_channel_class.actual_total_sales)
|
1393 |
-
sales_delta = float(_channel_class.delta_sales)
|
1394 |
-
st.metric(
|
1395 |
-
target,
|
1396 |
-
format_numbers(
|
1397 |
-
current_channel_sales, include_indicator=False
|
1398 |
-
),
|
1399 |
-
delta=numerize(sales_delta, 1),
|
1400 |
-
label_visibility="collapsed",
|
1401 |
-
)
|
1402 |
-
|
1403 |
-
with _columns[4]:
|
1404 |
-
|
1405 |
-
# if st.checkbox(
|
1406 |
-
# label="select for optimization",
|
1407 |
-
# key=f"{channel_name}_selected",
|
1408 |
-
# value=False,
|
1409 |
-
# # on_change=partial(select_channel_for_optimization, channel_name),
|
1410 |
-
# label_visibility="collapsed",
|
1411 |
-
# ):
|
1412 |
-
# select_channel_for_optimization(channel_name)
|
1413 |
-
|
1414 |
-
st.checkbox(
|
1415 |
-
label="select for optimization",
|
1416 |
-
key=f"{channel_name}_selected",
|
1417 |
-
value=False,
|
1418 |
-
on_change=partial(
|
1419 |
-
select_channel_for_optimization, channel_name
|
1420 |
-
),
|
1421 |
-
label_visibility="collapsed",
|
1422 |
-
)
|
1423 |
-
|
1424 |
-
st.markdown(
|
1425 |
-
"""<hr class="spends-child-seperator">""",
|
1426 |
-
unsafe_allow_html=True,
|
1427 |
-
)
|
1428 |
-
|
1429 |
-
# Bins
|
1430 |
-
col = channels_list[i]
|
1431 |
-
x_actual = st.session_state["scenario"].channels[col].actual_spends
|
1432 |
-
x_modified = (
|
1433 |
-
st.session_state["scenario"].channels[col].modified_spends
|
1434 |
-
)
|
1435 |
-
|
1436 |
-
x_total = x_modified.sum()
|
1437 |
-
power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
|
1438 |
-
|
1439 |
-
updated_rcs_key = (
|
1440 |
-
f"{metrics_selected}#@{panel_selected}#@{channel_name}"
|
1441 |
-
)
|
1442 |
-
|
1443 |
-
if updated_rcs and updated_rcs_key in list(updated_rcs.keys()):
|
1444 |
-
K = updated_rcs[updated_rcs_key]["K"]
|
1445 |
-
b = updated_rcs[updated_rcs_key]["b"]
|
1446 |
-
a = updated_rcs[updated_rcs_key]["a"]
|
1447 |
-
x0 = updated_rcs[updated_rcs_key]["x0"]
|
1448 |
-
else:
|
1449 |
-
K = st.session_state["rcs"][col]["K"]
|
1450 |
-
b = st.session_state["rcs"][col]["b"]
|
1451 |
-
a = st.session_state["rcs"][col]["a"]
|
1452 |
-
x0 = st.session_state["rcs"][col]["x0"]
|
1453 |
-
|
1454 |
-
x_plot = np.linspace(0, 5 * x_actual.sum(), 200)
|
1455 |
-
|
1456 |
-
# Append current_channel_spends to the end of x_plot
|
1457 |
-
x_plot = np.append(x_plot, current_channel_spends)
|
1458 |
-
|
1459 |
-
x, y, marginal_roi = [], [], []
|
1460 |
-
for x_p in x_plot:
|
1461 |
-
x.append(x_p * x_actual / x_actual.sum())
|
1462 |
-
|
1463 |
-
for index in range(len(x_plot)):
|
1464 |
-
y.append(s_curve(x[index] / 10**power, K, b, a, x0))
|
1465 |
-
|
1466 |
-
for index in range(len(x_plot)):
|
1467 |
-
marginal_roi.append(
|
1468 |
-
a
|
1469 |
-
* y[index]
|
1470 |
-
* (1 - y[index] / np.maximum(K, np.finfo(float).eps))
|
1471 |
-
)
|
1472 |
-
|
1473 |
-
x = (
|
1474 |
-
np.sum(x, axis=1)
|
1475 |
-
* st.session_state["scenario"].channels[col].conversion_rate
|
1476 |
-
)
|
1477 |
-
y = np.sum(y, axis=1)
|
1478 |
-
marginal_roi = (
|
1479 |
-
np.average(marginal_roi, axis=1)
|
1480 |
-
/ st.session_state["scenario"].channels[col].conversion_rate
|
1481 |
-
)
|
1482 |
-
|
1483 |
-
roi = y / np.maximum(x, np.finfo(float).eps)
|
1484 |
-
|
1485 |
-
roi_current, marginal_roi_current = roi[-1], marginal_roi[-1]
|
1486 |
-
x, y, roi, marginal_roi = (
|
1487 |
-
x[:-1],
|
1488 |
-
y[:-1],
|
1489 |
-
roi[:-1],
|
1490 |
-
marginal_roi[:-1],
|
1491 |
-
) # Drop data for current spends
|
1492 |
-
|
1493 |
-
start_value, end_value, left_value, right_value = (
|
1494 |
-
find_segment_value(
|
1495 |
-
x,
|
1496 |
-
roi,
|
1497 |
-
marginal_roi,
|
1498 |
-
)
|
1499 |
-
)
|
1500 |
-
|
1501 |
-
rgba = calculate_rgba(
|
1502 |
-
start_value,
|
1503 |
-
end_value,
|
1504 |
-
left_value,
|
1505 |
-
right_value,
|
1506 |
-
current_channel_spends,
|
1507 |
-
)
|
1508 |
-
|
1509 |
-
with bin_placeholder:
|
1510 |
-
st.markdown(
|
1511 |
-
f"""
|
1512 |
-
<div style="
|
1513 |
-
border-radius: 12px;
|
1514 |
-
background-color: {rgba};
|
1515 |
-
padding: 10px;
|
1516 |
-
text-align: center;
|
1517 |
-
color: #006EC0;
|
1518 |
-
">
|
1519 |
-
<p style="margin: 0; font-size: 20px;">ROI: {round(roi_current,1)}</p>
|
1520 |
-
<p style="margin: 0; font-size: 20px;">Marginal ROI: {round(marginal_roi_current,1)}</p>
|
1521 |
-
</div>
|
1522 |
-
""",
|
1523 |
-
unsafe_allow_html=True,
|
1524 |
-
)
|
1525 |
-
|
1526 |
-
st.session_state["project_dct"]["scenario_planner"]["scenario"] = (
|
1527 |
-
st.session_state["scenario"]
|
1528 |
-
)
|
1529 |
-
|
1530 |
-
with st.expander("See Response Curves", expanded=True):
|
1531 |
-
fig = plot_response_curves()
|
1532 |
-
st.plotly_chart(fig, use_container_width=True)
|
1533 |
-
|
1534 |
-
def update_optimization_bounds(channel_name, bound_type):
|
1535 |
-
index = 0 if bound_type == "lower" else 1
|
1536 |
-
update_key = (
|
1537 |
-
f"{channel_name}_b_lower"
|
1538 |
-
if bound_type == "lower"
|
1539 |
-
else f"{channel_name}_b_upper"
|
1540 |
-
)
|
1541 |
-
st.session_state["project_dct"]["scenario_planner"][
|
1542 |
-
unique_key
|
1543 |
-
].channels[channel_name].bounds[index] = st.session_state[update_key]
|
1544 |
-
|
1545 |
-
def update_optimization_bounds_all(bound_type):
|
1546 |
-
index = 0 if bound_type == "lower" else 1
|
1547 |
-
update_key = (
|
1548 |
-
f"all_b_lower" if bound_type == "lower" else f"all_b_upper"
|
1549 |
-
)
|
1550 |
-
|
1551 |
-
for channel_name, _channel in st.session_state["project_dct"][
|
1552 |
-
"scenario_planner"
|
1553 |
-
][unique_key].channels.items():
|
1554 |
-
_channel.bounds[index] = st.session_state[update_key]
|
1555 |
-
|
1556 |
-
with st.expander("Optimization setup"):
|
1557 |
-
bounds_placeholder = st.container()
|
1558 |
-
with bounds_placeholder:
|
1559 |
-
st.subheader("Optimization Bounds")
|
1560 |
-
with st.container():
|
1561 |
-
bounds_columns = st.columns((1, 0.35, 0.35, 1))
|
1562 |
-
with bounds_columns[0]:
|
1563 |
-
st.write("##")
|
1564 |
-
st.write("Update all channels")
|
1565 |
-
|
1566 |
-
with bounds_columns[1]:
|
1567 |
-
st.number_input(
|
1568 |
-
"Lower",
|
1569 |
-
min_value=-100,
|
1570 |
-
max_value=500,
|
1571 |
-
key=f"all_b_lower",
|
1572 |
-
# label_visibility="hidden",
|
1573 |
-
on_change=update_optimization_bounds_all,
|
1574 |
-
args=("lower",),
|
1575 |
-
step=5,
|
1576 |
-
value=-10,
|
1577 |
-
)
|
1578 |
-
|
1579 |
-
with bounds_columns[2]:
|
1580 |
-
st.number_input(
|
1581 |
-
"Higher",
|
1582 |
-
value=10,
|
1583 |
-
min_value=-100,
|
1584 |
-
max_value=500,
|
1585 |
-
key=f"all_b_upper",
|
1586 |
-
# label_visibility="hidden",
|
1587 |
-
on_change=update_optimization_bounds_all,
|
1588 |
-
args=("upper",),
|
1589 |
-
step=5,
|
1590 |
-
)
|
1591 |
-
st.divider()
|
1592 |
-
|
1593 |
-
st.write("#### Channel wise bounds")
|
1594 |
-
# st.divider()
|
1595 |
-
# bounds_columns = st.columns((1, 0.35, 0.35, 1))
|
1596 |
-
|
1597 |
-
# with bounds_columns[0]:
|
1598 |
-
# st.write("Channel")
|
1599 |
-
# with bounds_columns[1]:
|
1600 |
-
# st.write("Lower")
|
1601 |
-
# with bounds_columns[2]:
|
1602 |
-
# st.write("Upper")
|
1603 |
-
# st.divider()
|
1604 |
-
|
1605 |
-
for channel_name, _channel in st.session_state["project_dct"][
|
1606 |
-
"scenario_planner"
|
1607 |
-
][unique_key].channels.items():
|
1608 |
-
st.session_state[f"{channel_name}_b_lower"] = _channel.bounds[0]
|
1609 |
-
st.session_state[f"{channel_name}_b_upper"] = _channel.bounds[1]
|
1610 |
-
with bounds_placeholder:
|
1611 |
-
with st.container():
|
1612 |
-
bounds_columns = st.columns((1, 0.35, 0.35, 1))
|
1613 |
-
with bounds_columns[0]:
|
1614 |
-
st.write("##")
|
1615 |
-
st.write(channel_name)
|
1616 |
-
with bounds_columns[1]:
|
1617 |
-
st.number_input(
|
1618 |
-
"Lower",
|
1619 |
-
min_value=-100,
|
1620 |
-
max_value=500,
|
1621 |
-
key=f"{channel_name}_b_lower",
|
1622 |
-
label_visibility="hidden",
|
1623 |
-
on_change=update_optimization_bounds,
|
1624 |
-
args=(
|
1625 |
-
channel_name,
|
1626 |
-
"lower",
|
1627 |
-
),
|
1628 |
-
)
|
1629 |
-
|
1630 |
-
with bounds_columns[2]:
|
1631 |
-
st.number_input(
|
1632 |
-
"Higher",
|
1633 |
-
min_value=-100,
|
1634 |
-
max_value=500,
|
1635 |
-
key=f"{channel_name}_b_upper",
|
1636 |
-
label_visibility="hidden",
|
1637 |
-
on_change=update_optimization_bounds,
|
1638 |
-
args=(
|
1639 |
-
channel_name,
|
1640 |
-
"upper",
|
1641 |
-
),
|
1642 |
-
)
|
1643 |
-
|
1644 |
-
st.divider()
|
1645 |
-
_columns = st.columns(2)
|
1646 |
-
with _columns[0]:
|
1647 |
-
st.subheader("Save Scenario")
|
1648 |
-
scenario_name = st.text_input(
|
1649 |
-
"Scenario name",
|
1650 |
-
key="scenario_input",
|
1651 |
-
placeholder="Scenario name",
|
1652 |
-
label_visibility="collapsed",
|
1653 |
-
)
|
1654 |
-
st.button(
|
1655 |
-
"Save",
|
1656 |
-
on_click=lambda: save_scenario(scenario_name),
|
1657 |
-
disabled=len(st.session_state["scenario_input"]) == 0,
|
1658 |
-
)
|
1659 |
-
|
1660 |
-
summary_df = pd.DataFrame(st.session_state["acutual_predicted"])
|
1661 |
-
summary_df.drop_duplicates(
|
1662 |
-
subset="Channel_name", keep="last", inplace=True
|
1663 |
-
)
|
1664 |
-
|
1665 |
-
summary_df_sorted = summary_df.sort_values(by="Delta", ascending=False)
|
1666 |
-
summary_df_sorted["Delta_percent"] = np.round(
|
1667 |
-
(
|
1668 |
-
(
|
1669 |
-
summary_df_sorted["Optimized_spend"]
|
1670 |
-
/ summary_df_sorted["Actual_spend"]
|
1671 |
-
)
|
1672 |
-
- 1
|
1673 |
-
)
|
1674 |
-
* 100,
|
1675 |
-
2,
|
1676 |
-
)
|
1677 |
-
|
1678 |
-
with open("summary_df.pkl", "wb") as f:
|
1679 |
-
pickle.dump(summary_df_sorted, f)
|
1680 |
-
# st.dataframe(summary_df_sorted)
|
1681 |
-
# ___columns=st.columns(3)
|
1682 |
-
# with ___columns[2]:
|
1683 |
-
# fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent')
|
1684 |
-
# st.plotly_chart(fig,use_container_width=True)
|
1685 |
-
# with ___columns[0]:
|
1686 |
-
# fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend')
|
1687 |
-
# st.plotly_chart(fig,use_container_width=True)
|
1688 |
-
# with ___columns[1]:
|
1689 |
-
# fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend')
|
1690 |
-
# st.plotly_chart(fig,use_container_width=True)
|
1691 |
-
|
1692 |
-
elif auth_status == False:
|
1693 |
-
st.error("Username/Password is incorrect")
|
1694 |
-
|
1695 |
-
if auth_status != True:
|
1696 |
-
try:
|
1697 |
-
username_forgot_pw, email_forgot_password, random_password = (
|
1698 |
-
authenticator.forgot_password("Forgot password")
|
1699 |
-
)
|
1700 |
-
if username_forgot_pw:
|
1701 |
-
st.session_state["config"]["credentials"]["usernames"][
|
1702 |
-
username_forgot_pw
|
1703 |
-
]["password"] = stauth.Hasher([random_password]).generate()[0]
|
1704 |
-
send_email(email_forgot_password, random_password)
|
1705 |
-
st.success("New password sent securely")
|
1706 |
-
# Random password to be transferred to user securely
|
1707 |
-
elif username_forgot_pw == False:
|
1708 |
-
st.error("Username not found")
|
1709 |
-
except Exception as e:
|
1710 |
-
st.error(e)
|
1711 |
-
|
1712 |
-
update_db("9_Scenario_Planner.py")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|