BlendMMM commited on
Commit
fdbbbbf
·
verified ·
1 Parent(s): ef850e9

Delete pages

Browse files
pages/10_Saved_Scenarios.py DELETED
@@ -1,407 +0,0 @@
1
- import streamlit as st
2
- from numerize.numerize import numerize
3
- import io
4
- import pandas as pd
5
- from utilities import (
6
- format_numbers,
7
- decimal_formater,
8
- channel_name_formating,
9
- load_local_css,
10
- set_header,
11
- initialize_data,
12
- load_authenticator,
13
- )
14
- from openpyxl import Workbook
15
- from openpyxl.styles import Alignment, Font, PatternFill
16
- import pickle
17
- import streamlit_authenticator as stauth
18
- import yaml
19
- from yaml import SafeLoader
20
- from classes import class_from_dict
21
- from utilities import update_db
22
-
23
- st.set_page_config(layout="wide")
24
- load_local_css("styles.css")
25
- set_header()
26
-
27
- # for k, v in st.session_state.items():
28
- # if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
29
- # st.session_state[k] = v
30
-
31
-
32
- def create_scenario_summary(scenario_dict):
33
- summary_rows = []
34
- for channel_dict in scenario_dict["channels"]:
35
- name_mod = channel_name_formating(channel_dict["name"])
36
- summary_rows.append(
37
- [
38
- name_mod,
39
- channel_dict.get("actual_total_spends")
40
- * channel_dict.get("conversion_rate"),
41
- channel_dict.get("modified_total_spends")
42
- * channel_dict.get("conversion_rate"),
43
- channel_dict.get("actual_total_sales"),
44
- channel_dict.get("modified_total_sales"),
45
- channel_dict.get("actual_total_sales")
46
- / (
47
- channel_dict.get("actual_total_spends")
48
- * channel_dict.get("conversion_rate")
49
- ),
50
- channel_dict.get("modified_total_sales")
51
- / (
52
- channel_dict.get("modified_total_spends")
53
- * channel_dict.get("conversion_rate")
54
- ),
55
- channel_dict.get("actual_mroi"),
56
- channel_dict.get("modified_mroi"),
57
- channel_dict.get("actual_total_spends")
58
- * channel_dict.get("conversion_rate")
59
- / channel_dict.get("actual_total_sales"),
60
- channel_dict.get("modified_total_spends")
61
- * channel_dict.get("conversion_rate")
62
- / channel_dict.get("modified_total_sales"),
63
- ]
64
- )
65
-
66
- summary_rows.append(
67
- [
68
- "Total",
69
- scenario_dict.get("actual_total_spends"),
70
- scenario_dict.get("modified_total_spends"),
71
- scenario_dict.get("actual_total_sales"),
72
- scenario_dict.get("modified_total_sales"),
73
- scenario_dict.get("actual_total_sales")
74
- / scenario_dict.get("actual_total_spends"),
75
- scenario_dict.get("modified_total_sales")
76
- / scenario_dict.get("modified_total_spends"),
77
- "-",
78
- "-",
79
- scenario_dict.get("actual_total_spends")
80
- / scenario_dict.get("actual_total_sales"),
81
- scenario_dict.get("modified_total_spends")
82
- / scenario_dict.get("modified_total_sales"),
83
- ]
84
- )
85
-
86
- columns_index = pd.MultiIndex.from_product(
87
- [[""], ["Channel"]], names=["first", "second"]
88
- )
89
- columns_index = columns_index.append(
90
- pd.MultiIndex.from_product(
91
- [
92
- ["Spends", "NRPU", "ROI", "MROI", "Spend per NRPU"],
93
- ["Actual", "Simulated"],
94
- ],
95
- names=["first", "second"],
96
- )
97
- )
98
- return pd.DataFrame(summary_rows, columns=columns_index)
99
-
100
-
101
- def summary_df_to_worksheet(df, ws):
102
- heading_fill = PatternFill(
103
- fill_type="solid", start_color="FF11B6BD", end_color="FF11B6BD"
104
- )
105
- for j, header in enumerate(df.columns.values):
106
- col = j + 1
107
- for i in range(1, 3):
108
- ws.cell(row=i, column=j + 1, value=header[i - 1]).font = Font(
109
- bold=True, color="FF11B6BD"
110
- )
111
- ws.cell(row=i, column=j + 1).fill = heading_fill
112
- if col > 1 and (col - 6) % 5 == 0:
113
- ws.merge_cells(start_row=1, end_row=1, start_column=col - 3, end_column=col)
114
- ws.cell(row=1, column=col).alignment = Alignment(horizontal="center")
115
- for i, row in enumerate(df.itertuples()):
116
- for j, value in enumerate(row):
117
- if j == 0:
118
- continue
119
- elif (j - 2) % 4 == 0 or (j - 3) % 4 == 0:
120
- ws.cell(row=i + 3, column=j, value=value).number_format = "$#,##0.0"
121
- else:
122
- ws.cell(row=i + 3, column=j, value=value)
123
-
124
-
125
- from openpyxl.utils import get_column_letter
126
- from openpyxl.styles import Font, PatternFill
127
- import logging
128
-
129
-
130
- def scenario_df_to_worksheet(df, ws):
131
- heading_fill = PatternFill(
132
- start_color="FF11B6BD", end_color="FF11B6BD", fill_type="solid"
133
- )
134
-
135
- for j, header in enumerate(df.columns.values):
136
- cell = ws.cell(row=1, column=j + 1, value=header)
137
- cell.font = Font(bold=True, color="FF11B6BD")
138
- cell.fill = heading_fill
139
-
140
- for i, row in enumerate(df.itertuples()):
141
- for j, value in enumerate(
142
- row[1:], start=1
143
- ): # Start from index 1 to skip the index column
144
- try:
145
- cell = ws.cell(row=i + 2, column=j, value=value)
146
- if isinstance(value, (int, float)):
147
- cell.number_format = "$#,##0.0"
148
- elif isinstance(value, str):
149
- cell.value = value[:32767]
150
- else:
151
- cell.value = str(value)
152
- except ValueError as e:
153
- logging.error(
154
- f"Error assigning value '{value}' to cell {get_column_letter(j)}{i+2}: {e}"
155
- )
156
- cell.value = None # Assign None to the cell where the error occurred
157
-
158
- return ws
159
-
160
-
161
- def download_scenarios():
162
- """
163
- Makes a excel with all saved scenarios and saves it locally
164
- """
165
- ## create summary page
166
- if len(scenarios_to_download) == 0:
167
- return
168
- wb = Workbook()
169
- wb.iso_dates = True
170
- wb.remove(wb.active)
171
- st.session_state["xlsx_buffer"] = io.BytesIO()
172
- summary_df = None
173
- # print(scenarios_to_download)
174
- for scenario_name in scenarios_to_download:
175
- scenario_dict = st.session_state["saved_scenarios"][scenario_name]
176
- _spends = []
177
- column_names = ["Date"]
178
- _sales = None
179
- dates = None
180
- summary_rows = []
181
- for channel in scenario_dict["channels"]:
182
- if dates is None:
183
- dates = channel.get("dates")
184
- _spends.append(dates)
185
- if _sales is None:
186
- _sales = channel.get("modified_sales")
187
- else:
188
- _sales += channel.get("modified_sales")
189
- _spends.append(
190
- channel.get("modified_spends") * channel.get("conversion_rate")
191
- )
192
- column_names.append(channel.get("name"))
193
-
194
- name_mod = channel_name_formating(channel["name"])
195
- summary_rows.append(
196
- [
197
- name_mod,
198
- channel.get("modified_total_spends")
199
- * channel.get("conversion_rate"),
200
- channel.get("modified_total_sales"),
201
- channel.get("modified_total_sales")
202
- / channel.get("modified_total_spends")
203
- * channel.get("conversion_rate"),
204
- channel.get("modified_mroi"),
205
- channel.get("modified_total_sales")
206
- / channel.get("modified_total_spends")
207
- * channel.get("conversion_rate"),
208
- ]
209
- )
210
- _spends.append(_sales)
211
- column_names.append("NRPU")
212
- scenario_df = pd.DataFrame(_spends).T
213
- scenario_df.columns = column_names
214
- ## write to sheet
215
- ws = wb.create_sheet(scenario_name)
216
- scenario_df_to_worksheet(scenario_df, ws)
217
- summary_rows.append(
218
- [
219
- "Total",
220
- scenario_dict.get("modified_total_spends"),
221
- scenario_dict.get("modified_total_sales"),
222
- scenario_dict.get("modified_total_sales")
223
- / scenario_dict.get("modified_total_spends"),
224
- "-",
225
- scenario_dict.get("modified_total_spends")
226
- / scenario_dict.get("modified_total_sales"),
227
- ]
228
- )
229
- columns_index = pd.MultiIndex.from_product(
230
- [[""], ["Channel"]], names=["first", "second"]
231
- )
232
- columns_index = columns_index.append(
233
- pd.MultiIndex.from_product(
234
- [[scenario_name], ["Spends", "NRPU", "ROI", "MROI", "Spends per NRPU"]],
235
- names=["first", "second"],
236
- )
237
- )
238
- if summary_df is None:
239
- summary_df = pd.DataFrame(summary_rows, columns=columns_index)
240
- summary_df = summary_df.set_index(("", "Channel"))
241
- else:
242
- _df = pd.DataFrame(summary_rows, columns=columns_index)
243
- _df = _df.set_index(("", "Channel"))
244
- summary_df = summary_df.merge(_df, left_index=True, right_index=True)
245
- ws = wb.create_sheet("Summary", 0)
246
- summary_df_to_worksheet(summary_df.reset_index(), ws)
247
- wb.save(st.session_state["xlsx_buffer"])
248
- st.session_state["disable_download_button"] = False
249
-
250
-
251
- def disable_download_button():
252
- st.session_state["disable_download_button"] = True
253
-
254
-
255
- def transform(x):
256
- if x.name == ("", "Channel"):
257
- return x
258
- elif x.name[0] == "ROI" or x.name[0] == "MROI":
259
- return x.apply(
260
- lambda y: (
261
- y
262
- if isinstance(y, str)
263
- else decimal_formater(
264
- format_numbers(y, include_indicator=False, n_decimals=4),
265
- n_decimals=4,
266
- )
267
- )
268
- )
269
- else:
270
- return x.apply(lambda y: y if isinstance(y, str) else format_numbers(y))
271
-
272
-
273
- def delete_scenario():
274
- if selected_scenario in st.session_state["saved_scenarios"]:
275
- del st.session_state["saved_scenarios"][selected_scenario]
276
- with open("../saved_scenarios.pkl", "wb") as f:
277
- pickle.dump(st.session_state["saved_scenarios"], f)
278
-
279
-
280
- def load_scenario():
281
- if selected_scenario in st.session_state["saved_scenarios"]:
282
- st.session_state["scenario"] = class_from_dict(selected_scenario_details)
283
-
284
-
285
- authenticator = st.session_state.get("authenticator")
286
- if authenticator is None:
287
- authenticator = load_authenticator()
288
-
289
- name, authentication_status, username = authenticator.login("Login", "main")
290
- auth_status = st.session_state.get("authentication_status")
291
-
292
- if auth_status == True:
293
- is_state_initiaized = st.session_state.get("initialized", False)
294
- if not is_state_initiaized:
295
- # print("Scenario page state reloaded")
296
- initialize_data()
297
-
298
- saved_scenarios = st.session_state["saved_scenarios"]
299
-
300
- if len(saved_scenarios) == 0:
301
- st.header("No saved scenarios")
302
-
303
- else:
304
- selected_scenario_list = list(saved_scenarios.keys())
305
- if "selected_scenario_selectbox_key" not in st.session_state:
306
- st.session_state["selected_scenario_selectbox_key"] = (
307
- selected_scenario_list[
308
- st.session_state["project_dct"]["saved_scenarios"][
309
- "selected_scenario_selectbox_key"
310
- ]
311
- ]
312
- )
313
-
314
- col_a, col_b = st.columns(2)
315
- selected_scenario = col_a.selectbox(
316
- "Pick a scenario to view details",
317
- selected_scenario_list,
318
- # key="selected_scenario_selectbox_key",
319
- index=st.session_state["project_dct"]["saved_scenarios"][
320
- "selected_scenario_selectbox_key"
321
- ],
322
- )
323
- st.session_state["project_dct"]["saved_scenarios"][
324
- "selected_scenario_selectbox_key"
325
- ] = selected_scenario_list.index(selected_scenario)
326
-
327
- scenarios_to_download = col_b.multiselect(
328
- "Select scenarios to download",
329
- list(saved_scenarios.keys()),
330
- on_change=disable_download_button,
331
- )
332
-
333
- with col_a:
334
- col3, col4 = st.columns(2)
335
-
336
- col4.button(
337
- "Delete scenarios",
338
- on_click=delete_scenario,
339
- use_container_width=True,
340
- )
341
- col3.button(
342
- "Load Scenario",
343
- on_click=load_scenario,
344
- use_container_width=True,
345
- )
346
-
347
- with col_b:
348
- col1, col2 = st.columns(2)
349
-
350
- col1.button(
351
- "Prepare download",
352
- on_click=download_scenarios,
353
- use_container_width=True,
354
- )
355
- col2.download_button(
356
- label="Download Scenarios",
357
- data=st.session_state["xlsx_buffer"].getvalue(),
358
- file_name="scenarios.xlsx",
359
- mime="application/vnd.ms-excel",
360
- disabled=st.session_state["disable_download_button"],
361
- on_click=disable_download_button,
362
- use_container_width=True,
363
- )
364
-
365
- # column_1, column_2, column_3 = st.columns((6, 1, 1))
366
- # with column_1:
367
- # st.header(selected_scenario)
368
- # with column_2:
369
- # st.button("Delete scenarios", on_click=delete_scenario)
370
- # with column_3:
371
- # st.button("Load Scenario", on_click=load_scenario)
372
-
373
- selected_scenario_details = saved_scenarios[selected_scenario]
374
-
375
- pd.set_option("display.max_colwidth", 100)
376
-
377
- st.markdown(
378
- create_scenario_summary(selected_scenario_details)
379
- .transform(transform)
380
- .style.set_table_styles(
381
- [
382
- {"selector": "th", "props": [("background-color", "#11B6BD")]},
383
- {
384
- "selector": "tr:nth-child(even)",
385
- "props": [("background-color", "#11B6BD")],
386
- },
387
- ]
388
- )
389
- .to_html(),
390
- unsafe_allow_html=True,
391
- )
392
-
393
- elif auth_status == False:
394
- st.error("Username/Password is incorrect")
395
-
396
- if auth_status != True:
397
- try:
398
- username_forgot_pw, email_forgot_password, random_password = (
399
- authenticator.forgot_password("Forgot password")
400
- )
401
- if username_forgot_pw:
402
- st.success("New password sent securely")
403
- # Random password to be transferred to user securely
404
- elif username_forgot_pw == False:
405
- st.error("Username not found")
406
- except Exception as e:
407
- st.error(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/11_Optimized_Result_Analysis.py DELETED
@@ -1,453 +0,0 @@
1
- import streamlit as st
2
- from numerize.numerize import numerize
3
- import pandas as pd
4
- from utilities import (format_numbers,decimal_formater,
5
- load_local_css,set_header,
6
- initialize_data,
7
- load_authenticator)
8
- import pickle
9
- import streamlit_authenticator as stauth
10
- import yaml
11
- from yaml import SafeLoader
12
- from classes import class_from_dict
13
- import plotly.express as px
14
- import numpy as np
15
- import plotly.graph_objects as go
16
- import pandas as pd
17
- from plotly.subplots import make_subplots
18
- import sqlite3
19
- from utilities import update_db
20
- def format_number(x):
21
- if x >= 1_000_000:
22
- return f'{x / 1_000_000:.2f}M'
23
- elif x >= 1_000:
24
- return f'{x / 1_000:.2f}K'
25
- else:
26
- return f'{x:.2f}'
27
-
28
- def summary_plot(data, x, y, title, text_column, color, format_as_percent=False, format_as_decimal=False):
29
- fig = px.bar(data, x=x, y=y, orientation='h',
30
- title=title, text=text_column, color=color)
31
- fig.update_layout(showlegend=False)
32
- data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
33
-
34
- # Update the format of the displayed text based on the chosen format
35
- if format_as_percent:
36
- fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
37
- elif format_as_decimal:
38
- fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
39
- else:
40
- fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
41
-
42
- fig.update_layout(xaxis_title=x, yaxis_title='Channel Name', showlegend=False)
43
- return fig
44
-
45
-
46
- def stacked_summary_plot(data, x, y, title, text_column, color_column, stack_column=None, format_as_percent=False, format_as_decimal=False):
47
- fig = px.bar(data, x=x, y=y, orientation='h',
48
- title=title, text=text_column, color=color_column, facet_col=stack_column)
49
- fig.update_layout(showlegend=False)
50
- data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
51
-
52
- # Update the format of the displayed text based on the chosen format
53
- if format_as_percent:
54
- fig.update_traces(texttemplate='%{text:.0%}', textposition='outside', hovertemplate='%{x:.0%}')
55
- elif format_as_decimal:
56
- fig.update_traces(texttemplate='%{text:.2f}', textposition='outside', hovertemplate='%{x:.2f}')
57
- else:
58
- fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
59
-
60
- fig.update_layout(xaxis_title=x, yaxis_title='', showlegend=False)
61
- return fig
62
-
63
-
64
-
65
- def funnel_plot(data, x, y, title, text_column, color_column, format_as_percent=False, format_as_decimal=False):
66
- data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
67
-
68
- # Round the numeric values in the text column to two decimal points
69
- data[text_column] = data[text_column].round(2)
70
-
71
- # Create a color map for categorical data
72
- color_map = {category: f'rgb({i * 30 % 255},{i * 50 % 255},{i * 70 % 255})' for i, category in enumerate(data[color_column].unique())}
73
-
74
- fig = go.Figure(go.Funnel(
75
- y=data[y],
76
- x=data[x],
77
- text=data[text_column],
78
- marker=dict(color=data[color_column].map(color_map)),
79
- textinfo="value",
80
- hoverinfo='y+x+text'
81
- ))
82
-
83
- # Update the format of the displayed text based on the chosen format
84
- if format_as_percent:
85
- fig.update_layout(title=title, funnelmode="percent")
86
- elif format_as_decimal:
87
- fig.update_layout(title=title, funnelmode="overlay")
88
- else:
89
- fig.update_layout(title=title, funnelmode="group")
90
-
91
- return fig
92
-
93
-
94
- st.set_page_config(layout='wide')
95
- load_local_css('styles.css')
96
- set_header()
97
-
98
- # for k, v in st.session_state.items():
99
- # if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
100
- # st.session_state[k] = v
101
-
102
- st.empty()
103
- st.header('Model Result Analysis')
104
- spends_data=pd.read_excel('Overview_data_test.xlsx')
105
-
106
- with open('summary_df.pkl', 'rb') as file:
107
- summary_df_sorted = pickle.load(file)
108
- #st.write(summary_df_sorted)
109
-
110
- selected_scenario= st.selectbox('Select Saved Scenarios',['S1','S2'])
111
- summary_df_sorted=summary_df_sorted.sort_values(by=['Optimized_spend'],ascending=False)
112
- st.header('Optimized Spends Overview')
113
- ___columns=st.columns(3)
114
- with ___columns[2]:
115
- fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent',color='Channel_name')
116
- st.plotly_chart(fig,use_container_width=True)
117
- with ___columns[0]:
118
- fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend',color='Channel_name')
119
- st.plotly_chart(fig,use_container_width=True)
120
- with ___columns[1]:
121
- fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
122
- st.plotly_chart(fig,use_container_width=False)
123
-
124
- st.header(' Budget Allocation')
125
- summary_df_sorted['Perc_alloted']=np.round(summary_df_sorted['Optimized_spend']/summary_df_sorted['Optimized_spend'].sum(),2)
126
- columns2=st.columns(2)
127
- with columns2[0]:
128
- fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend',color='Channel_name')
129
- st.plotly_chart(fig,use_container_width=True)
130
- with columns2[1]:
131
- fig=summary_plot(summary_df_sorted, x='Perc_alloted', y='Channel_name', title='% Split', text_column='Perc_alloted',color='Channel_name',format_as_percent=True)
132
- st.plotly_chart(fig,use_container_width=True)
133
-
134
-
135
- if 'raw_data' not in st.session_state:
136
- st.session_state['raw_data']=pd.read_excel('raw_data_nov7_combined1.xlsx')
137
- st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['MediaChannelName'].isin(summary_df_sorted['Channel_name'].unique())]
138
- st.session_state['raw_data']=st.session_state['raw_data'][st.session_state['raw_data']['Date'].isin(spends_data["Date"].unique())]
139
-
140
-
141
-
142
- #st.write(st.session_state['raw_data']['ResponseMetricName'])
143
- # st.write(st.session_state['raw_data'])
144
-
145
-
146
- st.header('Response Forecast Overview')
147
- raw_data=st.session_state['raw_data']
148
- effectiveness_overall=raw_data.groupby('ResponseMetricName').agg({'ResponseMetricValue': 'sum'}).reset_index()
149
- effectiveness_overall['Efficiency']=effectiveness_overall['ResponseMetricValue'].map(lambda x: x/raw_data['Media Spend'].sum() )
150
- # st.write(effectiveness_overall)
151
-
152
- columns6=st.columns(3)
153
-
154
- effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False,inplace=True)
155
- effectiveness_overall=np.round(effectiveness_overall,2)
156
- effectiveness_overall['ResponseMetric'] = effectiveness_overall['ResponseMetricName'].apply(lambda x: 'BAU' if 'BAU' in x else ('Gamified' if 'Gamified' in x else x))
157
- # effectiveness_overall=np.where(effectiveness_overall[effectiveness_overall['ResponseMetricName']=="Adjusted Account Approval BAU"],"Adjusted Account Approval BAU",effectiveness_overall['ResponseMetricName'])
158
-
159
- effectiveness_overall.replace({'ResponseMetricName':{'BAU approved clients - Appsflyer':'Approved clients - Appsflyer',
160
- 'Gamified approved clients - Appsflyer':'Approved clients - Appsflyer'}},inplace=True)
161
-
162
- # st.write(effectiveness_overall.sort_values(by=['ResponseMetricValue'],ascending=False))
163
-
164
-
165
- condition = effectiveness_overall['ResponseMetricName'] == "Adjusted Account Approval BAU"
166
- condition1= effectiveness_overall['ResponseMetricName'] == "Approved clients - Appsflyer"
167
- effectiveness_overall['ResponseMetric'] = np.where(condition, "Adjusted Account Approval BAU", effectiveness_overall['ResponseMetric'])
168
-
169
- effectiveness_overall['ResponseMetricName'] = np.where(condition1, "Approved clients - Appsflyer (BAU, Gamified)", effectiveness_overall['ResponseMetricName'])
170
- # effectiveness_overall=pd.DataFrame({'ResponseMetricName':["App Installs - Appsflyer",'Account Requests - Appsflyer',
171
- # 'Total Adjusted Account Approval','Adjusted Account Approval BAU',
172
- # 'Approved clients - Appsflyer','Approved clients - Appsflyer'],
173
- # 'ResponseMetricValue':[683067,367020,112315,79768,36661,16834],
174
- # 'Efficiency':[1.24,0.67,0.2,0.14,0.07,0.03],
175
- custom_colors = {
176
- 'App Installs - Appsflyer': 'rgb(255, 135, 0)', # Steel Blue (Blue)
177
- 'Account Requests - Appsflyer': 'rgb(125, 239, 161)', # Cornflower Blue (Blue)
178
- 'Adjusted Account Approval': 'rgb(129, 200, 255)', # Dodger Blue (Blue)
179
- 'Adjusted Account Approval BAU': 'rgb(255, 207, 98)', # Light Sky Blue (Blue)
180
- 'Approved clients - Appsflyer': 'rgb(0, 97, 198)', # Light Blue (Blue)
181
- "BAU": 'rgb(41, 176, 157)', # Steel Blue (Blue)
182
- "Gamified": 'rgb(213, 218, 229)' # Silver (Gray)
183
- # Add more categories and their respective shades of blue as needed
184
- }
185
-
186
-
187
-
188
-
189
-
190
-
191
- with columns6[0]:
192
- revenue=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Total Approved Accounts - Revenue']['ResponseMetricValue']).iloc[0]
193
- revenue=round(revenue / 1_000_000, 2)
194
-
195
- # st.metric('Total Revenue', f"${revenue} M")
196
- # with columns6[1]:
197
- # BAU=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='BAU approved clients - Revenue']['ResponseMetricValue']).iloc[0]
198
- # BAU=round(BAU / 1_000_000, 2)
199
- # st.metric('BAU approved clients - Revenue', f"${BAU} M")
200
- # with columns6[2]:
201
- # Gam=(effectiveness_overall[effectiveness_overall['ResponseMetricName']=='Gamified approved clients - Revenue']['ResponseMetricValue']).iloc[0]
202
- # Gam=round(Gam / 1_000_000, 2)
203
- # st.metric('Gamified approved clients - Revenue', f"${Gam} M")
204
-
205
- # st.write(effectiveness_overall)
206
- data = {'Revenue': ['BAU approved clients - Revenue', 'Gamified approved clients- Revenue'],
207
- 'ResponseMetricValue': [70200000, 1770000],
208
- 'Efficiency':[127.54,3.21]}
209
- df = pd.DataFrame(data)
210
-
211
-
212
- columns9=st.columns([0.60,0.40])
213
- with columns9[0]:
214
- figd = px.pie(df,
215
- names='Revenue',
216
- values='ResponseMetricValue',
217
- hole=0.3, # set the size of the hole in the donut
218
- title='Effectiveness')
219
- figd.update_layout(
220
- margin=dict(l=0, r=0, b=0, t=0),width=100, height=180,legend=dict(
221
- orientation='v', # set orientation to horizontal
222
- x=0, # set x to 0 to move to the left
223
- y=0.8 # adjust y as needed
224
- )
225
- )
226
-
227
- st.plotly_chart(figd, use_container_width=True)
228
-
229
- with columns9[1]:
230
- figd1 = px.pie(df,
231
- names='Revenue',
232
- values='Efficiency',
233
- hole=0.3, # set the size of the hole in the donut
234
- title='Efficiency')
235
- figd1.update_layout(
236
- margin=dict(l=0, r=0, b=0, t=0),width=100,height=180,showlegend=False
237
- )
238
- st.plotly_chart(figd1, use_container_width=True)
239
-
240
- effectiveness_overall['Response Metric Name']=effectiveness_overall['ResponseMetricName']
241
-
242
-
243
-
244
- columns4= st.columns([0.55,0.45])
245
- with columns4[0]:
246
- fig=px.funnel(effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
247
- 'BAU approved clients - Revenue',
248
- 'Gamified approved clients - Revenue',
249
- "Total Approved Accounts - Appsflyer"]))],
250
- x='ResponseMetricValue', y='Response Metric Name',color='ResponseMetric',
251
- color_discrete_map=custom_colors,title='Effectiveness',
252
- labels=None)
253
- custom_y_labels=['App Installs - Appsflyer','Account Requests - Appsflyer','Adjusted Account Approval','Adjusted Account Approval BAU',
254
- "Approved clients - Appsflyer (BAU, Gamified)"
255
- ]
256
- fig.update_layout(showlegend=False,
257
- yaxis=dict(
258
- tickmode='array',
259
- ticktext=custom_y_labels,
260
- )
261
- )
262
- fig.update_traces(textinfo='value', textposition='inside', texttemplate='%{x:.2s} ', hoverinfo='y+x+percent initial')
263
-
264
- last_trace_index = len(fig.data) - 1
265
- fig.update_traces(marker=dict(line=dict(color='black', width=2)), selector=dict(marker=dict(color='blue')))
266
-
267
- st.plotly_chart(fig,use_container_width=True)
268
-
269
-
270
-
271
-
272
-
273
- with columns4[1]:
274
-
275
- # Your existing code for creating the bar chart
276
- fig1 = px.bar((effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue',
277
- 'BAU approved clients - Revenue',
278
- 'Gamified approved clients - Revenue',
279
- "Total Approved Accounts - Appsflyer"]))]).sort_values(by='ResponseMetricValue'),
280
- x='Efficiency', y='Response Metric Name',
281
- color_discrete_map=custom_colors, color='ResponseMetric',
282
- labels=None,text_auto=True,title='Efficiency'
283
- )
284
-
285
- # Update layout and traces
286
- fig1.update_traces(customdata=effectiveness_overall['Efficiency'],
287
- textposition='auto')
288
- fig1.update_layout(showlegend=False)
289
- fig1.update_yaxes(title='',showticklabels=False)
290
- fig1.update_xaxes(title='',showticklabels=False)
291
- fig1.update_xaxes(tickfont=dict(size=20))
292
- fig1.update_yaxes(tickfont=dict(size=20))
293
- st.plotly_chart(fig1, use_container_width=True)
294
-
295
-
296
- effectiveness_overall_revenue=pd.DataFrame({'ResponseMetricName':['Approved Clients','Approved Clients'],
297
- 'ResponseMetricValue':[70201070,1768900],
298
- 'Efficiency':[127.54,3.21],
299
- 'ResponseMetric':['BAU','Gamified']
300
- })
301
- # from plotly.subplots import make_subplots
302
- # fig = make_subplots(rows=1, cols=2,
303
- # subplot_titles=["Effectiveness", "Efficiency"])
304
-
305
- # # Add first plot as subplot
306
- # fig.add_trace(go.Funnel(
307
- # x = fig.data[0].x,
308
- # y = fig.data[0].y,
309
- # textinfo = 'value+percent initial',
310
- # hoverinfo = 'x+y+percent initial'
311
- # ), row=1, col=1)
312
-
313
- # # Update layout for first subplot
314
- # fig.update_xaxes(title_text="Response Metric Value", row=1, col=1)
315
- # fig.update_yaxes(ticktext = custom_y_labels, row=1, col=1)
316
-
317
- # # Add second plot as subplot
318
- # fig.add_trace(go.Bar(
319
- # x = fig1.data[0].x,
320
- # y = fig1.data[0].y,
321
- # customdata = fig1.data[0].customdata,
322
- # textposition = 'auto'
323
- # ), row=1, col=2)
324
-
325
- # # Update layout for second subplot
326
- # fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
327
- # fig.update_yaxes(title='', showticklabels=False, row=1, col=2)
328
-
329
- # fig.update_layout(height=600, width=800, title_text="Key Metrics")
330
- # st.plotly_chart(fig)
331
-
332
-
333
- st.header('Return Forecast by Media Channel')
334
- with st.expander("Return Forecast by Media Channel"):
335
- metric_data=[val for val in list(st.session_state['raw_data']['ResponseMetricName'].unique()) if val!=np.NaN]
336
- # st.write(metric_data)
337
- metric=st.selectbox('Select Metric',metric_data,index=1)
338
-
339
- selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
340
- # st.dataframe(selected_metric.head(2))
341
- selected_metric=st.session_state['raw_data'][st.session_state['raw_data']['ResponseMetricName']==metric]
342
- effectiveness=selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()
343
- effectiveness_df=pd.DataFrame({'Channel':effectiveness.index,"ResponseMetricValue":effectiveness.values})
344
-
345
- summary_df_sorted=summary_df_sorted.merge(effectiveness_df,left_on="Channel_name",right_on='Channel')
346
-
347
- #
348
- summary_df_sorted['Efficiency'] = summary_df_sorted['ResponseMetricValue'] / summary_df_sorted['Optimized_spend']
349
- summary_df_sorted=summary_df_sorted.sort_values(by='Optimized_spend',ascending=True)
350
- #st.dataframe(summary_df_sorted)
351
-
352
- channel_colors = px.colors.qualitative.Plotly
353
-
354
- fig = make_subplots(rows=1, cols=3, subplot_titles=('Optimized Spends', 'Effectiveness', 'Efficiency'), horizontal_spacing=0.05)
355
-
356
- for i, channel in enumerate(summary_df_sorted['Channel_name'].unique()):
357
- channel_df = summary_df_sorted[summary_df_sorted['Channel_name'] == channel]
358
- channel_color = channel_colors[i % len(channel_colors)]
359
-
360
- fig.add_trace(go.Bar(x=channel_df['Optimized_spend'],
361
- y=channel_df['Channel_name'],
362
- text=channel_df['Optimized_spend'].apply(format_number),
363
- marker_color=channel_color,
364
- orientation='h'), row=1, col=1)
365
-
366
- fig.add_trace(go.Bar(x=channel_df['ResponseMetricValue'],
367
- y=channel_df['Channel_name'],
368
- text=channel_df['ResponseMetricValue'].apply(format_number),
369
- marker_color=channel_color,
370
- orientation='h', showlegend=False), row=1, col=2)
371
-
372
- fig.add_trace(go.Bar(x=channel_df['Efficiency'],
373
- y=channel_df['Channel_name'],
374
- text=channel_df['Efficiency'].apply(format_number),
375
- marker_color=channel_color,
376
- orientation='h', showlegend=False), row=1, col=3)
377
-
378
- fig.update_layout(
379
- height=600,
380
- width=900,
381
- title='Media Channel Performance',
382
- showlegend=False
383
- )
384
-
385
- fig.update_yaxes(showticklabels=False ,row=1, col=2 )
386
- fig.update_yaxes(showticklabels=False, row=1, col=3)
387
-
388
- fig.update_xaxes(showticklabels=False, row=1, col=1)
389
- fig.update_xaxes(showticklabels=False, row=1, col=2)
390
- fig.update_xaxes(showticklabels=False, row=1, col=3)
391
-
392
-
393
- st.plotly_chart(fig, use_container_width=True)
394
-
395
-
396
-
397
- # columns= st.columns(3)
398
- # with columns[0]:
399
- # fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='', text_column='Optimized_spend',color='Channel_name')
400
- # st.plotly_chart(fig,use_container_width=True)
401
- # with columns[1]:
402
-
403
- # # effectiveness=(selected_metric.groupby(by=['MediaChannelName'])['ResponseMetricValue'].sum()).values
404
- # # effectiveness_df=pd.DataFrame({'Channel':st.session_state['raw_data']['MediaChannelName'].unique(),"ResponseMetricValue":effectiveness})
405
- # # # effectiveness.reset_index(inplace=True)
406
- # # # st.dataframe(effectiveness.head())
407
-
408
-
409
- # fig=summary_plot(summary_df_sorted, x='ResponseMetricValue', y='Channel_name', title='Effectiveness', text_column='ResponseMetricValue',color='Channel_name')
410
- # st.plotly_chart(fig,use_container_width=True)
411
-
412
- # with columns[2]:
413
- # fig=summary_plot(summary_df_sorted, x='Efficiency', y='Channel_name', title='Efficiency', text_column='Efficiency',color='Channel_name',format_as_decimal=True)
414
- # st.plotly_chart(fig,use_container_width=True)
415
-
416
-
417
- # Create figure with subplots
418
- # fig = make_subplots(rows=1, cols=2)
419
-
420
- # # Add funnel plot to subplot 1
421
- # fig.add_trace(
422
- # go.Funnel(
423
- # x=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricValue'],
424
- # y=effectiveness_overall[~(effectiveness_overall['ResponseMetricName'].isin(['Total Approved Accounts - Revenue', 'BAU approved clients - Revenue', 'Gamified approved clients - Revenue', "Total Approved Accounts - Appsflyer"]))]['ResponseMetricName'],
425
- # textposition="inside",
426
- # texttemplate="%{x:.2s}",
427
- # customdata=effectiveness_overall['Efficiency'],
428
- # hovertemplate="%{customdata:.2f}<extra></extra>"
429
- # ),
430
- # row=1, col=1
431
- # )
432
-
433
- # # Add bar plot to subplot 2
434
- # fig.add_trace(
435
- # go.Bar(
436
- # x=effectiveness_overall.sort_values(by='ResponseMetricValue')['Efficiency'],
437
- # y=effectiveness_overall.sort_values(by='ResponseMetricValue')['ResponseMetricName'],
438
- # marker_color=effectiveness_overall['ResponseMetric'],
439
- # customdata=effectiveness_overall['Efficiency'],
440
- # hovertemplate="%{customdata:.2f}<extra></extra>",
441
- # textposition="outside"
442
- # ),
443
- # row=1, col=2
444
- # )
445
-
446
- # # Update layout
447
- # fig.update_layout(title_text="Effectiveness")
448
- # fig.update_yaxes(title_text="", row=1, col=1)
449
- # fig.update_yaxes(title_text="", showticklabels=False, row=1, col=2)
450
- # fig.update_xaxes(title_text="Efficiency", showticklabels=False, row=1, col=2)
451
-
452
- # # Show figure
453
- # st.plotly_chart(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/1_Data_Import.py DELETED
@@ -1,1547 +0,0 @@
1
- # Importing necessary libraries
2
- import streamlit as st
3
- import os
4
-
5
- # from Home_redirecting import home
6
- from utilities import update_db
7
-
8
- st.set_page_config(
9
- page_title="Data Import",
10
- page_icon=":shark:",
11
- layout="wide",
12
- initial_sidebar_state="collapsed",
13
- )
14
-
15
- import pickle
16
- import pandas as pd
17
- from utilities import set_header, load_local_css
18
- import streamlit_authenticator as stauth
19
- import yaml
20
- from yaml import SafeLoader
21
- import sqlite3
22
-
23
- load_local_css("styles.css")
24
- set_header()
25
-
26
- for k, v in st.session_state.items():
27
- if (
28
- k not in ["logout", "login", "config"]
29
- and not k.startswith("FormSubmitter")
30
- and not k.startswith("data-editor")
31
- ):
32
- st.session_state[k] = v
33
- with open("config.yaml") as file:
34
- config = yaml.load(file, Loader=SafeLoader)
35
- st.session_state["config"] = config
36
- authenticator = stauth.Authenticate(
37
- config["credentials"],
38
- config["cookie"]["name"],
39
- config["cookie"]["key"],
40
- config["cookie"]["expiry_days"],
41
- config["preauthorized"],
42
- )
43
- st.session_state["authenticator"] = authenticator
44
- name, authentication_status, username = authenticator.login("Login", "main")
45
- auth_status = st.session_state.get("authentication_status")
46
-
47
- if auth_status == True:
48
- authenticator.logout("Logout", "main")
49
- is_state_initiaized = st.session_state.get("initialized", False)
50
-
51
- if not is_state_initiaized:
52
-
53
- if "session_name" not in st.session_state:
54
- st.session_state["session_name"] = None
55
-
56
- # Function to validate date column in dataframe
57
-
58
- if "project_dct" not in st.session_state:
59
- # home()
60
- st.warning("please select a project from Home page")
61
- st.stop()
62
-
63
- def validate_date_column(df):
64
- try:
65
- # Attempt to convert the 'Date' column to datetime
66
- df["date"] = pd.to_datetime(df["date"], format="%d-%m-%Y")
67
- return True
68
- except:
69
- return False
70
-
71
- # Function to determine data interval
72
- def determine_data_interval(common_freq):
73
- if common_freq == 1:
74
- return "daily"
75
- elif common_freq == 7:
76
- return "weekly"
77
- elif 28 <= common_freq <= 31:
78
- return "monthly"
79
- else:
80
- return "irregular"
81
-
82
- # Function to read each uploaded Excel file into a pandas DataFrame and stores them in a dictionary
83
- st.cache_resource(show_spinner=False)
84
-
85
- def files_to_dataframes(uploaded_files):
86
- df_dict = {}
87
- for uploaded_file in uploaded_files:
88
- # Extract file name without extension
89
- file_name = uploaded_file.name.rsplit(".", 1)[0]
90
-
91
- # Check for duplicate file names
92
- if file_name in df_dict:
93
- st.warning(
94
- f"Duplicate File: {file_name}. This file will be skipped.",
95
- icon="⚠️",
96
- )
97
- continue
98
-
99
- # Read the file into a DataFrame
100
- df = pd.read_excel(uploaded_file)
101
-
102
- # Convert all column names to lowercase
103
- df.columns = df.columns.str.lower().str.strip()
104
-
105
- # Separate numeric and non-numeric columns
106
- numeric_cols = list(df.select_dtypes(include=["number"]).columns)
107
- non_numeric_cols = [
108
- col
109
- for col in df.select_dtypes(exclude=["number"]).columns
110
- if col.lower() != "date"
111
- ]
112
-
113
- # Check for 'Date' column
114
- if not (validate_date_column(df) and len(numeric_cols) > 0):
115
- st.warning(
116
- f"File Name: {file_name} ➜ Please upload data with Date column in 'DD-MM-YYYY' format and at least one media/exogenous column. This file will be skipped.",
117
- icon="⚠️",
118
- )
119
- continue
120
-
121
- # Check for interval
122
- common_freq = common_freq = (
123
- pd.Series(df["date"].unique())
124
- .diff()
125
- .dt.days.dropna()
126
- .mode()[0]
127
- )
128
- # Calculate the data interval (daily, weekly, monthly or irregular)
129
- interval = determine_data_interval(common_freq)
130
- if interval == "irregular":
131
- st.warning(
132
- f"File Name: {file_name} ➜ Please upload data in daily, weekly or monthly interval. This file will be skipped.",
133
- icon="⚠️",
134
- )
135
- continue
136
-
137
- # Store both DataFrames in the dictionary under their respective keys
138
- df_dict[file_name] = {
139
- "numeric": numeric_cols,
140
- "non_numeric": non_numeric_cols,
141
- "interval": interval,
142
- "df": df,
143
- }
144
-
145
- return df_dict
146
-
147
- # Function to adjust dataframe granularity
148
- def adjust_dataframe_granularity(
149
- df, current_granularity, target_granularity
150
- ):
151
- # Set index
152
- df.set_index("date", inplace=True)
153
-
154
- # Define aggregation rules for resampling
155
- aggregation_rules = {
156
- col: "sum" if pd.api.types.is_numeric_dtype(df[col]) else "first"
157
- for col in df.columns
158
- }
159
-
160
- # Initialize resampled_df
161
- resampled_df = df
162
- if current_granularity == "daily" and target_granularity == "weekly":
163
- resampled_df = df.resample(
164
- "W-MON", closed="left", label="left"
165
- ).agg(aggregation_rules)
166
-
167
- elif (
168
- current_granularity == "daily" and target_granularity == "monthly"
169
- ):
170
- resampled_df = df.resample("MS", closed="left", label="left").agg(
171
- aggregation_rules
172
- )
173
-
174
- elif current_granularity == "daily" and target_granularity == "daily":
175
- resampled_df = df.resample("D").agg(aggregation_rules)
176
-
177
- elif (
178
- current_granularity in ["weekly", "monthly"]
179
- and target_granularity == "daily"
180
- ):
181
- # For higher to lower granularity, distribute numeric and replicate non-numeric values equally across the new period
182
- expanded_data = []
183
- for _, row in df.iterrows():
184
- if current_granularity == "weekly":
185
- period_range = pd.date_range(start=row.name, periods=7)
186
- elif current_granularity == "monthly":
187
- period_range = pd.date_range(
188
- start=row.name, periods=row.name.days_in_month
189
- )
190
-
191
- for date in period_range:
192
- new_row = {}
193
- for col in df.columns:
194
- if pd.api.types.is_numeric_dtype(df[col]):
195
- if current_granularity == "weekly":
196
- new_row[col] = row[col] / 7
197
- elif current_granularity == "monthly":
198
- new_row[col] = (
199
- row[col] / row.name.days_in_month
200
- )
201
- else:
202
- new_row[col] = row[col]
203
- expanded_data.append((date, new_row))
204
-
205
- resampled_df = pd.DataFrame(
206
- [data for _, data in expanded_data],
207
- index=[date for date, _ in expanded_data],
208
- )
209
-
210
- # Reset index
211
- resampled_df = resampled_df.reset_index().rename(
212
- columns={"index": "date"}
213
- )
214
-
215
- return resampled_df
216
-
217
- # Function to clean and extract unique values of Panel_1 and Panel_2
218
- st.cache_resource(show_spinner=False)
219
-
220
- def clean_and_extract_unique_values(files_dict, selections):
221
- all_panel1_values = set()
222
- all_panel2_values = set()
223
-
224
- for file_name, file_data in files_dict.items():
225
- df = file_data["df"]
226
-
227
- # 'Panel_1' and 'Panel_2' selections
228
- selected_panel1 = selections[file_name].get("Panel_1")
229
- selected_panel2 = selections[file_name].get("Panel_2")
230
-
231
- # Clean and standardize Panel_1 column if it exists and is selected
232
- if (
233
- selected_panel1
234
- and selected_panel1 != "N/A"
235
- and selected_panel1 in df.columns
236
- ):
237
- df[selected_panel1] = (
238
- df[selected_panel1]
239
- .str.lower()
240
- .str.strip()
241
- .str.replace("_", " ")
242
- )
243
- all_panel1_values.update(df[selected_panel1].dropna().unique())
244
-
245
- # Clean and standardize Panel_2 column if it exists and is selected
246
- if (
247
- selected_panel2
248
- and selected_panel2 != "N/A"
249
- and selected_panel2 in df.columns
250
- ):
251
- df[selected_panel2] = (
252
- df[selected_panel2]
253
- .str.lower()
254
- .str.strip()
255
- .str.replace("_", " ")
256
- )
257
- all_panel2_values.update(df[selected_panel2].dropna().unique())
258
-
259
- # Update the processed DataFrame back in the dictionary
260
- files_dict[file_name]["df"] = df
261
-
262
- return all_panel1_values, all_panel2_values
263
-
264
- # Function to format values for display
265
- st.cache_resource(show_spinner=False)
266
-
267
- def format_values_for_display(values_list):
268
- # Capitalize the first letter of each word and replace underscores with spaces
269
- formatted_list = [
270
- value.replace("_", " ").title() for value in values_list
271
- ]
272
- # Join values with commas and 'and' before the last value
273
- if len(formatted_list) > 1:
274
- return (
275
- ", ".join(formatted_list[:-1]) + ", and " + formatted_list[-1]
276
- )
277
- elif formatted_list:
278
- return formatted_list[0]
279
- return "No values available"
280
-
281
- # Function to normalizes all data within files_dict to a daily granularity
282
- st.cache(show_spinner=False, allow_output_mutation=True)
283
-
284
- def standardize_data_to_daily(files_dict, selections):
285
- # Normalize all data to a daily granularity using a provided function
286
- files_dict = apply_granularity_to_all(files_dict, "daily", selections)
287
-
288
- # Update the "interval" attribute for each dataset to indicate the new granularity
289
- for files_name, files_data in files_dict.items():
290
- files_data["interval"] = "daily"
291
-
292
- return files_dict
293
-
294
- # Function to apply granularity transformation to all DataFrames in files_dict
295
- st.cache_resource(show_spinner=False)
296
-
297
- def apply_granularity_to_all(
298
- files_dict, granularity_selection, selections
299
- ):
300
- for file_name, file_data in files_dict.items():
301
- df = file_data["df"].copy()
302
-
303
- # Handling when Panel_1 or Panel_2 might be 'N/A'
304
- selected_panel1 = selections[file_name].get("Panel_1")
305
- selected_panel2 = selections[file_name].get("Panel_2")
306
-
307
- # Correcting the segment selection logic & handling 'N/A'
308
- if selected_panel1 != "N/A" and selected_panel2 != "N/A":
309
- unique_combinations = df[
310
- [selected_panel1, selected_panel2]
311
- ].drop_duplicates()
312
- elif selected_panel1 != "N/A":
313
- unique_combinations = df[[selected_panel1]].drop_duplicates()
314
- selected_panel2 = None # Ensure Panel_2 is ignored if N/A
315
- elif selected_panel2 != "N/A":
316
- unique_combinations = df[[selected_panel2]].drop_duplicates()
317
- selected_panel1 = None # Ensure Panel_1 is ignored if N/A
318
- else:
319
- # If both are 'N/A', process the entire dataframe as is
320
- df = adjust_dataframe_granularity(
321
- df, file_data["interval"], granularity_selection
322
- )
323
- files_dict[file_name]["df"] = df
324
- continue # Skip to the next file
325
-
326
- transformed_segments = []
327
- for _, combo in unique_combinations.iterrows():
328
- if selected_panel1 and selected_panel2:
329
- segment = df[
330
- (df[selected_panel1] == combo[selected_panel1])
331
- & (df[selected_panel2] == combo[selected_panel2])
332
- ]
333
- elif selected_panel1:
334
- segment = df[df[selected_panel1] == combo[selected_panel1]]
335
- elif selected_panel2:
336
- segment = df[df[selected_panel2] == combo[selected_panel2]]
337
-
338
- # Adjust granularity of the segment
339
- transformed_segment = adjust_dataframe_granularity(
340
- segment, file_data["interval"], granularity_selection
341
- )
342
- transformed_segments.append(transformed_segment)
343
-
344
- # Combine all transformed segments into a single DataFrame for this file
345
- transformed_df = pd.concat(transformed_segments, ignore_index=True)
346
- files_dict[file_name]["df"] = transformed_df
347
-
348
- return files_dict
349
-
350
- # Function to create main dataframe structure
351
- st.cache_resource(show_spinner=False)
352
-
353
- def create_main_dataframe(
354
- files_dict, all_panel1_values, all_panel2_values, granularity_selection
355
- ):
356
- # Determine the global start and end dates across all DataFrames
357
- global_start = min(
358
- df["df"]["date"].min() for df in files_dict.values()
359
- )
360
- global_end = max(df["df"]["date"].max() for df in files_dict.values())
361
-
362
- # Adjust the date_range generation based on the granularity_selection
363
- if granularity_selection == "weekly":
364
- # Generate a weekly range, with weeks starting on Monday
365
- date_range = pd.date_range(
366
- start=global_start, end=global_end, freq="W-MON"
367
- )
368
- elif granularity_selection == "monthly":
369
- # Generate a monthly range, starting from the first day of each month
370
- date_range = pd.date_range(
371
- start=global_start, end=global_end, freq="MS"
372
- )
373
- else: # Default to daily if not weekly or monthly
374
- date_range = pd.date_range(
375
- start=global_start, end=global_end, freq="D"
376
- )
377
-
378
- # Collect all unique Panel_1 and Panel_2 values, excluding 'N/A'
379
- all_panel1s = all_panel1_values
380
- all_panel2s = all_panel2_values
381
-
382
- # Dynamically build the list of dimensions (Panel_1, Panel_2) to include in the main DataFrame based on availability
383
- dimensions, merge_keys = [], []
384
- if all_panel1s:
385
- dimensions.append(all_panel1s)
386
- merge_keys.append("Panel_1")
387
- if all_panel2s:
388
- dimensions.append(all_panel2s)
389
- merge_keys.append("Panel_2")
390
-
391
- dimensions.append(date_range) # Date range is always included
392
- merge_keys.append("date") # Date range is always included
393
-
394
- # Create a main DataFrame template with the dimensions
395
- main_df = pd.MultiIndex.from_product(
396
- dimensions,
397
- names=[name for name, _ in zip(merge_keys, dimensions)],
398
- ).to_frame(index=False)
399
-
400
- return main_df.reset_index(drop=True)
401
-
402
- # Function to prepare and merge dataFrames
403
- st.cache_resource(show_spinner=False)
404
-
405
- def merge_into_main_df(main_df, files_dict, selections):
406
- for file_name, file_data in files_dict.items():
407
- df = file_data["df"].copy()
408
-
409
- # Rename selected Panel_1 and Panel_2 columns if not 'N/A'
410
- selected_panel1 = selections[file_name].get("Panel_1", "N/A")
411
- selected_panel2 = selections[file_name].get("Panel_2", "N/A")
412
- if selected_panel1 != "N/A":
413
- df.rename(columns={selected_panel1: "Panel_1"}, inplace=True)
414
- if selected_panel2 != "N/A":
415
- df.rename(columns={selected_panel2: "Panel_2"}, inplace=True)
416
-
417
- # Merge current DataFrame into main_df based on 'date', and where applicable, 'Panel_1' and 'Panel_2'
418
- merge_keys = ["date"]
419
- if "Panel_1" in df.columns:
420
- merge_keys.append("Panel_1")
421
- if "Panel_2" in df.columns:
422
- merge_keys.append("Panel_2")
423
- main_df = pd.merge(main_df, df, on=merge_keys, how="left")
424
-
425
- # After all merges, sort by 'date' and reset index for cleanliness
426
- sort_by = ["date"]
427
- if "Panel_1" in main_df.columns:
428
- sort_by.append("Panel_1")
429
- if "Panel_2" in main_df.columns:
430
- sort_by.append("Panel_2")
431
- main_df.sort_values(by=sort_by, inplace=True)
432
- main_df.reset_index(drop=True, inplace=True)
433
-
434
- return main_df
435
-
436
- # Function to categorize column
437
- def categorize_column(column_name):
438
- # Define keywords for each category
439
- internal_keywords = [
440
- "Price",
441
- "Discount",
442
- "product_price",
443
- "cost",
444
- "margin",
445
- "inventory",
446
- "sales",
447
- "revenue",
448
- "turnover",
449
- "expense",
450
- ]
451
- exogenous_keywords = [
452
- "GDP",
453
- "Tax",
454
- "Inflation",
455
- "interest_rate",
456
- "employment_rate",
457
- "exchange_rate",
458
- "consumer_spending",
459
- "retail_sales",
460
- "oil_prices",
461
- "weather",
462
- ]
463
-
464
- # Check if the column name matches any of the keywords for Internal or Exogenous categories
465
-
466
- if (
467
- column_name
468
- in st.session_state["project_dct"]["data_import"]["cat_dct"].keys()
469
- and st.session_state["project_dct"]["data_import"]["cat_dct"][
470
- column_name
471
- ]
472
- is not None
473
- ):
474
-
475
- return st.session_state["project_dct"]["data_import"]["cat_dct"][
476
- column_name
477
- ] # resume project manoj
478
-
479
- else:
480
- for keyword in internal_keywords:
481
- if keyword.lower() in column_name.lower():
482
- return "Internal"
483
- for keyword in exogenous_keywords:
484
- if keyword.lower() in column_name.lower():
485
- return "Exogenous"
486
-
487
- # Default to Media if no match found
488
- return "Media"
489
-
490
- # Function to calculate missing stats and prepare for editable DataFrame
491
- st.cache_resource(show_spinner=False)
492
-
493
- def prepare_missing_stats_df(df):
494
- missing_stats = []
495
- for column in df.columns:
496
- if (
497
- column == "date" or column == "Panel_2" or column == "Panel_1"
498
- ): # Skip Date, Panel_1 and Panel_2 column
499
- continue
500
-
501
- missing = df[column].isnull().sum()
502
- pct_missing = round((missing / len(df)) * 100, 2)
503
-
504
- # Dynamically assign category based on column name
505
- category = categorize_column(column)
506
- # category = "Media" # Keep default bin as Media
507
-
508
- missing_stats.append(
509
- {
510
- "Column": column,
511
- "Missing Values": missing,
512
- "Missing Percentage": pct_missing,
513
- "Impute Method": "Fill with 0", # Default value
514
- "Category": category,
515
- }
516
- )
517
- stats_df = pd.DataFrame(missing_stats)
518
-
519
- return stats_df
520
-
521
- # Function to add API DataFrame details to the files dictionary
522
- st.cache_resource(show_spinner=False)
523
-
524
- def add_api_dataframe_to_dict(main_df, files_dict):
525
- files_dict["API"] = {
526
- "numeric": list(main_df.select_dtypes(include=["number"]).columns),
527
- "non_numeric": [
528
- col
529
- for col in main_df.select_dtypes(exclude=["number"]).columns
530
- if col.lower() != "date"
531
- ],
532
- "interval": determine_data_interval(
533
- pd.Series(main_df["date"].unique())
534
- .diff()
535
- .dt.days.dropna()
536
- .mode()[0]
537
- ),
538
- "df": main_df,
539
- }
540
-
541
- return files_dict
542
-
543
- # Function to reads an API into a DataFrame, parsing specified columns as datetime
544
- @st.cache_resource(show_spinner=False)
545
- def read_API_data():
546
- return pd.read_excel(
547
- r"./upf_data_converted_randomized_resp_metrics.xlsx",
548
- parse_dates=["Date"],
549
- )
550
-
551
- # Function to set the 'Panel_1_Panel_2_Selected' session state variable to False
552
- def set_Panel_1_Panel_2_Selected_false():
553
-
554
- st.session_state["Panel_1_Panel_2_Selected"] = False
555
-
556
- # restoring project_dct to default values when user modify any widjets
557
- st.session_state["project_dct"]["data_import"][
558
- "edited_stats_df"
559
- ] = None
560
- st.session_state["project_dct"]["data_import"]["merged_df"] = None
561
- st.session_state["project_dct"]["data_import"][
562
- "missing_stats_df"
563
- ] = None
564
- st.session_state["project_dct"]["data_import"]["cat_dct"] = {}
565
- st.session_state["project_dct"]["data_import"][
566
- "numeric_columns"
567
- ] = None
568
- st.session_state["project_dct"]["data_import"]["default_df"] = None
569
- st.session_state["project_dct"]["data_import"]["final_df"] = None
570
- st.session_state["project_dct"]["data_import"]["edited_df"] = None
571
-
572
- # Function to serialize and save the objects into a pickle file
573
- @st.cache_resource(show_spinner=False)
574
- def save_to_pickle(file_path, final_df, bin_dict):
575
- # Open the file in write-binary mode and dump the objects
576
- with open(file_path, "wb") as f:
577
- pickle.dump({"final_df": final_df, "bin_dict": bin_dict}, f)
578
- # Data is now saved to file
579
-
580
- # Function to processes the merged_df DataFrame based on operations defined in edited_df
581
- @st.cache_resource(show_spinner=False)
582
- def process_dataframes(merged_df, edited_df, edited_stats_df):
583
- # Ensure there are operations defined by the user
584
- if edited_df.empty:
585
-
586
- return merged_df, edited_stats_df # No operations to apply
587
-
588
- # Perform operations as defined by the user
589
- else:
590
-
591
- for index, row in edited_df.iterrows():
592
- result_column_name = (
593
- f"{row['Column 1']}{row['Operator']}{row['Column 2']}"
594
- )
595
- col1 = row["Column 1"]
596
- col2 = row["Column 2"]
597
- op = row["Operator"]
598
-
599
- # Apply the specified operation
600
- if op == "+":
601
- merged_df[result_column_name] = (
602
- merged_df[col1] + merged_df[col2]
603
- )
604
- elif op == "-":
605
- merged_df[result_column_name] = (
606
- merged_df[col1] - merged_df[col2]
607
- )
608
- elif op == "*":
609
- merged_df[result_column_name] = (
610
- merged_df[col1] * merged_df[col2]
611
- )
612
- elif op == "/":
613
- merged_df[result_column_name] = merged_df[
614
- col1
615
- ] / merged_df[col2].replace(0, 1e-9)
616
-
617
- # Add summary of operation to edited_stats_df
618
- new_row = {
619
- "Column": result_column_name,
620
- "Missing Values": None,
621
- "Missing Percentage": None,
622
- "Impute Method": None,
623
- "Category": row["Category"],
624
- }
625
- new_row_df = pd.DataFrame([new_row])
626
-
627
- # Use pd.concat to add the new_row_df to edited_stats_df
628
- edited_stats_df = pd.concat(
629
- [edited_stats_df, new_row_df], ignore_index=True, axis=0
630
- )
631
-
632
- # Combine column names from edited_df for cleanup
633
- combined_columns = set(edited_df["Column 1"]).union(
634
- set(edited_df["Column 2"])
635
- )
636
-
637
- # Filter out rows in edited_stats_df and drop columns from merged_df
638
- edited_stats_df = edited_stats_df[
639
- ~edited_stats_df["Column"].isin(combined_columns)
640
- ]
641
- merged_df.drop(
642
- columns=list(combined_columns), errors="ignore", inplace=True
643
- )
644
-
645
- return merged_df, edited_stats_df
646
-
647
- # Function to prepare a list of numeric column names and initialize an empty DataFrame with predefined structure
648
- st.cache_resource(show_spinner=False)
649
-
650
- def prepare_numeric_columns_and_default_df(merged_df, edited_stats_df):
651
- # Get columns categorized as 'Response Metrics'
652
- columns_response_metrics = edited_stats_df[
653
- edited_stats_df["Category"] == "Response Metrics"
654
- ]["Column"].tolist()
655
-
656
- # Filter numeric columns, excluding those categorized as 'Response Metrics'
657
- numeric_columns = [
658
- col
659
- for col in merged_df.select_dtypes(include=["number"]).columns
660
- if col not in columns_response_metrics
661
- ]
662
-
663
- # Define the structure of the empty DataFrame
664
- data = {
665
- "Column 1": pd.Series([], dtype="str"),
666
- "Operator": pd.Series([], dtype="str"),
667
- "Column 2": pd.Series([], dtype="str"),
668
- "Category": pd.Series([], dtype="str"),
669
- }
670
- default_df = pd.DataFrame(data)
671
-
672
- return numeric_columns, default_df
673
-
674
- # function to reset to default values in project_dct:
675
-
676
- # Initialize 'final_df' in session state
677
- if "final_df" not in st.session_state:
678
- st.session_state["final_df"] = pd.DataFrame()
679
-
680
- # Initialize 'bin_dict' in session state
681
- if "bin_dict" not in st.session_state:
682
- st.session_state["bin_dict"] = {}
683
-
684
- # Initialize 'Panel_1_Panel_2_Selected' in session state
685
- if "Panel_1_Panel_2_Selected" not in st.session_state:
686
- st.session_state["Panel_1_Panel_2_Selected"] = False
687
-
688
- # Page Title
689
- st.write("") # Top padding
690
- st.title("Data Import")
691
-
692
- conn = sqlite3.connect(
693
- r"DB\User.db", check_same_thread=False
694
- ) # connection with sql db
695
- c = conn.cursor()
696
-
697
- #########################################################################################################################################################
698
- # Create a dictionary to hold all DataFrames and collect user input to specify "Panel_2" and "Panel_1" columns for each file
699
- #########################################################################################################################################################
700
-
701
- # Read the Excel file, parsing 'Date' column as datetime
702
- main_df = read_API_data()
703
-
704
- # Convert all column names to lowercase
705
- main_df.columns = main_df.columns.str.lower().str.strip()
706
-
707
- # File uploader
708
- uploaded_files = st.file_uploader(
709
- "Upload additional data",
710
- type=["xlsx"],
711
- accept_multiple_files=True,
712
- on_change=set_Panel_1_Panel_2_Selected_false,
713
- )
714
-
715
- # Custom HTML for upload instructions
716
- recommendation_html = f"""
717
- <div style="text-align: justify;">
718
- <strong>Recommendation:</strong> For optimal processing, please ensure that all uploaded datasets including panel, media, internal, and exogenous data adhere to the following guidelines: Each dataset must include a <code>Date</code> column formatted as <code>DD-MM-YYYY</code>, be free of missing values.
719
- </div>
720
- """
721
- st.markdown(recommendation_html, unsafe_allow_html=True)
722
-
723
- # Choose Desired Granularity
724
- st.markdown("#### Choose Desired Granularity")
725
- # Granularity Selection
726
-
727
- granularity_selection = st.selectbox(
728
- "Choose Date Granularity",
729
- ["Daily", "Weekly", "Monthly"],
730
- label_visibility="collapsed",
731
- on_change=set_Panel_1_Panel_2_Selected_false,
732
- index=st.session_state["project_dct"]["data_import"][
733
- "granularity_selection"
734
- ], # resume
735
- )
736
-
737
- # st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
738
-
739
- st.session_state["project_dct"]["data_import"]["granularity_selection"] = [
740
- "Daily",
741
- "Weekly",
742
- "Monthly",
743
- ].index(granularity_selection)
744
- # st.write(st.session_state['project_dct']['data_import']['granularity_selection'])
745
- granularity_selection = str(granularity_selection).lower()
746
-
747
- # Convert files to dataframes
748
- files_dict = files_to_dataframes(uploaded_files)
749
-
750
- # Add API Dataframe
751
- if main_df is not None:
752
- files_dict = add_api_dataframe_to_dict(main_df, files_dict)
753
-
754
- # Display a warning message if no files have been uploaded and halt further execution
755
- if not files_dict:
756
- st.warning(
757
- "Please upload at least one file to proceed.",
758
- icon="⚠️",
759
- )
760
- st.stop() # Halts further execution until file is uploaded
761
-
762
- # Select Panel_1 and Panel_2 columns
763
- st.markdown("#### Select Panel columns")
764
- selections = {}
765
- with st.expander("Select Panel columns", expanded=False):
766
- count = (
767
- 0 # Initialize counter to manage the visibility of labels and keys
768
- )
769
- for file_name, file_data in files_dict.items():
770
-
771
- # generatimg project dct keys dynamically
772
- if (
773
- f"Panel_1_selectbox{file_name}"
774
- not in st.session_state["project_dct"]["data_import"].keys()
775
- ):
776
- st.session_state["project_dct"]["data_import"][
777
- f"Panel_1_selectbox{file_name}"
778
- ] = 0
779
-
780
- if (
781
- f"Panel_2_selectbox{file_name}"
782
- not in st.session_state["project_dct"]["data_import"].keys()
783
- ):
784
-
785
- st.session_state["project_dct"]["data_import"][
786
- f"Panel_2_selectbox{file_name}"
787
- ] = 0
788
-
789
- # Determine visibility of the label based on the count
790
- if count == 0:
791
- label_visibility = "visible"
792
- else:
793
- label_visibility = "collapsed"
794
-
795
- # Extract non-numeric columns
796
- non_numeric_cols = file_data["non_numeric"]
797
-
798
- # Prepare Panel_1 and Panel_2 values for dropdown, adding "N/A" as an option
799
- panel1_values = non_numeric_cols + ["N/A"]
800
- panel2_values = non_numeric_cols + ["N/A"]
801
-
802
- # Skip if only one option is available
803
- if len(panel1_values) == 1 and len(panel2_values) == 1:
804
- selected_panel1, selected_panel2 = "N/A", "N/A"
805
- # Update the selections for Panel_1 and Panel_2 for the current file
806
- selections[file_name] = {
807
- "Panel_1": selected_panel1,
808
- "Panel_2": selected_panel2,
809
- }
810
- continue
811
-
812
- # Create layout columns for File Name, Panel_2, and Panel_1 selections
813
- file_name_col, Panel_1_col, Panel_2_col = st.columns([2, 4, 4])
814
-
815
- with file_name_col:
816
- # Display "File Name" label only for the first file
817
- if count == 0:
818
- st.write("File Name")
819
- else:
820
- st.write("")
821
- st.write(file_name) # Display the file name
822
-
823
- with Panel_1_col:
824
- # Display a selectbox for Panel_1 values
825
- selected_panel1 = st.selectbox(
826
- "Select Panel Level 1",
827
- panel2_values,
828
- on_change=set_Panel_1_Panel_2_Selected_false,
829
- label_visibility=label_visibility, # Control visibility of the label
830
- key=f"Panel_1_selectbox{count}", # Ensure unique key for each selectbox
831
- index=st.session_state["project_dct"]["data_import"][
832
- f"Panel_1_selectbox{file_name}"
833
- ],
834
- )
835
-
836
- st.session_state["project_dct"]["data_import"][
837
- f"Panel_1_selectbox{file_name}"
838
- ] = panel2_values.index(selected_panel1)
839
-
840
- with Panel_2_col:
841
- # Display a selectbox for Panel_2 values
842
- selected_panel2 = st.selectbox(
843
- "Select Panel Level 2",
844
- panel1_values,
845
- on_change=set_Panel_1_Panel_2_Selected_false,
846
- label_visibility=label_visibility, # Control visibility of the label
847
- key=f"Panel_2_selectbox{count}", # Ensure unique key for each selectbox
848
- index=st.session_state["project_dct"]["data_import"][
849
- f"Panel_2_selectbox{file_name}"
850
- ],
851
- )
852
-
853
- st.session_state["project_dct"]["data_import"][
854
- f"Panel_2_selectbox{file_name}"
855
- ] = panel1_values.index(selected_panel2)
856
-
857
- # st.write(st.session_state['project_dct']['data_import'][f"Panel_2_selectbox{file_name}"])
858
-
859
- # Skip processing if the same column is selected for both Panel_1 and Panel_2 due to potential data integrity issues
860
-
861
- if selected_panel2 == selected_panel1 and not (
862
- selected_panel2 == "N/A" and selected_panel1 == "N/A"
863
- ):
864
- st.warning(
865
- f"File: {file_name} → The same column cannot serve as both Panel_1 and Panel_2. Please adjust your selections.",
866
- )
867
- selected_panel1, selected_panel2 = "N/A", "N/A"
868
- st.stop()
869
-
870
- # Update the selections for Panel_1 and Panel_2 for the current file
871
- selections[file_name] = {
872
- "Panel_1": selected_panel1,
873
- "Panel_2": selected_panel2,
874
- }
875
-
876
- count += 1 # Increment the counter after processing each file
877
- st.write()
878
- # Accept Panel_1 and Panel_2 selection
879
- accept = st.button(
880
- "Accept and Process", use_container_width=True
881
- ) # resume project manoj
882
-
883
- if (
884
- accept == False
885
- and st.session_state["project_dct"]["data_import"]["edited_stats_df"]
886
- is not None
887
- ):
888
-
889
- # st.write(st.session_state['project_dct'])
890
- st.markdown("#### Unique Panel values")
891
- # Display Panel_1 and Panel_2 values
892
- with st.expander("Unique Panel values"):
893
- st.write("")
894
- st.markdown(
895
- f"""
896
- <style>
897
- .justify-text {{
898
- text-align: justify;
899
- }}
900
- </style>
901
- <div class="justify-text">
902
- <strong>Panel Level 1 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel1_values']}<br>
903
- <strong>Panel Level 2 Values:</strong> {st.session_state['project_dct']['data_import']['formatted_panel2_values']}
904
- </div>
905
- """,
906
- unsafe_allow_html=True,
907
- )
908
-
909
- # Display total Panel_1 and Panel_2
910
- st.write("")
911
- st.markdown(
912
- f"""
913
- <div style="text-align: justify;">
914
- <strong>Number of Level 1 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}<br>
915
- <strong>Number of Level 2 Panels detected:</strong> {len(st.session_state['project_dct']['data_import']['formatted_panel2_values'])}
916
- </div>
917
- """,
918
- unsafe_allow_html=True,
919
- )
920
- st.write("")
921
-
922
- # Create an editable DataFrame in Streamlit
923
-
924
- st.markdown("#### Select Variables Category & Impute Missing Values")
925
-
926
- # data_temp_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
927
-
928
- # with open(data_temp_path,"rb") as f:
929
- # saved_edited_stats_df=pickle.load(f)
930
-
931
- # a=st.data_editor(saved_edited_stats_df)
932
-
933
- merged_df = st.session_state["project_dct"]["data_import"][
934
- "merged_df"
935
- ].copy()
936
-
937
- missing_stats_df = st.session_state["project_dct"]["data_import"][
938
- "missing_stats_df"
939
- ]
940
-
941
- edited_stats_df = st.data_editor(
942
- st.session_state["project_dct"]["data_import"]["edited_stats_df"],
943
- column_config={
944
- "Impute Method": st.column_config.SelectboxColumn(
945
- options=[
946
- "Drop Column",
947
- "Fill with Mean",
948
- "Fill with Median",
949
- "Fill with 0",
950
- ],
951
- required=True,
952
- default="Fill with 0",
953
- ),
954
- "Category": st.column_config.SelectboxColumn(
955
- options=[
956
- "Media",
957
- "Exogenous",
958
- "Internal",
959
- "Response Metrics",
960
- ],
961
- required=True,
962
- default="Media",
963
- ),
964
- },
965
- disabled=["Column", "Missing Values", "Missing Percentage"],
966
- hide_index=True,
967
- use_container_width=True,
968
- key="data-editor-1",
969
- )
970
-
971
- st.session_state["project_dct"]["data_import"]["cat_dct"] = {
972
- col: cat
973
- for col, cat in zip(
974
- edited_stats_df["Column"], edited_stats_df["Category"]
975
- )
976
- }
977
-
978
- for i, row in edited_stats_df.iterrows():
979
- column = row["Column"]
980
- if row["Impute Method"] == "Drop Column":
981
- merged_df.drop(columns=[column], inplace=True)
982
-
983
- elif row["Impute Method"] == "Fill with Mean":
984
- merged_df[column].fillna(
985
- st.session_state["project_dct"]["data_import"][
986
- "merged_df"
987
- ][column].mean(),
988
- inplace=True,
989
- )
990
-
991
- elif row["Impute Method"] == "Fill with Median":
992
- merged_df[column].fillna(
993
- st.session_state["project_dct"]["data_import"][
994
- "merged_df"
995
- ][column].median(),
996
- inplace=True,
997
- )
998
-
999
- elif row["Impute Method"] == "Fill with 0":
1000
- merged_df[column].fillna(0, inplace=True)
1001
-
1002
- # st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
1003
- #########################################################################################################################################################
1004
- # Group columns
1005
- #########################################################################################################################################################
1006
-
1007
- # Display Group columns header
1008
- numeric_columns = st.session_state["project_dct"]["data_import"][
1009
- "numeric_columns"
1010
- ]
1011
- default_df = st.session_state["project_dct"]["data_import"][
1012
- "default_df"
1013
- ]
1014
-
1015
- st.markdown("#### Feature engineering")
1016
-
1017
- edited_df = st.data_editor(
1018
- st.session_state["project_dct"]["data_import"]["edited_df"],
1019
- column_config={
1020
- "Column 1": st.column_config.SelectboxColumn(
1021
- options=numeric_columns,
1022
- required=True,
1023
- width=400,
1024
- ),
1025
- "Operator": st.column_config.SelectboxColumn(
1026
- options=["+", "-", "*", "/"],
1027
- required=True,
1028
- default="+",
1029
- width=100,
1030
- ),
1031
- "Column 2": st.column_config.SelectboxColumn(
1032
- options=numeric_columns,
1033
- required=True,
1034
- default=numeric_columns[0],
1035
- width=400,
1036
- ),
1037
- "Category": st.column_config.SelectboxColumn(
1038
- options=[
1039
- "Media",
1040
- "Exogenous",
1041
- "Internal",
1042
- "Response Metrics",
1043
- ],
1044
- required=True,
1045
- default="Media",
1046
- width=200,
1047
- ),
1048
- },
1049
- num_rows="dynamic",
1050
- key="data-editor-4",
1051
- )
1052
-
1053
- final_df, edited_stats_df = process_dataframes(
1054
- merged_df, edited_df, edited_stats_df
1055
- )
1056
-
1057
- st.markdown("#### Final DataFrame")
1058
- st.dataframe(final_df, hide_index=True)
1059
-
1060
- # Initialize an empty dictionary to hold categories and their variables
1061
- category_dict = {}
1062
-
1063
- # Iterate over each row in the edited DataFrame to populate the dictionary
1064
- for i, row in edited_stats_df.iterrows():
1065
- column = row["Column"]
1066
- category = row[
1067
- "Category"
1068
- ] # The category chosen by the user for this variable
1069
-
1070
- # Check if the category already exists in the dictionary
1071
- if category not in category_dict:
1072
- # If not, initialize it with the current column as its first element
1073
- category_dict[category] = [column]
1074
- else:
1075
- # If it exists, append the current column to the list of variables under this category
1076
- category_dict[category].append(column)
1077
-
1078
- # Add Date, Panel_1 and Panel_12 in category dictionary
1079
- category_dict.update({"Date": ["date"]})
1080
- if "Panel_1" in final_df.columns:
1081
- category_dict["Panel Level 1"] = ["Panel_1"]
1082
- if "Panel_2" in final_df.columns:
1083
- category_dict["Panel Level 2"] = ["Panel_2"]
1084
-
1085
- # Display the dictionary
1086
- st.markdown("#### Variable Category")
1087
- for category, variables in category_dict.items():
1088
- # Check if there are multiple variables to handle "and" insertion correctly
1089
- if len(variables) > 1:
1090
- # Join all but the last variable with ", ", then add " and " before the last variable
1091
- variables_str = (
1092
- ", ".join(variables[:-1]) + " and " + variables[-1]
1093
- )
1094
- else:
1095
- # If there's only one variable, no need for "and"
1096
- variables_str = variables[0]
1097
-
1098
- # Display the category and its variables in the desired format
1099
- st.markdown(
1100
- f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
1101
- unsafe_allow_html=True,
1102
- )
1103
-
1104
- # Function to check if Response Metrics is selected
1105
- st.write("")
1106
- response_metrics_col = category_dict.get("Response Metrics", [])
1107
- if len(response_metrics_col) == 0:
1108
- st.warning("Please select Response Metrics column", icon="⚠️")
1109
- st.stop()
1110
- # elif len(response_metrics_col) > 1:
1111
- # st.warning("Please select only one Response Metrics column", icon="⚠️")
1112
- # st.stop()
1113
-
1114
- # Store final dataframe and bin dictionary into session state
1115
- st.session_state["final_df"], st.session_state["bin_dict"] = (
1116
- final_df,
1117
- category_dict,
1118
- )
1119
-
1120
- # Save the DataFrame and dictionary from the session state to the pickle file
1121
- if st.button(
1122
- "Accept and Save",
1123
- use_container_width=True,
1124
- key="data-editor-button",
1125
- ):
1126
- print("test*************")
1127
- update_db("1_Data_Import.py")
1128
- final_df = final_df.loc[:, ~final_df.columns.duplicated()]
1129
-
1130
- project_dct_path = os.path.join(
1131
- st.session_state["project_path"], "project_dct.pkl"
1132
- )
1133
-
1134
- with open(project_dct_path, "wb") as f:
1135
- pickle.dump(st.session_state["project_dct"], f)
1136
-
1137
- data_path = os.path.join(
1138
- st.session_state["project_path"], "data_import.pkl"
1139
- )
1140
-
1141
- st.session_state["data_path"] = data_path
1142
-
1143
- save_to_pickle(
1144
- data_path,
1145
- st.session_state["final_df"],
1146
- st.session_state["bin_dict"],
1147
- )
1148
-
1149
- st.session_state["project_dct"]["data_import"][
1150
- "edited_stats_df"
1151
- ] = edited_stats_df
1152
- st.session_state["project_dct"]["data_import"][
1153
- "merged_df"
1154
- ] = merged_df
1155
- st.session_state["project_dct"]["data_import"][
1156
- "missing_stats_df"
1157
- ] = missing_stats_df
1158
- st.session_state["project_dct"]["data_import"]["cat_dct"] = {
1159
- col: cat
1160
- for col, cat in zip(
1161
- edited_stats_df["Column"], edited_stats_df["Category"]
1162
- )
1163
- }
1164
- st.session_state["project_dct"]["data_import"][
1165
- "numeric_columns"
1166
- ] = numeric_columns
1167
- st.session_state["project_dct"]["data_import"][
1168
- "default_df"
1169
- ] = default_df
1170
- st.session_state["project_dct"]["data_import"][
1171
- "final_df"
1172
- ] = final_df
1173
- st.session_state["project_dct"]["data_import"][
1174
- "edited_df"
1175
- ] = edited_df
1176
-
1177
- st.toast("💾 Saved Successfully!")
1178
-
1179
- if accept:
1180
- # Normalize all data to a daily granularity. This initial standardization simplifies subsequent conversions to other levels of granularity
1181
- with st.spinner("Processing..."):
1182
- files_dict = standardize_data_to_daily(files_dict, selections)
1183
-
1184
- # Convert all data to daily level granularity
1185
- files_dict = apply_granularity_to_all(
1186
- files_dict, granularity_selection, selections
1187
- )
1188
-
1189
- # Update the 'files_dict' in the session state
1190
- st.session_state["files_dict"] = files_dict
1191
-
1192
- # Set a flag in the session state to indicate that selection has been made
1193
- st.session_state["Panel_1_Panel_2_Selected"] = True
1194
-
1195
- #########################################################################################################################################################
1196
- # Display unique Panel_1 and Panel_2 values
1197
- #########################################################################################################################################################
1198
-
1199
- # Halts further execution until Panel_1 and Panel_2 columns are selected
1200
- if (
1201
- st.session_state["project_dct"]["data_import"]["edited_stats_df"]
1202
- is None
1203
- ):
1204
-
1205
- if (
1206
- "files_dict" in st.session_state
1207
- and st.session_state["Panel_1_Panel_2_Selected"]
1208
- ):
1209
- files_dict = st.session_state["files_dict"]
1210
-
1211
- st.session_state["project_dct"]["data_import"][
1212
- "files_dict"
1213
- ] = files_dict # resume
1214
- else:
1215
- st.stop()
1216
-
1217
- # Set to store unique values of Panel_1 and Panel_2
1218
- with st.spinner("Fetching Panel values..."):
1219
- all_panel1_values, all_panel2_values = (
1220
- clean_and_extract_unique_values(files_dict, selections)
1221
- )
1222
-
1223
- # List of Panel_1 and Panel_2 columns unique values
1224
- list_of_all_panel1_values = list(all_panel1_values)
1225
- list_of_all_panel2_values = list(all_panel2_values)
1226
-
1227
- # Format Panel_1 and Panel_2 values for display
1228
- formatted_panel1_values = format_values_for_display(
1229
- list_of_all_panel1_values
1230
- ) ##
1231
- formatted_panel2_values = format_values_for_display(
1232
- list_of_all_panel2_values
1233
- ) ##
1234
-
1235
- # storing panel values in project_dct
1236
-
1237
- st.session_state["project_dct"]["data_import"][
1238
- "formatted_panel1_values"
1239
- ] = formatted_panel1_values
1240
- st.session_state["project_dct"]["data_import"][
1241
- "formatted_panel2_values"
1242
- ] = formatted_panel2_values
1243
-
1244
- # Unique Panel_1 and Panel_2 values
1245
- st.markdown("#### Unique Panel values")
1246
- # Display Panel_1 and Panel_2 values
1247
- with st.expander("Unique Panel values"):
1248
- st.write("")
1249
- st.markdown(
1250
- f"""
1251
- <style>
1252
- .justify-text {{
1253
- text-align: justify;
1254
- }}
1255
- </style>
1256
- <div class="justify-text">
1257
- <strong>Panel Level 1 Values:</strong> {formatted_panel1_values}<br>
1258
- <strong>Panel Level 2 Values:</strong> {formatted_panel2_values}
1259
- </div>
1260
- """,
1261
- unsafe_allow_html=True,
1262
- )
1263
-
1264
- # Display total Panel_1 and Panel_2
1265
- st.write("")
1266
- st.markdown(
1267
- f"""
1268
- <div style="text-align: justify;">
1269
- <strong>Number of Level 1 Panels detected:</strong> {len(list_of_all_panel1_values)}<br>
1270
- <strong>Number of Level 2 Panels detected:</strong> {len(list_of_all_panel2_values)}
1271
- </div>
1272
- """,
1273
- unsafe_allow_html=True,
1274
- )
1275
- st.write("")
1276
-
1277
- #########################################################################################################################################################
1278
- # Merge all DataFrames
1279
- #########################################################################################################################################################
1280
-
1281
- # Merge all DataFrames selected
1282
-
1283
- main_df = create_main_dataframe(
1284
- files_dict,
1285
- all_panel1_values,
1286
- all_panel2_values,
1287
- granularity_selection,
1288
- )
1289
-
1290
- merged_df = merge_into_main_df(main_df, files_dict, selections) ##
1291
-
1292
- #########################################################################################################################################################
1293
- # Categorize Variables and Impute Missing Values
1294
- #########################################################################################################################################################
1295
-
1296
- # Create an editable DataFrame in Streamlit
1297
-
1298
- st.markdown("#### Select Variables Category & Impute Missing Values")
1299
-
1300
- # Prepare missing stats DataFrame for editing
1301
- missing_stats_df = prepare_missing_stats_df(merged_df)
1302
-
1303
- # storing missing stats df
1304
-
1305
- edited_stats_df = st.data_editor(
1306
- missing_stats_df,
1307
- column_config={
1308
- "Impute Method": st.column_config.SelectboxColumn(
1309
- options=[
1310
- "Drop Column",
1311
- "Fill with Mean",
1312
- "Fill with Median",
1313
- "Fill with 0",
1314
- ],
1315
- required=True,
1316
- default="Fill with 0",
1317
- ),
1318
- "Category": st.column_config.SelectboxColumn(
1319
- options=[
1320
- "Media",
1321
- "Exogenous",
1322
- "Internal",
1323
- "Response Metrics",
1324
- ],
1325
- required=True,
1326
- default="Media",
1327
- ),
1328
- },
1329
- disabled=["Column", "Missing Values", "Missing Percentage"],
1330
- hide_index=True,
1331
- use_container_width=True,
1332
- key="data-editor-2",
1333
- )
1334
-
1335
- # edited_stats_df_path=os.path.join(st.session_state['project_path'],"edited_stats_df.pkl")
1336
-
1337
- # edited_stats_df.to_pickle(edited_stats_df_path)
1338
-
1339
- # Apply changes based on edited DataFrame
1340
- for i, row in edited_stats_df.iterrows():
1341
- column = row["Column"]
1342
- if row["Impute Method"] == "Drop Column":
1343
- merged_df.drop(columns=[column], inplace=True)
1344
-
1345
- elif row["Impute Method"] == "Fill with Mean":
1346
- merged_df[column].fillna(
1347
- merged_df[column].mean(), inplace=True
1348
- )
1349
-
1350
- elif row["Impute Method"] == "Fill with Median":
1351
- merged_df[column].fillna(
1352
- merged_df[column].median(), inplace=True
1353
- )
1354
-
1355
- elif row["Impute Method"] == "Fill with 0":
1356
- merged_df[column].fillna(0, inplace=True)
1357
-
1358
- # st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
1359
-
1360
- #########################################################################################################################################################
1361
- # Group columns
1362
- #########################################################################################################################################################
1363
-
1364
- # Display Group columns header
1365
- st.markdown("#### Feature engineering")
1366
-
1367
- # Prepare the numeric columns and an empty DataFrame for user input
1368
- numeric_columns, default_df = prepare_numeric_columns_and_default_df(
1369
- merged_df, edited_stats_df
1370
- )
1371
-
1372
- # st.session_state['project_dct']['data_import']['edited_stats_df']=edited_stats_df
1373
-
1374
- # Display editable Dataframe
1375
- edited_df = st.data_editor(
1376
- default_df,
1377
- column_config={
1378
- "Column 1": st.column_config.SelectboxColumn(
1379
- options=numeric_columns,
1380
- required=True,
1381
- width=400,
1382
- ),
1383
- "Operator": st.column_config.SelectboxColumn(
1384
- options=["+", "-", "*", "/"],
1385
- required=True,
1386
- default="+",
1387
- width=100,
1388
- ),
1389
- "Column 2": st.column_config.SelectboxColumn(
1390
- options=numeric_columns,
1391
- required=True,
1392
- default=numeric_columns[0],
1393
- width=400,
1394
- ),
1395
- "Category": st.column_config.SelectboxColumn(
1396
- options=[
1397
- "Media",
1398
- "Exogenous",
1399
- "Internal",
1400
- "Response Metrics",
1401
- ],
1402
- required=True,
1403
- default="Media",
1404
- width=200,
1405
- ),
1406
- },
1407
- num_rows="dynamic",
1408
- key="data-editor-3",
1409
- )
1410
-
1411
- # Process the DataFrame based on user inputs and operations specified in edited_df
1412
- final_df, edited_stats_df = process_dataframes(
1413
- merged_df, edited_df, edited_stats_df
1414
- )
1415
-
1416
- # edited_df_path=os.path.join(st.session_state['project_path'],'edited_df.pkl')
1417
- # edited_df.to_pickle(edited_df_path)
1418
-
1419
- #########################################################################################################################################################
1420
- # Display the Final DataFrame and variables
1421
- #########################################################################################################################################################
1422
-
1423
- # Display the Final DataFrame and variables
1424
-
1425
- st.markdown("#### Final DataFrame")
1426
-
1427
- st.dataframe(final_df, hide_index=True)
1428
-
1429
- # Initialize an empty dictionary to hold categories and their variables
1430
- category_dict = {}
1431
-
1432
- # Iterate over each row in the edited DataFrame to populate the dictionary
1433
- for i, row in edited_stats_df.iterrows():
1434
- column = row["Column"]
1435
- category = row[
1436
- "Category"
1437
- ] # The category chosen by the user for this variable
1438
-
1439
- # Check if the category already exists in the dictionary
1440
- if category not in category_dict:
1441
- # If not, initialize it with the current column as its first element
1442
- category_dict[category] = [column]
1443
- else:
1444
- # If it exists, append the current column to the list of variables under this category
1445
- category_dict[category].append(column)
1446
-
1447
- # Add Date, Panel_1 and Panel_12 in category dictionary
1448
- category_dict.update({"Date": ["date"]})
1449
- if "Panel_1" in final_df.columns:
1450
- category_dict["Panel Level 1"] = ["Panel_1"]
1451
- if "Panel_2" in final_df.columns:
1452
- category_dict["Panel Level 2"] = ["Panel_2"]
1453
-
1454
- # Display the dictionary
1455
- st.markdown("#### Variable Category")
1456
- for category, variables in category_dict.items():
1457
- # Check if there are multiple variables to handle "and" insertion correctly
1458
- if len(variables) > 1:
1459
- # Join all but the last variable with ", ", then add " and " before the last variable
1460
- variables_str = (
1461
- ", ".join(variables[:-1]) + " and " + variables[-1]
1462
- )
1463
- else:
1464
- # If there's only one variable, no need for "and"
1465
- variables_str = variables[0]
1466
-
1467
- # Display the category and its variables in the desired format
1468
- st.markdown(
1469
- f"<div style='text-align: justify;'><strong>{category}:</strong> {variables_str}</div>",
1470
- unsafe_allow_html=True,
1471
- )
1472
-
1473
- # Function to check if Response Metrics is selected
1474
- st.write("")
1475
-
1476
- response_metrics_col = category_dict.get("Response Metrics", [])
1477
- if len(response_metrics_col) == 0:
1478
- st.warning("Please select Response Metrics column", icon="⚠️")
1479
- st.stop()
1480
- # elif len(response_metrics_col) > 1:
1481
- # st.warning("Please select only one Response Metrics column", icon="⚠️")
1482
- # st.stop()
1483
-
1484
- # Store final dataframe and bin dictionary into session state
1485
-
1486
- st.session_state["final_df"], st.session_state["bin_dict"] = (
1487
- final_df,
1488
- category_dict,
1489
- )
1490
-
1491
- # Save the DataFrame and dictionary from the session state to the pickle file
1492
-
1493
- if st.button("Accept and Save", use_container_width=True):
1494
-
1495
- print("test*************")
1496
- update_db("1_Data_Import.py")
1497
-
1498
- project_dct_path = os.path.join(
1499
- st.session_state["project_path"], "project_dct.pkl"
1500
- )
1501
-
1502
- with open(project_dct_path, "wb") as f:
1503
- pickle.dump(st.session_state["project_dct"], f)
1504
-
1505
- data_path = os.path.join(
1506
- st.session_state["project_path"], "data_import.pkl"
1507
- )
1508
- st.session_state["data_path"] = data_path
1509
-
1510
- save_to_pickle(
1511
- data_path,
1512
- st.session_state["final_df"],
1513
- st.session_state["bin_dict"],
1514
- )
1515
-
1516
- st.session_state["project_dct"]["data_import"][
1517
- "edited_stats_df"
1518
- ] = edited_stats_df
1519
- st.session_state["project_dct"]["data_import"][
1520
- "merged_df"
1521
- ] = merged_df
1522
- st.session_state["project_dct"]["data_import"][
1523
- "missing_stats_df"
1524
- ] = missing_stats_df
1525
- st.session_state["project_dct"]["data_import"]["cat_dct"] = {
1526
- col: cat
1527
- for col, cat in zip(
1528
- edited_stats_df["Column"], edited_stats_df["Category"]
1529
- )
1530
- }
1531
- st.session_state["project_dct"]["data_import"][
1532
- "numeric_columns"
1533
- ] = numeric_columns
1534
- st.session_state["project_dct"]["data_import"][
1535
- "default_df"
1536
- ] = default_df
1537
- st.session_state["project_dct"]["data_import"][
1538
- "final_df"
1539
- ] = final_df
1540
- st.session_state["project_dct"]["data_import"][
1541
- "edited_df"
1542
- ] = edited_df
1543
-
1544
- st.toast("💾 Saved Successfully!")
1545
-
1546
- # *****************************************************************
1547
- # *********************************Persistant flow****************
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/2_Data_Validation.py DELETED
@@ -1,509 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import plotly.express as px
4
- import plotly.graph_objects as go
5
- from Eda_functions import *
6
- import numpy as np
7
- import pickle
8
-
9
- # from streamlit_pandas_profiling import st_profile_report
10
- import streamlit as st
11
- import streamlit.components.v1 as components
12
- import sweetviz as sv
13
- from utilities import set_header, load_local_css
14
- from st_aggrid import GridOptionsBuilder, GridUpdateMode
15
- from st_aggrid import GridOptionsBuilder
16
- from st_aggrid import AgGrid
17
- import base64
18
- import os
19
- import tempfile
20
- #import pandas_profiling
21
- #from pydantic_settings import BaseSettings
22
- from ydata_profiling import ProfileReport
23
- import re
24
-
25
- # from pygwalker.api.streamlit import StreamlitRenderer
26
- # from Home_redirecting import home
27
- import sqlite3
28
- from utilities import update_db
29
-
30
- st.set_page_config(
31
- page_title="Data Validation",
32
- page_icon=":shark:",
33
- layout="wide",
34
- initial_sidebar_state="collapsed",
35
- )
36
- load_local_css("styles.css")
37
- set_header()
38
-
39
-
40
- if "project_dct" not in st.session_state:
41
- # home()
42
- st.warning("Please select a project from home page")
43
- st.stop()
44
-
45
-
46
- data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")
47
-
48
- try:
49
- with open(data_path, "rb") as f:
50
- data = pickle.load(f)
51
- except Exception as e:
52
- st.error(f"Please import data from the Data Import Page")
53
- st.stop()
54
-
55
- conn = sqlite3.connect(
56
- r"DB\User.db", check_same_thread=False
57
- ) # connection with sql db
58
- c = conn.cursor()
59
- st.session_state["cleaned_data"] = data["final_df"]
60
- st.session_state["category_dict"] = data["bin_dict"]
61
- # st.write(st.session_state['category_dict'])
62
-
63
- st.title("Data Validation and Insights")
64
-
65
-
66
- target_variables = [
67
- st.session_state["category_dict"][key]
68
- for key in st.session_state["category_dict"].keys()
69
- if key == "Response Metrics"
70
- ]
71
- target_variables = list(*target_variables)
72
- target_column = st.selectbox(
73
- "Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
74
- target_variables,
75
- index=st.session_state["project_dct"]["data_validation"]["target_column"],
76
- )
77
-
78
- st.session_state["project_dct"]["data_validation"]["target_column"] = (
79
- target_variables.index(target_column)
80
- )
81
-
82
- st.session_state["target_column"] = target_column
83
-
84
- panels = st.session_state["category_dict"]["Panel Level 1"][0]
85
-
86
- selected_panels = st.multiselect(
87
- "Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
88
- st.session_state["cleaned_data"][panels].unique(),
89
- default=st.session_state["project_dct"]["data_validation"][
90
- "selected_panels"
91
- ],
92
- )
93
-
94
- st.session_state["project_dct"]["data_validation"][
95
- "selected_panels"
96
- ] = selected_panels
97
-
98
- aggregation_dict = {
99
- item: "sum" if key == "Media" else "mean"
100
- for key, value in st.session_state["category_dict"].items()
101
- for item in value
102
- if item not in ["date", "Panel_1"]
103
- }
104
-
105
- with st.expander("**Reponse Metric Analysis**"):
106
-
107
- if len(selected_panels) > 0:
108
- st.session_state["Cleaned_data_panel"] = st.session_state[
109
- "cleaned_data"
110
- ][st.session_state["cleaned_data"]["Panel_1"].isin(selected_panels)]
111
-
112
- st.session_state["Cleaned_data_panel"] = (
113
- st.session_state["Cleaned_data_panel"]
114
- .groupby(by="date")
115
- .agg(aggregation_dict)
116
- )
117
- st.session_state["Cleaned_data_panel"] = st.session_state[
118
- "Cleaned_data_panel"
119
- ].reset_index()
120
- else:
121
- # st.write(st.session_state['cleaned_data'])
122
- st.session_state["Cleaned_data_panel"] = (
123
- st.session_state["cleaned_data"]
124
- .groupby(by="date")
125
- .agg(aggregation_dict)
126
- )
127
- st.session_state["Cleaned_data_panel"] = st.session_state[
128
- "Cleaned_data_panel"
129
- ].reset_index()
130
-
131
- fig = line_plot_target(
132
- st.session_state["Cleaned_data_panel"],
133
- target=target_column,
134
- title=f"{target_column} Over Time",
135
- )
136
- st.plotly_chart(fig, use_container_width=True)
137
-
138
- media_channel = list(
139
- *[
140
- st.session_state["category_dict"][key]
141
- for key in st.session_state["category_dict"].keys()
142
- if key == "Media"
143
- ]
144
- )
145
- # st.write(media_channel)
146
-
147
- exo_var = list(
148
- *[
149
- st.session_state["category_dict"][key]
150
- for key in st.session_state["category_dict"].keys()
151
- if key == "Exogenous"
152
- ]
153
- )
154
- internal_var = list(
155
- *[
156
- st.session_state["category_dict"][key]
157
- for key in st.session_state["category_dict"].keys()
158
- if key == "Internal"
159
- ]
160
- )
161
- Non_media_variables = exo_var + internal_var
162
-
163
- st.markdown("### Annual Data Summary")
164
-
165
- st.dataframe(
166
- summary(
167
- st.session_state["Cleaned_data_panel"],
168
- media_channel + [target_column],
169
- spends=None,
170
- Target=True,
171
- ),
172
- use_container_width=True,
173
- )
174
-
175
- if st.checkbox("Show raw data"):
176
- st.write(
177
- pd.concat(
178
- [
179
- pd.to_datetime(
180
- st.session_state["Cleaned_data_panel"]["date"]
181
- ).dt.strftime("%m/%d/%Y"),
182
- st.session_state["Cleaned_data_panel"]
183
- .select_dtypes(np.number)
184
- .applymap(format_numbers),
185
- ],
186
- axis=1,
187
- )
188
- )
189
- col1 = st.columns(1)
190
-
191
- if "selected_feature" not in st.session_state:
192
- st.session_state["selected_feature"] = None
193
-
194
-
195
- def generate_report_with_target(channel_data, target_feature):
196
- report = sv.analyze([channel_data, "Dataset"], target_feat=target_feature)
197
- temp_dir = tempfile.mkdtemp()
198
- report_path = os.path.join(temp_dir, "report.html")
199
- report.show_html(
200
- filepath=report_path, open_browser=False
201
- ) # Generate the report as an HTML file
202
- return report_path
203
-
204
-
205
- def generate_profile_report(df):
206
- pr = df.profile_report()
207
- temp_dir = tempfile.mkdtemp()
208
- report_path = os.path.join(temp_dir, "report.html")
209
- pr.to_file(report_path)
210
- return report_path
211
-
212
-
213
- # st.header()
214
- with st.expander("Univariate and Bivariate Report"):
215
- eda_columns = st.columns(2)
216
- with eda_columns[0]:
217
- if st.button(
218
- "Generate Profile Report",
219
- help="Univariate report which inlcudes all statistical analysis",
220
- ):
221
- with st.spinner("Generating Report"):
222
- report_file = generate_profile_report(
223
- st.session_state["Cleaned_data_panel"]
224
- )
225
-
226
- if os.path.exists(report_file):
227
- with open(report_file, "rb") as f:
228
- st.success("Report Generated")
229
- st.download_button(
230
- label="Download EDA Report",
231
- data=f.read(),
232
- file_name="pandas_profiling_report.html",
233
- mime="text/html",
234
- )
235
- else:
236
- st.warning(
237
- "Report generation failed. Unable to find the report file."
238
- )
239
-
240
- with eda_columns[1]:
241
- if st.button(
242
- "Generate Sweetviz Report",
243
- help="Bivariate report for selected response metric",
244
- ):
245
- with st.spinner("Generating Report"):
246
- report_file = generate_report_with_target(
247
- st.session_state["Cleaned_data_panel"], target_column
248
- )
249
-
250
- if os.path.exists(report_file):
251
- with open(report_file, "rb") as f:
252
- st.success("Report Generated")
253
- st.download_button(
254
- label="Download EDA Report",
255
- data=f.read(),
256
- file_name="report.html",
257
- mime="text/html",
258
- )
259
- else:
260
- st.warning(
261
- "Report generation failed. Unable to find the report file."
262
- )
263
-
264
-
265
- # st.warning('Work in Progress')
266
- with st.expander("Media Variables Analysis"):
267
- # Get the selected feature
268
-
269
- media_variables = [
270
- col
271
- for col in media_channel
272
- if "cost" not in col.lower() and "spend" not in col.lower()
273
- ]
274
-
275
- st.session_state["selected_feature"] = st.selectbox(
276
- "Select media", media_variables
277
- )
278
-
279
- st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
280
- media_variables.index(st.session_state["selected_feature"])
281
- )
282
-
283
- # Filter spends features based on the selected feature
284
- spends_features = [
285
- col
286
- for col in st.session_state["Cleaned_data_panel"].columns
287
- if any(keyword in col.lower() for keyword in ["cost", "spend"])
288
- ]
289
- spends_feature = [
290
- col
291
- for col in spends_features
292
- if re.split(r"_cost|_spend", col.lower())[0]
293
- in st.session_state["selected_feature"]
294
- ]
295
-
296
- if "validation" not in st.session_state:
297
-
298
- st.session_state["validation"] = st.session_state["project_dct"][
299
- "data_validation"
300
- ]["validated_variables"]
301
-
302
- val_variables = [col for col in media_channel if col != "date"]
303
-
304
- if not set(
305
- st.session_state["project_dct"]["data_validation"][
306
- "validated_variables"
307
- ]
308
- ).issubset(set(val_variables)):
309
-
310
- st.session_state["validation"] = []
311
-
312
- if len(spends_feature) == 0:
313
- st.warning(
314
- "No spends varaible available for the selected metric in data"
315
- )
316
-
317
- else:
318
- fig_row1 = line_plot(
319
- st.session_state["Cleaned_data_panel"],
320
- x_col="date",
321
- y1_cols=[st.session_state["selected_feature"]],
322
- y2_cols=[target_column],
323
- title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
324
- )
325
- st.plotly_chart(fig_row1, use_container_width=True)
326
- st.markdown("### Summary")
327
- st.dataframe(
328
- summary(
329
- st.session_state["cleaned_data"],
330
- [st.session_state["selected_feature"]],
331
- spends=spends_feature[0],
332
- ),
333
- use_container_width=True,
334
- )
335
-
336
- cols2 = st.columns(2)
337
-
338
- if len(
339
- set(st.session_state["validation"]).intersection(val_variables)
340
- ) == len(val_variables):
341
- disable = True
342
- help = "All media variables are validated"
343
- else:
344
- disable = False
345
- help = ""
346
-
347
- with cols2[0]:
348
- if st.button("Validate", disabled=disable, help=help):
349
- st.session_state["validation"].append(
350
- st.session_state["selected_feature"]
351
- )
352
- with cols2[1]:
353
-
354
- if st.checkbox("Validate all", disabled=disable, help=help):
355
- st.session_state["validation"].extend(val_variables)
356
- st.success("All media variables are validated ✅")
357
-
358
- if len(
359
- set(st.session_state["validation"]).intersection(val_variables)
360
- ) != len(val_variables):
361
- validation_data = pd.DataFrame(
362
- {
363
- "Validate": [
364
- (
365
- True
366
- if col in st.session_state["validation"]
367
- else False
368
- )
369
- for col in val_variables
370
- ],
371
- "Variables": val_variables,
372
- }
373
- )
374
- cols3 = st.columns([1, 30])
375
- with cols3[1]:
376
- validation_df = st.data_editor(
377
- validation_data,
378
- # column_config={
379
- # 'Validate':st.column_config.CheckboxColumn(wi)
380
- # },
381
- column_config={
382
- "Validate": st.column_config.CheckboxColumn(
383
- default=False,
384
- width=100,
385
- ),
386
- "Variables": st.column_config.TextColumn(width=1000),
387
- },
388
- hide_index=True,
389
- )
390
-
391
- selected_rows = validation_df[
392
- validation_df["Validate"] == True
393
- ]["Variables"]
394
-
395
- # st.write(selected_rows)
396
-
397
- st.session_state["validation"].extend(selected_rows)
398
-
399
- st.session_state["project_dct"]["data_validation"][
400
- "validated_variables"
401
- ] = st.session_state["validation"]
402
-
403
- not_validated_variables = [
404
- col
405
- for col in val_variables
406
- if col not in st.session_state["validation"]
407
- ]
408
-
409
- if not_validated_variables:
410
- not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
411
- st.warning(not_validated_message)
412
-
413
-
414
- with st.expander("Non Media Variables Analysis"):
415
- selected_columns_row4 = st.selectbox(
416
- "Select Channel",
417
- Non_media_variables,
418
- index=st.session_state["project_dct"]["data_validation"][
419
- "Non_media_variables"
420
- ],
421
- )
422
-
423
- st.session_state["project_dct"]["data_validation"][
424
- "Non_media_variables"
425
- ] = Non_media_variables.index(selected_columns_row4)
426
-
427
- # # Create the dual-axis line plot
428
- fig_row4 = line_plot(
429
- st.session_state["Cleaned_data_panel"],
430
- x_col="date",
431
- y1_cols=[selected_columns_row4],
432
- y2_cols=[target_column],
433
- title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
434
- )
435
- st.plotly_chart(fig_row4, use_container_width=True)
436
- selected_non_media = selected_columns_row4
437
- sum_df = st.session_state["Cleaned_data_panel"][
438
- ["date", selected_non_media, target_column]
439
- ]
440
- sum_df["Year"] = pd.to_datetime(
441
- st.session_state["Cleaned_data_panel"]["date"]
442
- ).dt.year
443
- # st.dataframe(df)
444
- # st.dataframe(sum_df.head(2))
445
- print(sum_df)
446
- sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
447
- sum_df.loc["Grand Total"] = sum_df.sum()
448
- sum_df = sum_df.applymap(format_numbers)
449
- sum_df.fillna("-", inplace=True)
450
- sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
451
- st.markdown("### Summary")
452
- st.dataframe(sum_df, use_container_width=True)
453
-
454
- # with st.expander('Interactive Dashboard'):
455
-
456
- # pygg_app=StreamlitRenderer(st.session_state['cleaned_data'])
457
-
458
- # pygg_app.explorer()
459
-
460
- with st.expander("Correlation Analysis"):
461
- options = list(
462
- st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
463
- )
464
-
465
- # selected_options = []
466
- # num_columns = 4
467
- # num_rows = -(-len(options) // num_columns) # Ceiling division to calculate rows
468
-
469
- # # Create a grid of checkboxes
470
- # st.header('Select Features for Correlation Plot')
471
- # tick=False
472
- # if st.checkbox('Select all'):
473
- # tick=True
474
- # selected_options = []
475
- # for row in range(num_rows):
476
- # cols = st.columns(num_columns)
477
- # for col in cols:
478
- # if options:
479
- # option = options.pop(0)
480
- # selected = col.checkbox(option,value=tick)
481
- # if selected:
482
- # selected_options.append(option)
483
- # # Display selected options
484
-
485
- selected_options = st.multiselect(
486
- "Select Variables For correlation plot",
487
- [var for var in options if var != target_column],
488
- default=options[3],
489
- )
490
-
491
- st.pyplot(
492
- correlation_plot(
493
- st.session_state["Cleaned_data_panel"],
494
- selected_options,
495
- target_column,
496
- )
497
- )
498
-
499
- if st.button("Save Changes", use_container_width=True):
500
-
501
- update_db("2_Data_Validation.py")
502
-
503
- project_dct_path = os.path.join(
504
- st.session_state["project_path"], "project_dct.pkl"
505
- )
506
-
507
- with open(project_dct_path, "wb") as f:
508
- pickle.dump(st.session_state["project_dct"], f)
509
- st.success("Changes saved")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/3_Transformations.py DELETED
@@ -1,686 +0,0 @@
1
- # Importing necessary libraries
2
- import streamlit as st
3
-
4
- st.set_page_config(
5
- page_title="Transformations",
6
- page_icon=":shark:",
7
- layout="wide",
8
- initial_sidebar_state="collapsed",
9
- )
10
-
11
- import pickle
12
- import numpy as np
13
- import pandas as pd
14
- from utilities import set_header, load_local_css
15
- import streamlit_authenticator as stauth
16
- import yaml
17
- from yaml import SafeLoader
18
- import os
19
- import sqlite3
20
- from utilities import update_db
21
-
22
-
23
- load_local_css("styles.css")
24
- set_header()
25
-
26
-
27
- # Check for authentication status
28
- for k, v in st.session_state.items():
29
- if k not in ["logout", "login", "config"] and not k.startswith(
30
- "FormSubmitter"
31
- ):
32
- st.session_state[k] = v
33
- with open("config.yaml") as file:
34
- config = yaml.load(file, Loader=SafeLoader)
35
- st.session_state["config"] = config
36
- authenticator = stauth.Authenticate(
37
- config["credentials"],
38
- config["cookie"]["name"],
39
- config["cookie"]["key"],
40
- config["cookie"]["expiry_days"],
41
- config["preauthorized"],
42
- )
43
- st.session_state["authenticator"] = authenticator
44
- name, authentication_status, username = authenticator.login("Login", "main")
45
- auth_status = st.session_state.get("authentication_status")
46
-
47
- if auth_status == True:
48
- authenticator.logout("Logout", "main")
49
- is_state_initiaized = st.session_state.get("initialized", False)
50
-
51
- if "project_dct" not in st.session_state:
52
- st.error("Please load a project from Home page")
53
- st.stop()
54
-
55
- conn = sqlite3.connect(
56
- r"DB/User.db", check_same_thread=False
57
- ) # connection with sql db
58
- c = conn.cursor()
59
-
60
- if not is_state_initiaized:
61
- if "session_name" not in st.session_state:
62
- st.session_state["session_name"] = None
63
-
64
- if not os.path.exists(
65
- os.path.join(st.session_state["project_path"], "data_import.pkl")
66
- ):
67
- st.error("Please move to Data Import page")
68
- # Deserialize and load the objects from the pickle file
69
- with open(
70
- os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
71
- ) as f:
72
- data = pickle.load(f)
73
-
74
- # Accessing the loaded objects
75
- final_df_loaded = data["final_df"]
76
- bin_dict_loaded = data["bin_dict"]
77
- # final_df_loaded.to_csv("Test/final_df_loaded.csv",index=False)
78
- # Initialize session state==-
79
- if "transformed_columns_dict" not in st.session_state:
80
- st.session_state["transformed_columns_dict"] = (
81
- {}
82
- ) # Default empty dictionary
83
-
84
- if "final_df" not in st.session_state:
85
- st.session_state["final_df"] = (
86
- final_df_loaded # Default as original dataframe
87
- )
88
-
89
- if "summary_string" not in st.session_state:
90
- st.session_state["summary_string"] = None # Default as None
91
-
92
- # Extract original columns for specified categories
93
- original_columns = {
94
- category: bin_dict_loaded[category]
95
- for category in ["Media", "Internal", "Exogenous"]
96
- if category in bin_dict_loaded
97
- }
98
-
99
- # Retrive Panel columns
100
- panel_1 = bin_dict_loaded.get("Panel Level 1")
101
- panel_2 = bin_dict_loaded.get("Panel Level 2")
102
-
103
- # # For testing on non panel level
104
- # final_df_loaded = final_df_loaded.drop("Panel_1", axis=1)
105
- # final_df_loaded = final_df_loaded.groupby("date").mean().reset_index()
106
- # panel_1 = None
107
-
108
- # Apply transformations on panel level
109
- if panel_1:
110
- panel = panel_1 + panel_2 if panel_2 else panel_1
111
- else:
112
- panel = []
113
-
114
- # Function to build transformation widgets
115
- def transformation_widgets(category, transform_params, date_granularity):
116
-
117
- if (
118
- st.session_state["project_dct"]["transformations"] is None
119
- or st.session_state["project_dct"]["transformations"] == {}
120
- ):
121
- st.session_state["project_dct"]["transformations"] = {}
122
- if (
123
- category
124
- not in st.session_state["project_dct"]["transformations"].keys()
125
- ):
126
- st.session_state["project_dct"]["transformations"][category] = {}
127
-
128
- # Define a dict of pre-defined default values of every transformation
129
- predefined_defualts = {
130
- "Lag": (1, 2),
131
- "Lead": (1, 2),
132
- "Moving Average": (1, 2),
133
- "Saturation": (10, 20),
134
- "Power": (2, 4),
135
- "Adstock": (0.5, 0.7),
136
- }
137
-
138
- def selection_change():
139
- # Handles removing transformations
140
- if f"transformation_{category}" in st.session_state:
141
- current_selection = st.session_state[
142
- f"transformation_{category}"
143
- ]
144
- past_selection = st.session_state["project_dct"][
145
- "transformations"
146
- ][category][f"transformation_{category}"]
147
- removed_selection = list(
148
- set(past_selection) - set(current_selection)
149
- )
150
- for selection in removed_selection:
151
- # Option 1 - revert to defualt
152
- # st.session_state['project_dct']['transformations'][category][selection] = predefined_defualts[selection]
153
-
154
- # option 2 - delete from dict
155
- del st.session_state["project_dct"]["transformations"][
156
- category
157
- ][selection]
158
-
159
- # Transformation Options
160
- transformation_options = {
161
- "Media": [
162
- "Lag",
163
- "Moving Average",
164
- "Saturation",
165
- "Power",
166
- "Adstock",
167
- ],
168
- "Internal": ["Lead", "Lag", "Moving Average"],
169
- "Exogenous": ["Lead", "Lag", "Moving Average"],
170
- }
171
-
172
- expanded = st.session_state["project_dct"]["transformations"][
173
- category
174
- ].get("expanded", False)
175
- st.session_state["project_dct"]["transformations"][category][
176
- "expanded"
177
- ] = False
178
- with st.expander(f"{category} Transformations", expanded=expanded):
179
- st.session_state["project_dct"]["transformations"][category][
180
- "expanded"
181
- ] = True
182
-
183
- # Let users select which transformations to apply
184
- sel_transformations = st.session_state["project_dct"][
185
- "transformations"
186
- ][category].get(f"transformation_{category}", [])
187
- transformations_to_apply = st.multiselect(
188
- "Select transformations to apply",
189
- options=transformation_options[category],
190
- default=sel_transformations,
191
- key=f"transformation_{category}",
192
- # on_change=selection_change(),
193
- )
194
- st.session_state["project_dct"]["transformations"][category][
195
- "transformation_" + category
196
- ] = transformations_to_apply
197
- # Determine the number of transformations to put in each column
198
- transformations_per_column = (
199
- len(transformations_to_apply) // 2
200
- + len(transformations_to_apply) % 2
201
- )
202
-
203
- # Create two columns
204
- col1, col2 = st.columns(2)
205
-
206
- # Assign transformations to each column
207
- transformations_col1 = transformations_to_apply[
208
- :transformations_per_column
209
- ]
210
- transformations_col2 = transformations_to_apply[
211
- transformations_per_column:
212
- ]
213
-
214
- # Define a helper function to create widgets for each transformation
215
- def create_transformation_widgets(column, transformations):
216
- with column:
217
- for transformation in transformations:
218
- # Conditionally create widgets for selected transformations
219
- if transformation == "Lead":
220
- lead_default = st.session_state["project_dct"][
221
- "transformations"
222
- ][category].get(
223
- "Lead", predefined_defualts["Lead"]
224
- )
225
- st.markdown(f"**Lead ({date_granularity})**")
226
- lead = st.slider(
227
- "Lead periods",
228
- 1,
229
- 10,
230
- lead_default,
231
- 1,
232
- key=f"lead_{category}",
233
- label_visibility="collapsed",
234
- )
235
- st.session_state["project_dct"]["transformations"][
236
- category
237
- ]["Lead"] = lead
238
- start = lead[0]
239
- end = lead[1]
240
- step = 1
241
- transform_params[category]["Lead"] = np.arange(
242
- start, end + step, step
243
- )
244
-
245
- if transformation == "Lag":
246
- lag_default = st.session_state["project_dct"][
247
- "transformations"
248
- ][category].get("Lag", predefined_defualts["Lag"])
249
- st.markdown(f"**Lag ({date_granularity})**")
250
- lag = st.slider(
251
- "Lag periods",
252
- 1,
253
- 10,
254
- (1, 2), # lag_default,
255
- 1,
256
- key=f"lag_{category}",
257
- label_visibility="collapsed",
258
- )
259
- st.session_state["project_dct"]["transformations"][
260
- category
261
- ]["Lag"] = lag
262
- start = lag[0]
263
- end = lag[1]
264
- step = 1
265
- transform_params[category]["Lag"] = np.arange(
266
- start, end + step, step
267
- )
268
-
269
- if transformation == "Moving Average":
270
- ma_default = st.session_state["project_dct"][
271
- "transformations"
272
- ][category].get(
273
- "MA", predefined_defualts["Moving Average"]
274
- )
275
- st.markdown(
276
- f"**Moving Average ({date_granularity})**"
277
- )
278
- window = st.slider(
279
- "Window size for Moving Average",
280
- 1,
281
- 10,
282
- ma_default,
283
- 1,
284
- key=f"ma_{category}",
285
- label_visibility="collapsed",
286
- )
287
- st.session_state["project_dct"]["transformations"][
288
- category
289
- ]["MA"] = window
290
- start = window[0]
291
- end = window[1]
292
- step = 1
293
- transform_params[category]["Moving Average"] = (
294
- np.arange(start, end + step, step)
295
- )
296
-
297
- if transformation == "Saturation":
298
- st.markdown("**Saturation (%)**")
299
- saturation_default = st.session_state[
300
- "project_dct"
301
- ]["transformations"][category].get(
302
- "Saturation", predefined_defualts["Saturation"]
303
- )
304
- saturation_point = st.slider(
305
- f"Saturation Percentage",
306
- 0,
307
- 100,
308
- saturation_default,
309
- 10,
310
- key=f"sat_{category}",
311
- label_visibility="collapsed",
312
- )
313
- st.session_state["project_dct"]["transformations"][
314
- category
315
- ]["Saturation"] = saturation_point
316
- start = saturation_point[0]
317
- end = saturation_point[1]
318
- step = 10
319
- transform_params[category]["Saturation"] = (
320
- np.arange(start, end + step, step)
321
- )
322
-
323
- if transformation == "Power":
324
- st.markdown("**Power**")
325
- power_default = st.session_state["project_dct"][
326
- "transformations"
327
- ][category].get(
328
- "Power", predefined_defualts["Power"]
329
- )
330
- power = st.slider(
331
- f"Power",
332
- 0,
333
- 10,
334
- power_default,
335
- 1,
336
- key=f"power_{category}",
337
- label_visibility="collapsed",
338
- )
339
- st.session_state["project_dct"]["transformations"][
340
- category
341
- ]["Power"] = power
342
- start = power[0]
343
- end = power[1]
344
- step = 1
345
- transform_params[category]["Power"] = np.arange(
346
- start, end + step, step
347
- )
348
-
349
- if transformation == "Adstock":
350
- ads_default = st.session_state["project_dct"][
351
- "transformations"
352
- ][category].get(
353
- "Adstock", predefined_defualts["Adstock"]
354
- )
355
- st.markdown("**Adstock**")
356
- rate = st.slider(
357
- f"Factor ({category})",
358
- 0.0,
359
- 1.0,
360
- ads_default,
361
- 0.05,
362
- key=f"adstock_{category}",
363
- label_visibility="collapsed",
364
- )
365
- st.session_state["project_dct"]["transformations"][
366
- category
367
- ]["Adstock"] = rate
368
- start = rate[0]
369
- end = rate[1]
370
- step = 0.05
371
- adstock_range = [
372
- round(a, 3)
373
- for a in np.arange(start, end + step, step)
374
- ]
375
- transform_params[category][
376
- "Adstock"
377
- ] = adstock_range
378
-
379
- # Create widgets in each column
380
- create_transformation_widgets(col1, transformations_col1)
381
- create_transformation_widgets(col2, transformations_col2)
382
-
383
- # Function to apply Lag transformation
384
- def apply_lag(df, lag):
385
- return df.shift(lag)
386
-
387
- # Function to apply Lead transformation
388
- def apply_lead(df, lead):
389
- return df.shift(-lead)
390
-
391
- # Function to apply Moving Average transformation
392
- def apply_moving_average(df, window_size):
393
- return df.rolling(window=window_size).mean()
394
-
395
- # Function to apply Saturation transformation
396
- def apply_saturation(df, saturation_percent_100):
397
- # Convert saturation percentage from 100-based to fraction
398
- saturation_percent = saturation_percent_100 / 100.0
399
-
400
- # Calculate saturation point and steepness
401
- column_max = df.max()
402
- column_min = df.min()
403
- saturation_point = (column_min + column_max) / 2
404
-
405
- numerator = np.log(
406
- (1 / (saturation_percent if saturation_percent != 1 else 1 - 1e-9))
407
- - 1
408
- )
409
- denominator = np.log(saturation_point / max(column_max, 1e-9))
410
-
411
- steepness = numerator / max(
412
- denominator, 1e-9
413
- ) # Avoid division by zero with a small constant
414
-
415
- # Apply the saturation transformation
416
- transformed_series = df.apply(
417
- lambda x: (1 / (1 + (saturation_point / x) ** steepness)) * x
418
- )
419
-
420
- return transformed_series
421
-
422
- # Function to apply Power transformation
423
- def apply_power(df, power):
424
- return df**power
425
-
426
- # Function to apply Adstock transformation
427
- def apply_adstock(df, factor):
428
- x = 0
429
- # Use the walrus operator to update x iteratively with the Adstock formula
430
- adstock_var = [x := x * factor + v for v in df]
431
- ans = pd.Series(adstock_var, index=df.index)
432
- return ans
433
-
434
- # Function to generate transformed columns names
435
- @st.cache_resource(show_spinner=False)
436
- def generate_transformed_columns(original_columns, transform_params):
437
- transformed_columns, summary = {}, {}
438
-
439
- for category, columns in original_columns.items():
440
- for column in columns:
441
- transformed_columns[column] = []
442
- summary_details = (
443
- []
444
- ) # List to hold transformation details for the current column
445
-
446
- if category in transform_params:
447
- for transformation, values in transform_params[
448
- category
449
- ].items():
450
- # Generate transformed column names for each value
451
- for value in values:
452
- transformed_name = (
453
- f"{column}@{transformation}_{value}"
454
- )
455
- transformed_columns[column].append(
456
- transformed_name
457
- )
458
-
459
- # Format the values list as a string with commas and "and" before the last item
460
- if len(values) > 1:
461
- formatted_values = (
462
- ", ".join(map(str, values[:-1]))
463
- + " and "
464
- + str(values[-1])
465
- )
466
- else:
467
- formatted_values = str(values[0])
468
-
469
- # Add transformation details
470
- summary_details.append(
471
- f"{transformation} ({formatted_values})"
472
- )
473
-
474
- # Only add to summary if there are transformation details for the column
475
- if summary_details:
476
- formatted_summary = "⮕ ".join(summary_details)
477
- # Use <strong> tags to make the column name bold
478
- summary[column] = (
479
- f"<strong>{column}</strong>: {formatted_summary}"
480
- )
481
-
482
- # Generate a comprehensive summary string for all columns
483
- summary_items = [
484
- f"{idx + 1}. {details}"
485
- for idx, details in enumerate(summary.values())
486
- ]
487
-
488
- summary_string = "\n".join(summary_items)
489
-
490
- return transformed_columns, summary_string
491
-
492
- # Function to apply transformations to DataFrame slices based on specified categories and parameters
493
- @st.cache_resource(show_spinner=False)
494
- def apply_category_transformations(df, bin_dict, transform_params, panel):
495
- # Dictionary for function mapping
496
- transformation_functions = {
497
- "Lead": apply_lead,
498
- "Lag": apply_lag,
499
- "Moving Average": apply_moving_average,
500
- "Saturation": apply_saturation,
501
- "Power": apply_power,
502
- "Adstock": apply_adstock,
503
- }
504
-
505
- # Initialize category_df as an empty DataFrame
506
- category_df = pd.DataFrame()
507
-
508
- # Iterate through each category specified in transform_params
509
- for category in ["Media", "Internal", "Exogenous"]:
510
- if (
511
- category not in transform_params
512
- or category not in bin_dict
513
- or not transform_params[category]
514
- ):
515
- continue # Skip categories without transformations
516
-
517
- # Slice the DataFrame based on the columns specified in bin_dict for the current category
518
- df_slice = df[bin_dict[category] + panel]
519
-
520
- # Iterate through each transformation and its parameters for the current category
521
- for transformation, parameters in transform_params[
522
- category
523
- ].items():
524
- transformation_function = transformation_functions[
525
- transformation
526
- ]
527
-
528
- # Check if there is panel data to group by
529
- if len(panel) > 0:
530
- # Apply the transformation to each group
531
- category_df = pd.concat(
532
- [
533
- df_slice.groupby(panel)
534
- .transform(transformation_function, p)
535
- .add_suffix(f"@{transformation}_{p}")
536
- for p in parameters
537
- ],
538
- axis=1,
539
- )
540
-
541
- # Replace all NaN or null values in category_df with 0
542
- category_df.fillna(0, inplace=True)
543
-
544
- # Update df_slice
545
- df_slice = pd.concat(
546
- [df[panel], category_df],
547
- axis=1,
548
- )
549
-
550
- else:
551
- for p in parameters:
552
- # Apply the transformation function to each column
553
- temp_df = df_slice.apply(
554
- lambda x: transformation_function(x, p), axis=0
555
- ).rename(
556
- lambda x: f"{x}@{transformation}_{p}",
557
- axis="columns",
558
- )
559
- # Concatenate the transformed DataFrame slice to the category DataFrame
560
- category_df = pd.concat([category_df, temp_df], axis=1)
561
-
562
- # Replace all NaN or null values in category_df with 0
563
- category_df.fillna(0, inplace=True)
564
-
565
- # Update df_slice
566
- df_slice = pd.concat(
567
- [df[panel], category_df],
568
- axis=1,
569
- )
570
-
571
- # If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
572
- if not category_df.empty:
573
- final_df = pd.concat([df, category_df], axis=1)
574
- else:
575
- # If no transformations were applied, use the original DataFrame
576
- final_df = df
577
-
578
- return final_df
579
-
580
- # Function to infers the granularity of the date column in a DataFrame
581
- @st.cache_resource(show_spinner=False)
582
- def infer_date_granularity(df):
583
- # Find the most common difference
584
- common_freq = (
585
- pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]
586
- )
587
-
588
- # Map the most common difference to a granularity
589
- if common_freq == 1:
590
- return "daily"
591
- elif common_freq == 7:
592
- return "weekly"
593
- elif 28 <= common_freq <= 31:
594
- return "monthly"
595
- else:
596
- return "irregular"
597
-
598
- #########################################################################################################################################################
599
- # User input for transformations
600
- #########################################################################################################################################################
601
-
602
- # Infer date granularity
603
- date_granularity = infer_date_granularity(final_df_loaded)
604
-
605
- # Initialize the main dictionary to store the transformation parameters for each category
606
- transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}
607
-
608
- # User input for transformations
609
- st.markdown("### Select Transformations to Apply")
610
- for category in ["Media", "Internal", "Exogenous"]:
611
- # Skip Internal
612
- if category == "Internal":
613
- continue
614
-
615
- transformation_widgets(category, transform_params, date_granularity)
616
-
617
- #########################################################################################################################################################
618
- # Apply transformations
619
- #########################################################################################################################################################
620
-
621
- # Apply category-based transformations to the DataFrame
622
- if st.button("Accept and Proceed", use_container_width=True):
623
- with st.spinner("Applying transformations..."):
624
- final_df = apply_category_transformations(
625
- final_df_loaded, bin_dict_loaded, transform_params, panel
626
- )
627
-
628
- # Generate a dictionary mapping original column names to lists of transformed column names
629
- transformed_columns_dict, summary_string = (
630
- generate_transformed_columns(
631
- original_columns, transform_params
632
- )
633
- )
634
-
635
- # Store into transformed dataframe and summary session state
636
- st.session_state["final_df"] = final_df
637
- st.session_state["summary_string"] = summary_string
638
-
639
- #########################################################################################################################################################
640
- # Display the transformed DataFrame and summary
641
- #########################################################################################################################################################
642
-
643
- # Display the transformed DataFrame in the Streamlit app
644
- st.markdown("### Transformed DataFrame")
645
- st.dataframe(st.session_state["final_df"], hide_index=True)
646
-
647
- # Total rows and columns
648
- total_rows, total_columns = st.session_state["final_df"].shape
649
- st.markdown(
650
- f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
651
- unsafe_allow_html=True,
652
- )
653
-
654
- # Display the summary of transformations as markdown
655
- if st.session_state["summary_string"]:
656
- with st.expander("Summary of Transformations"):
657
- st.markdown("### Summary of Transformations")
658
- st.markdown(
659
- st.session_state["summary_string"], unsafe_allow_html=True
660
- )
661
-
662
- @st.cache_resource(show_spinner=False)
663
- def save_to_pickle(file_path, final_df):
664
- # Open the file in write-binary mode and dump the objects
665
- with open(file_path, "wb") as f:
666
- pickle.dump({"final_df_transformed": final_df}, f)
667
- # Data is now saved to file
668
-
669
- if st.button("Accept and Save", use_container_width=True):
670
-
671
- save_to_pickle(
672
- os.path.join(
673
- st.session_state["project_path"], "final_df_transformed.pkl"
674
- ),
675
- st.session_state["final_df"],
676
- )
677
- project_dct_path = os.path.join(
678
- st.session_state["project_path"], "project_dct.pkl"
679
- )
680
-
681
- with open(project_dct_path, "wb") as f:
682
- pickle.dump(st.session_state["project_dct"], f)
683
-
684
- update_db("3_Transformations.py")
685
-
686
- st.toast("💾 Saved Successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/4_Model_Build.py DELETED
@@ -1,1062 +0,0 @@
1
- """
2
- MMO Build Sprint 3
3
- additions : adding more variables to session state for saved model : random effect, predicted train & test
4
-
5
- MMO Build Sprint 4
6
- additions : ability to run models for different response metrics
7
- """
8
-
9
- import streamlit as st
10
- import pandas as pd
11
- import plotly.express as px
12
- import plotly.graph_objects as go
13
- from Eda_functions import format_numbers
14
- import numpy as np
15
- import pickle
16
- from st_aggrid import AgGrid
17
- from st_aggrid import GridOptionsBuilder, GridUpdateMode
18
- from utilities import set_header, load_local_css
19
- from st_aggrid import GridOptionsBuilder
20
- import time
21
- import itertools
22
- import statsmodels.api as sm
23
- import numpy as npc
24
- import re
25
- import itertools
26
- from sklearn.metrics import (
27
- mean_absolute_error,
28
- r2_score,
29
- mean_absolute_percentage_error,
30
- )
31
- from sklearn.preprocessing import MinMaxScaler
32
- import os
33
- import matplotlib.pyplot as plt
34
- from statsmodels.stats.outliers_influence import variance_inflation_factor
35
- import yaml
36
- from yaml import SafeLoader
37
- import streamlit_authenticator as stauth
38
-
39
- st.set_option("deprecation.showPyplotGlobalUse", False)
40
- import statsmodels.api as sm
41
- import statsmodels.formula.api as smf
42
-
43
- from datetime import datetime
44
- import seaborn as sns
45
- from Data_prep_functions import *
46
- import sqlite3
47
- from utilities import update_db
48
-
49
-
50
- @st.cache_resource(show_spinner=False)
51
- # def save_to_pickle(file_path, final_df):
52
- # # Open the file in write-binary mode and dump the objects
53
- # with open(file_path, "wb") as f:
54
- # pickle.dump({file_path: final_df}, f)
55
-
56
-
57
- def get_random_effects(media_data, panel_col, _mdf):
58
- random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
59
-
60
- for i, market in enumerate(media_data[panel_col].unique()):
61
- print(i, end="\r")
62
- intercept = _mdf.random_effects[market].values[0]
63
- random_eff_df.loc[i, "random_effect"] = intercept
64
- random_eff_df.loc[i, panel_col] = market
65
-
66
- return random_eff_df
67
-
68
-
69
- def mdf_predict(X_df, mdf, random_eff_df):
70
- X = X_df.copy()
71
- X["fixed_effect"] = mdf.predict(X)
72
- X = pd.merge(X, random_eff_df, on=panel_col, how="left")
73
- X["pred"] = X["fixed_effect"] + X["random_effect"]
74
- # X.to_csv('Test/megred_df.csv',index=False)
75
- X.drop(columns=["fixed_effect", "random_effect"], inplace=True)
76
- return X["pred"]
77
-
78
-
79
- st.set_page_config(
80
- page_title="Model Build",
81
- page_icon=":shark:",
82
- layout="wide",
83
- initial_sidebar_state="collapsed",
84
- )
85
-
86
- load_local_css("styles.css")
87
- set_header()
88
-
89
- # Check for authentication status
90
- for k, v in st.session_state.items():
91
- if k not in [
92
- "logout",
93
- "login",
94
- "config",
95
- "model_build_button",
96
- ] and not k.startswith("FormSubmitter"):
97
- st.session_state[k] = v
98
- with open("config.yaml") as file:
99
- config = yaml.load(file, Loader=SafeLoader)
100
- st.session_state["config"] = config
101
- authenticator = stauth.Authenticate(
102
- config["credentials"],
103
- config["cookie"]["name"],
104
- config["cookie"]["key"],
105
- config["cookie"]["expiry_days"],
106
- config["preauthorized"],
107
- )
108
- st.session_state["authenticator"] = authenticator
109
- name, authentication_status, username = authenticator.login("Login", "main")
110
- auth_status = st.session_state.get("authentication_status")
111
-
112
- if auth_status == True:
113
- authenticator.logout("Logout", "main")
114
- is_state_initiaized = st.session_state.get("initialized", False)
115
-
116
- conn = sqlite3.connect(
117
- r"DB/User.db", check_same_thread=False
118
- ) # connection with sql db
119
- c = conn.cursor()
120
-
121
- if not is_state_initiaized:
122
-
123
- if "session_name" not in st.session_state:
124
- st.session_state["session_name"] = None
125
-
126
- if "project_dct" not in st.session_state:
127
- st.error("Please load a project from Home page")
128
- st.stop()
129
-
130
- st.title("1. Build Your Model")
131
-
132
- if not os.path.exists(
133
- os.path.join(st.session_state["project_path"], "data_import.pkl")
134
- ):
135
- st.error("Please move to Data Import Page and save.")
136
- st.stop()
137
- with open(
138
- os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
139
- ) as f:
140
- data = pickle.load(f)
141
- st.session_state["bin_dict"] = data["bin_dict"]
142
-
143
- if not os.path.exists(
144
- os.path.join(
145
- st.session_state["project_path"], "final_df_transformed.pkl"
146
- )
147
- ):
148
- st.error(
149
- "Please move to Transformation Page and save transformations."
150
- )
151
- st.stop()
152
- with open(
153
- os.path.join(
154
- st.session_state["project_path"], "final_df_transformed.pkl"
155
- ),
156
- "rb",
157
- ) as f:
158
- data = pickle.load(f)
159
- media_data = data["final_df_transformed"]
160
- #media_data.to_csv("Test/media_data.csv", index=False)
161
- train_idx = int(len(media_data) / 5) * 4
162
- # Sprint4 - available response metrics is a list of all reponse metrics in the data
163
- ## these will be put in a drop down
164
-
165
- st.session_state["media_data"] = media_data
166
-
167
- if "available_response_metrics" not in st.session_state:
168
- # st.session_state['available_response_metrics'] = ['Total Approved Accounts - Revenue',
169
- # 'Total Approved Accounts - Appsflyer',
170
- # 'Account Requests - Appsflyer',
171
- # 'App Installs - Appsflyer']
172
-
173
- st.session_state["available_response_metrics"] = st.session_state[
174
- "bin_dict"
175
- ]["Response Metrics"]
176
- # Sprint4
177
- if "is_tuned_model" not in st.session_state:
178
- st.session_state["is_tuned_model"] = {}
179
- for resp_metric in st.session_state["available_response_metrics"]:
180
- resp_metric = (
181
- resp_metric.lower()
182
- .replace(" ", "_")
183
- .replace("-", "")
184
- .replace(":", "")
185
- .replace("__", "_")
186
- )
187
- st.session_state["is_tuned_model"][resp_metric] = False
188
-
189
- # Sprint4 - used_response_metrics is a list of resp metrics for which user has created & saved a model
190
- if "used_response_metrics" not in st.session_state:
191
- st.session_state["used_response_metrics"] = []
192
-
193
- # Sprint4 - saved_model_names
194
- if "saved_model_names" not in st.session_state:
195
- st.session_state["saved_model_names"] = []
196
-
197
- if "Model" not in st.session_state:
198
- if (
199
- "session_state_saved"
200
- in st.session_state["project_dct"]["model_build"].keys()
201
- and st.session_state["project_dct"]["model_build"][
202
- "session_state_saved"
203
- ]
204
- is not None
205
- and "Model"
206
- in st.session_state["project_dct"]["model_build"][
207
- "session_state_saved"
208
- ].keys()
209
- ):
210
- st.session_state["Model"] = st.session_state["project_dct"][
211
- "model_build"
212
- ]["session_state_saved"]["Model"]
213
- else:
214
- st.session_state["Model"] = {}
215
-
216
- # Sprint4 - select a response metric
217
- default_target_idx = (
218
- st.session_state["project_dct"]["model_build"].get(
219
- "sel_target_col", None
220
- )
221
- if st.session_state["project_dct"]["model_build"].get(
222
- "sel_target_col", None
223
- )
224
- is not None
225
- else st.session_state["available_response_metrics"][0]
226
- )
227
-
228
- sel_target_col = st.selectbox(
229
- "Select the response metric",
230
- st.session_state["available_response_metrics"],
231
- index=st.session_state["available_response_metrics"].index(
232
- default_target_idx
233
- ),
234
- )
235
- # , on_change=reset_save())
236
- st.session_state["project_dct"]["model_build"][
237
- "sel_target_col"
238
- ] = sel_target_col
239
-
240
- target_col = (
241
- sel_target_col.lower()
242
- .replace(" ", "_")
243
- .replace("-", "")
244
- .replace(":", "")
245
- .replace("__", "_")
246
- )
247
- new_name_dct = {
248
- col: col.lower()
249
- .replace(".", "_")
250
- .lower()
251
- .replace("@", "_")
252
- .replace(" ", "_")
253
- .replace("-", "")
254
- .replace(":", "")
255
- .replace("__", "_")
256
- for col in media_data.columns
257
- }
258
- media_data.columns = [
259
- col.lower()
260
- .replace(".", "_")
261
- .replace("@", "_")
262
- .replace(" ", "_")
263
- .replace("-", "")
264
- .replace(":", "")
265
- .replace("__", "_")
266
- for col in media_data.columns
267
- ]
268
- panel_col = [
269
- col.lower()
270
- .replace(".", "_")
271
- .replace("@", "_")
272
- .replace(" ", "_")
273
- .replace("-", "")
274
- .replace(":", "")
275
- .replace("__", "_")
276
- for col in st.session_state["bin_dict"]["Panel Level 1"]
277
- ][
278
- 0
279
- ] # set the panel column
280
- date_col = "date"
281
-
282
- is_panel = True if len(panel_col) > 0 else False
283
-
284
- if "is_panel" not in st.session_state:
285
- st.session_state["is_panel"] = is_panel
286
-
287
- if is_panel:
288
- media_data.sort_values([date_col, panel_col], inplace=True)
289
- else:
290
- media_data.sort_values(date_col, inplace=True)
291
-
292
- media_data.reset_index(drop=True, inplace=True)
293
-
294
- date = media_data[date_col]
295
- st.session_state["date"] = date
296
- y = media_data[target_col]
297
-
298
- if is_panel:
299
- spends_data = media_data[
300
- [
301
- c
302
- for c in media_data.columns
303
- if "_cost" in c.lower() or "_spend" in c.lower()
304
- ]
305
- + [date_col, panel_col]
306
- ]
307
- # Sprint3 - spends for resp curves
308
- else:
309
- spends_data = media_data[
310
- [
311
- c
312
- for c in media_data.columns
313
- if "_cost" in c.lower() or "_spend" in c.lower()
314
- ]
315
- + [date_col]
316
- ]
317
-
318
- y = media_data[target_col]
319
- media_data.drop([date_col], axis=1, inplace=True)
320
- media_data.reset_index(drop=True, inplace=True)
321
-
322
- columns = st.columns(2)
323
-
324
- old_shape = media_data.shape
325
-
326
- if "old_shape" not in st.session_state:
327
- st.session_state["old_shape"] = old_shape
328
-
329
- if "media_data" not in st.session_state:
330
- st.session_state["media_data"] = pd.DataFrame()
331
-
332
- # Sprint3
333
- if "orig_media_data" not in st.session_state:
334
- st.session_state["orig_media_data"] = pd.DataFrame()
335
-
336
- # Sprint3 additions
337
- if "random_effects" not in st.session_state:
338
- st.session_state["random_effects"] = pd.DataFrame()
339
- if "pred_train" not in st.session_state:
340
- st.session_state["pred_train"] = []
341
- if "pred_test" not in st.session_state:
342
- st.session_state["pred_test"] = []
343
- # end of Sprint3 additions
344
-
345
- # Section 3 - Create combinations
346
-
347
- # bucket=['paid_search', 'kwai','indicacao','infleux', 'influencer','FB: Level Achieved - Tier 1 Impressions',
348
- # ' FB: Level Achieved - Tier 2 Impressions','paid_social_others',
349
- # ' GA App: Will And Cid Pequena Baixo Risco Clicks',
350
- # 'digital_tactic_others',"programmatic"
351
- # ]
352
-
353
- # srishti - bucket names changed
354
- bucket = [
355
- "paid_search",
356
- "kwai",
357
- "indicacao",
358
- "infleux",
359
- "influencer",
360
- "fb_level_achieved_tier_2",
361
- "fb_level_achieved_tier_1",
362
- "paid_social_others",
363
- "ga_app",
364
- "digital_tactic_others",
365
- "programmatic",
366
- ]
367
-
368
- # with columns[0]:
369
- # if st.button('Create Combinations of Variables'):
370
-
371
- top_3_correlated_features = []
372
- # # for col in st.session_state['media_data'].columns[:19]:
373
- # original_cols = [c for c in st.session_state['media_data'].columns if
374
- # "_clicks" in c.lower() or "_impressions" in c.lower()]
375
- # original_cols = [c for c in original_cols if "_lag" not in c.lower() and "_adstock" not in c.lower()]
376
-
377
- original_cols = (
378
- st.session_state["bin_dict"]["Media"]
379
- + st.session_state["bin_dict"]["Internal"]
380
- )
381
-
382
- original_cols = [
383
- col.lower()
384
- .replace(".", "_")
385
- .replace("@", "_")
386
- .replace(" ", "_")
387
- .replace("-", "")
388
- .replace(":", "")
389
- .replace("__", "_")
390
- for col in original_cols
391
- ]
392
- original_cols = [col for col in original_cols if "_cost" not in col]
393
- # for col in st.session_state['media_data'].columns[:19]:
394
- for col in original_cols: # srishti - new
395
- corr_df = (
396
- pd.concat(
397
- [st.session_state["media_data"].filter(regex=col), y], axis=1
398
- )
399
- .corr()[target_col]
400
- .iloc[:-1]
401
- )
402
- top_3_correlated_features.append(
403
- list(corr_df.sort_values(ascending=False).head(2).index)
404
- )
405
- flattened_list = [
406
- item for sublist in top_3_correlated_features for item in sublist
407
- ]
408
- # all_features_set={var:[col for col in flattened_list if var in col] for var in bucket}
409
- all_features_set = {
410
- var: [col for col in flattened_list if var in col]
411
- for var in bucket
412
- if len([col for col in flattened_list if var in col]) > 0
413
- } # srishti
414
- channels_all = [values for values in all_features_set.values()]
415
- st.session_state["combinations"] = list(itertools.product(*channels_all))
416
- # if 'combinations' not in st.session_state:
417
- # st.session_state['combinations']=combinations_all
418
-
419
- st.session_state["final_selection"] = st.session_state["combinations"]
420
- # st.success('Created combinations')
421
-
422
- # revenue.reset_index(drop=True,inplace=True)
423
- y.reset_index(drop=True, inplace=True)
424
- if "Model_results" not in st.session_state:
425
- st.session_state["Model_results"] = {
426
- "Model_object": [],
427
- "Model_iteration": [],
428
- "Feature_set": [],
429
- "MAPE": [],
430
- "R2": [],
431
- "ADJR2": [],
432
- "pos_count": [],
433
- }
434
-
435
- def reset_model_result_dct():
436
- st.session_state["Model_results"] = {
437
- "Model_object": [],
438
- "Model_iteration": [],
439
- "Feature_set": [],
440
- "MAPE": [],
441
- "R2": [],
442
- "ADJR2": [],
443
- "pos_count": [],
444
- }
445
-
446
- # if st.button('Build Model'):
447
-
448
- if "iterations" not in st.session_state:
449
- st.session_state["iterations"] = 0
450
-
451
- if "final_selection" not in st.session_state:
452
- st.session_state["final_selection"] = False
453
-
454
- save_path = r"Model/"
455
- if st.session_state["final_selection"]:
456
- st.write(
457
- f'Total combinations created {format_numbers(len(st.session_state["final_selection"]))}'
458
- )
459
-
460
- # st.session_state["project_dct"]["model_build"]["all_iters_check"] = False
461
-
462
- checkbox_default = (
463
- st.session_state["project_dct"]["model_build"]["all_iters_check"]
464
- if st.session_state["project_dct"]["model_build"]["all_iters_check"]
465
- is not None
466
- else False
467
- )
468
-
469
- if st.checkbox("Build all iterations", value=checkbox_default):
470
- # st.session_state["project_dct"]["model_build"]["all_iters_check"]
471
- iterations = len(st.session_state["final_selection"])
472
- st.session_state["project_dct"]["model_build"][
473
- "all_iters_check"
474
- ] = True
475
-
476
- else:
477
- iterations = st.number_input(
478
- "Select the number of iterations to perform",
479
- min_value=0,
480
- step=100,
481
- value=st.session_state["iterations"],
482
- on_change=reset_model_result_dct,
483
- )
484
- st.session_state["project_dct"]["model_build"][
485
- "all_iters_check"
486
- ] = False
487
- st.session_state["project_dct"]["model_build"][
488
- "iterations"
489
- ] = iterations
490
-
491
- # st.stop()
492
-
493
- # build_button = st.session_state["project_dct"]["model_build"]["build_button"] if \
494
- # "build_button" in st.session_state["project_dct"]["model_build"].keys() else False
495
- # model_button =st.button('Build Model', on_click=reset_model_result_dct, key='model_build_button')
496
- # if
497
- # if model_button:
498
- if st.button(
499
- "Build Model",
500
- on_click=reset_model_result_dct,
501
- key="model_build_button",
502
- ):
503
- if iterations < 1:
504
- st.error("Please select number of iterations")
505
- st.stop()
506
- st.session_state["project_dct"]["model_build"]["build_button"] = True
507
- st.session_state["iterations"] = iterations
508
-
509
- # Section 4 - Model
510
- # st.session_state['media_data'] = st.session_state['media_data'].fillna(method='ffill')
511
- st.session_state["media_data"] = st.session_state["media_data"].ffill()
512
- st.markdown(
513
- "Data Split -- Training Period: May 9th, 2023 - October 5th,2023 , Testing Period: October 6th, 2023 - November 7th, 2023 "
514
- )
515
- progress_bar = st.progress(0) # Initialize the progress bar
516
- # time_remaining_text = st.empty() # Create an empty space for time remaining text
517
- start_time = time.time() # Record the start time
518
- progress_text = st.empty()
519
-
520
- # time_elapsed_text = st.empty()
521
- # for i, selected_features in enumerate(st.session_state["final_selection"][40000:40000 + int(iterations)]):
522
- # for i, selected_features in enumerate(st.session_state["final_selection"]):
523
-
524
- if is_panel == True:
525
- for i, selected_features in enumerate(
526
- st.session_state["final_selection"][0 : int(iterations)]
527
- ): # srishti
528
- df = st.session_state["media_data"]
529
-
530
- fet = [var for var in selected_features if len(var) > 0]
531
- inp_vars_str = " + ".join(fet) # new
532
-
533
- X = df[fet]
534
- y = df[target_col]
535
- ss = MinMaxScaler()
536
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
537
-
538
- X[target_col] = y # Sprint2
539
- X[panel_col] = df[panel_col] # Sprint2
540
-
541
- X_train = X.iloc[:train_idx]
542
- X_test = X.iloc[train_idx:]
543
- y_train = y.iloc[:train_idx]
544
- y_test = y.iloc[train_idx:]
545
-
546
- print(X_train.shape)
547
- # model = sm.OLS(y_train, X_train).fit()
548
- md_str = target_col + " ~ " + inp_vars_str
549
- # md = smf.mixedlm("total_approved_accounts_revenue ~ {}".format(inp_vars_str),
550
- # data=X_train[[target_col] + fet],
551
- # groups=X_train[panel_col])
552
- md = smf.mixedlm(
553
- md_str,
554
- data=X_train[[target_col] + fet],
555
- groups=X_train[panel_col],
556
- )
557
- mdf = md.fit()
558
- predicted_values = mdf.fittedvalues
559
-
560
- coefficients = mdf.fe_params.to_dict()
561
- model_positive = [
562
- col for col in coefficients.keys() if coefficients[col] > 0
563
- ]
564
-
565
- pvalues = [var for var in list(mdf.pvalues) if var <= 0.06]
566
-
567
- if (len(model_positive) / len(selected_features)) > 0 and (
568
- len(pvalues) / len(selected_features)
569
- ) >= 0: # srishti - changed just for testing, revert later
570
- # predicted_values = model.predict(X_train)
571
- mape = mean_absolute_percentage_error(
572
- y_train, predicted_values
573
- )
574
- r2 = r2_score(y_train, predicted_values)
575
- adjr2 = 1 - (1 - r2) * (len(y_train) - 1) / (
576
- len(y_train) - len(selected_features) - 1
577
- )
578
-
579
- filename = os.path.join(save_path, f"model_{i}.pkl")
580
- with open(filename, "wb") as f:
581
- pickle.dump(mdf, f)
582
- # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
583
- # model = pickle.load(file)
584
-
585
- st.session_state["Model_results"]["Model_object"].append(
586
- filename
587
- )
588
- st.session_state["Model_results"][
589
- "Model_iteration"
590
- ].append(i)
591
- st.session_state["Model_results"]["Feature_set"].append(
592
- fet
593
- )
594
- st.session_state["Model_results"]["MAPE"].append(mape)
595
- st.session_state["Model_results"]["R2"].append(r2)
596
- st.session_state["Model_results"]["pos_count"].append(
597
- len(model_positive)
598
- )
599
- st.session_state["Model_results"]["ADJR2"].append(adjr2)
600
-
601
- current_time = time.time()
602
- time_taken = current_time - start_time
603
- time_elapsed_minutes = time_taken / 60
604
- completed_iterations_text = f"{i + 1}/{iterations}"
605
- progress_bar.progress((i + 1) / int(iterations))
606
- progress_text.text(
607
- f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
608
- )
609
- st.write(
610
- f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
611
- )
612
-
613
- else:
614
-
615
- for i, selected_features in enumerate(
616
- st.session_state["final_selection"][0 : int(iterations)]
617
- ): # srishti
618
- df = st.session_state["media_data"]
619
-
620
- fet = [var for var in selected_features if len(var) > 0]
621
- inp_vars_str = " + ".join(fet)
622
-
623
- X = df[fet]
624
- y = df[target_col]
625
- ss = MinMaxScaler()
626
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
627
- X = sm.add_constant(X)
628
- X_train = X.iloc[:130]
629
- X_test = X.iloc[130:]
630
- y_train = y.iloc[:130]
631
- y_test = y.iloc[130:]
632
-
633
- model = sm.OLS(y_train, X_train).fit()
634
-
635
- coefficients = model.params.to_list()
636
- model_positive = [coef for coef in coefficients if coef > 0]
637
- predicted_values = model.predict(X_train)
638
- pvalues = [var for var in list(model.pvalues) if var <= 0.06]
639
-
640
- # if (len(model_possitive) / len(selected_features)) > 0.9 and (len(pvalues) / len(selected_features)) >= 0.8:
641
- if (len(model_positive) / len(selected_features)) > 0 and (
642
- len(pvalues) / len(selected_features)
643
- ) >= 0.5: # srishti - changed just for testing, revert later VALID MODEL CRITERIA
644
- # predicted_values = model.predict(X_train)
645
- mape = mean_absolute_percentage_error(
646
- y_train, predicted_values
647
- )
648
- adjr2 = model.rsquared_adj
649
- r2 = model.rsquared
650
-
651
- filename = os.path.join(save_path, f"model_{i}.pkl")
652
- with open(filename, "wb") as f:
653
- pickle.dump(model, f)
654
- # with open(r"C:\Users\ManojP\Documents\MMM\simopt\Model\model.pkl", 'rb') as file:
655
- # model = pickle.load(file)
656
-
657
- st.session_state["Model_results"]["Model_object"].append(
658
- filename
659
- )
660
- st.session_state["Model_results"][
661
- "Model_iteration"
662
- ].append(i)
663
- st.session_state["Model_results"]["Feature_set"].append(
664
- fet
665
- )
666
- st.session_state["Model_results"]["MAPE"].append(mape)
667
- st.session_state["Model_results"]["R2"].append(r2)
668
- st.session_state["Model_results"]["ADJR2"].append(adjr2)
669
- st.session_state["Model_results"]["pos_count"].append(
670
- len(model_positive)
671
- )
672
-
673
- current_time = time.time()
674
- time_taken = current_time - start_time
675
- time_elapsed_minutes = time_taken / 60
676
- completed_iterations_text = f"{i + 1}/{iterations}"
677
- progress_bar.progress((i + 1) / int(iterations))
678
- progress_text.text(
679
- f"Completed iterations: {completed_iterations_text},Time Elapsed (min): {time_elapsed_minutes:.2f}"
680
- )
681
- st.write(
682
- f'Out of {st.session_state["iterations"]} iterations : {len(st.session_state["Model_results"]["Model_object"])} valid models'
683
- )
684
-
685
- pd.DataFrame(st.session_state["Model_results"]).to_csv(
686
- "model_output.csv"
687
- )
688
-
689
- def to_percentage(value):
690
- return f"{value * 100:.1f}%"
691
-
692
- ## Section 5 - Select Model
693
- st.title("2. Select Models")
694
- show_results_defualt = (
695
- st.session_state["project_dct"]["model_build"]["show_results_check"]
696
- if st.session_state["project_dct"]["model_build"]["show_results_check"]
697
- is not None
698
- else False
699
- )
700
- if "tick" not in st.session_state:
701
- st.session_state["tick"] = False
702
- if st.checkbox(
703
- "Show results of top 10 models (based on MAPE and Adj. R2)",
704
- value=show_results_defualt,
705
- ):
706
- st.session_state["project_dct"]["model_build"][
707
- "show_results_check"
708
- ] = True
709
- st.session_state["tick"] = True
710
- st.write(
711
- "Select one model iteration to generate performance metrics for it:"
712
- )
713
- data = pd.DataFrame(st.session_state["Model_results"])
714
- data = data[data["pos_count"] == data["pos_count"].max()].reset_index(
715
- drop=True
716
- ) # Sprint4 -- Srishti -- only show models with the lowest num of neg coeffs
717
- data.sort_values(by=["ADJR2"], ascending=False, inplace=True)
718
- data.drop_duplicates(subset="Model_iteration", inplace=True)
719
- top_10 = data.head(10)
720
- top_10["Rank"] = np.arange(1, len(top_10) + 1, 1)
721
- top_10[["MAPE", "R2", "ADJR2"]] = np.round(
722
- top_10[["MAPE", "R2", "ADJR2"]], 4
723
- ).applymap(to_percentage)
724
- top_10_table = top_10[
725
- ["Rank", "Model_iteration", "MAPE", "ADJR2", "R2"]
726
- ]
727
- # top_10_table.columns=[['Rank','Model Iteration Index','MAPE','Adjusted R2','R2']]
728
- gd = GridOptionsBuilder.from_dataframe(top_10_table)
729
- gd.configure_pagination(enabled=True)
730
-
731
- gd.configure_selection(
732
- use_checkbox=True,
733
- selection_mode="single",
734
- pre_select_all_rows=False,
735
- pre_selected_rows=[1],
736
- )
737
-
738
- gridoptions = gd.build()
739
-
740
- table = AgGrid(
741
- top_10,
742
- gridOptions=gridoptions,
743
- update_mode=GridUpdateMode.SELECTION_CHANGED,
744
- )
745
-
746
- selected_rows = table.selected_rows
747
- # if st.session_state["selected_rows"] != selected_rows:
748
- # st.session_state["build_rc_cb"] = False
749
- st.session_state["selected_rows"] = selected_rows
750
-
751
- # Section 6 - Display Results
752
-
753
- if len(selected_rows) > 0:
754
- st.header("2.1 Results Summary")
755
-
756
- model_object = data[
757
- data["Model_iteration"] == selected_rows[0]["Model_iteration"]
758
- ]["Model_object"]
759
- features_set = data[
760
- data["Model_iteration"] == selected_rows[0]["Model_iteration"]
761
- ]["Feature_set"]
762
-
763
- with open(str(model_object.values[0]), "rb") as file:
764
- # print(file)
765
- model = pickle.load(file)
766
- st.write(model.summary())
767
- st.header("2.2 Actual vs. Predicted Plot")
768
-
769
- if is_panel:
770
- df = st.session_state["media_data"]
771
- X = df[features_set.values[0]]
772
- y = df[target_col]
773
-
774
- ss = MinMaxScaler()
775
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
776
-
777
- # Sprint2 changes
778
- X[target_col] = y # new
779
- X[panel_col] = df[panel_col]
780
- X[date_col] = date
781
-
782
- X_train = X.iloc[:train_idx]
783
- X_test = X.iloc[train_idx:].reset_index(drop=True)
784
- y_train = y.iloc[:train_idx]
785
- y_test = y.iloc[train_idx:].reset_index(drop=True)
786
-
787
- test_spends = spends_data[
788
- train_idx:
789
- ] # Sprint3 - test spends for resp curves
790
- random_eff_df = get_random_effects(
791
- media_data, panel_col, model
792
- )
793
- train_pred = model.fittedvalues
794
- test_pred = mdf_predict(X_test, model, random_eff_df)
795
- print("__" * 20, test_pred.isna().sum())
796
-
797
- else:
798
- df = st.session_state["media_data"]
799
- X = df[features_set.values[0]]
800
- y = df[target_col]
801
-
802
- ss = MinMaxScaler()
803
- X = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
804
- X = sm.add_constant(X)
805
-
806
- X[date_col] = date
807
-
808
- X_train = X.iloc[:130]
809
- X_test = X.iloc[130:].reset_index(drop=True)
810
- y_train = y.iloc[:130]
811
- y_test = y.iloc[130:].reset_index(drop=True)
812
-
813
- test_spends = spends_data[
814
- 130:
815
- ] # Sprint3 - test spends for resp curves
816
- train_pred = model.predict(
817
- X_train[features_set.values[0] + ["const"]]
818
- )
819
- test_pred = model.predict(
820
- X_test[features_set.values[0] + ["const"]]
821
- )
822
-
823
- # save x test to test - srishti
824
- # x_test_to_save = X_test.copy()
825
- # x_test_to_save['Actuals'] = y_test
826
- # x_test_to_save['Predictions'] = test_pred
827
- #
828
- # x_train_to_save = X_train.copy()
829
- # x_train_to_save['Actuals'] = y_train
830
- # x_train_to_save['Predictions'] = train_pred
831
- #
832
- # x_train_to_save.to_csv('Test/x_train_to_save.csv', index=False)
833
- # x_test_to_save.to_csv('Test/x_test_to_save.csv', index=False)
834
-
835
- st.session_state["X"] = X_train
836
- st.session_state["features_set"] = features_set.values[0]
837
- print(
838
- "**" * 20, "selected model features : ", features_set.values[0]
839
- )
840
- metrics_table, line, actual_vs_predicted_plot = (
841
- plot_actual_vs_predicted(
842
- X_train[date_col],
843
- y_train,
844
- train_pred,
845
- model,
846
- target_column=sel_target_col,
847
- is_panel=is_panel,
848
- )
849
- ) # Sprint2
850
-
851
- st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
852
-
853
- st.markdown("## 2.3 Residual Analysis")
854
- columns = st.columns(2)
855
- with columns[0]:
856
- fig = plot_residual_predicted(
857
- y_train, train_pred, X_train
858
- ) # Sprint2
859
- st.plotly_chart(fig)
860
-
861
- with columns[1]:
862
- st.empty()
863
- fig = qqplot(y_train, train_pred) # Sprint2
864
- st.plotly_chart(fig)
865
-
866
- with columns[0]:
867
- fig = residual_distribution(y_train, train_pred) # Sprint2
868
- st.pyplot(fig)
869
-
870
- vif_data = pd.DataFrame()
871
- # X=X.drop('const',axis=1)
872
- X_train_orig = (
873
- X_train.copy()
874
- ) # Sprint2 -- creating a copy of xtrain. Later deleting panel, target & date from xtrain
875
- del_col_list = list(
876
- set([target_col, panel_col, date_col]).intersection(
877
- set(X_train.columns)
878
- )
879
- )
880
- X_train.drop(columns=del_col_list, inplace=True) # Sprint2
881
-
882
- vif_data["Variable"] = X_train.columns
883
- vif_data["VIF"] = [
884
- variance_inflation_factor(X_train.values, i)
885
- for i in range(X_train.shape[1])
886
- ]
887
- vif_data.sort_values(by=["VIF"], ascending=False, inplace=True)
888
- vif_data = np.round(vif_data)
889
- vif_data["VIF"] = vif_data["VIF"].astype(float)
890
- st.header("2.4 Variance Inflation Factor (VIF)")
891
- # st.dataframe(vif_data)
892
- color_mapping = {
893
- "darkgreen": (vif_data["VIF"] < 3),
894
- "orange": (vif_data["VIF"] >= 3) & (vif_data["VIF"] <= 10),
895
- "darkred": (vif_data["VIF"] > 10),
896
- }
897
-
898
- # Create a horizontal bar plot
899
- fig, ax = plt.subplots()
900
- fig.set_figwidth(10) # Adjust the width of the figure as needed
901
-
902
- # Sort the bars by descending VIF values
903
- vif_data = vif_data.sort_values(by="VIF", ascending=False)
904
-
905
- # Iterate through the color mapping and plot bars with corresponding colors
906
- for color, condition in color_mapping.items():
907
- subset = vif_data[condition]
908
- bars = ax.barh(
909
- subset["Variable"], subset["VIF"], color=color, label=color
910
- )
911
-
912
- # Add text annotations on top of the bars
913
- for bar in bars:
914
- width = bar.get_width()
915
- ax.annotate(
916
- f"{width:}",
917
- xy=(width, bar.get_y() + bar.get_height() / 2),
918
- xytext=(5, 0),
919
- textcoords="offset points",
920
- va="center",
921
- )
922
-
923
- # Customize the plot
924
- ax.set_xlabel("VIF Values")
925
- # ax.set_title('2.4 Variance Inflation Factor (VIF)')
926
- # ax.legend(loc='upper right')
927
-
928
- # Display the plot in Streamlit
929
- st.pyplot(fig)
930
-
931
- with st.expander("Results Summary Test data"):
932
- # ss = MinMaxScaler()
933
- # X_test = pd.DataFrame(ss.fit_transform(X_test), columns=X_test.columns)
934
- st.header("2.2 Actual vs. Predicted Plot")
935
-
936
- metrics_table, line, actual_vs_predicted_plot = (
937
- plot_actual_vs_predicted(
938
- X_test[date_col],
939
- y_test,
940
- test_pred,
941
- model,
942
- target_column=sel_target_col,
943
- is_panel=is_panel,
944
- )
945
- ) # Sprint2
946
-
947
- st.plotly_chart(
948
- actual_vs_predicted_plot, use_container_width=True
949
- )
950
-
951
- st.markdown("## 2.3 Residual Analysis")
952
- columns = st.columns(2)
953
- with columns[0]:
954
- fig = plot_residual_predicted(
955
- y, test_pred, X_test
956
- ) # Sprint2
957
- st.plotly_chart(fig)
958
-
959
- with columns[1]:
960
- st.empty()
961
- fig = qqplot(y, test_pred) # Sprint2
962
- st.plotly_chart(fig)
963
-
964
- with columns[0]:
965
- fig = residual_distribution(y, test_pred) # Sprint2
966
- st.pyplot(fig)
967
-
968
- value = False
969
- save_button_model = st.checkbox(
970
- "Save this model to tune", key="build_rc_cb"
971
- ) # , on_click=set_save())
972
-
973
- if save_button_model:
974
- mod_name = st.text_input("Enter model name")
975
- if len(mod_name) > 0:
976
- mod_name = (
977
- mod_name + "__" + target_col
978
- ) # Sprint4 - adding target col to model name
979
- if is_panel:
980
- pred_train = model.fittedvalues
981
- pred_test = mdf_predict(X_test, model, random_eff_df)
982
- else:
983
- st.session_state["features_set"] = st.session_state[
984
- "features_set"
985
- ] + ["const"]
986
- pred_train = model.predict(
987
- X_train_orig[st.session_state["features_set"]]
988
- )
989
- pred_test = model.predict(
990
- X_test[st.session_state["features_set"]]
991
- )
992
-
993
- st.session_state["Model"][mod_name] = {
994
- "Model_object": model,
995
- "feature_set": st.session_state["features_set"],
996
- "X_train": X_train_orig,
997
- "X_test": X_test,
998
- "y_train": y_train,
999
- "y_test": y_test,
1000
- "pred_train": pred_train,
1001
- "pred_test": pred_test,
1002
- }
1003
- st.session_state["X_train"] = X_train_orig
1004
- st.session_state["X_test_spends"] = test_spends
1005
- st.session_state["saved_model_names"].append(mod_name)
1006
- # Sprint3 additions
1007
- if is_panel:
1008
- random_eff_df = get_random_effects(
1009
- media_data, panel_col, model
1010
- )
1011
- st.session_state["random_effects"] = random_eff_df
1012
-
1013
- with open(
1014
- os.path.join(
1015
- st.session_state["project_path"], "best_models.pkl"
1016
- ),
1017
- "wb",
1018
- ) as f:
1019
- pickle.dump(st.session_state["Model"], f)
1020
- st.success(
1021
- mod_name
1022
- + " model saved! Proceed to the next page to tune the model"
1023
- )
1024
-
1025
- urm = st.session_state["used_response_metrics"]
1026
- urm.append(sel_target_col)
1027
- st.session_state["used_response_metrics"] = list(
1028
- set(urm)
1029
- )
1030
- mod_name = ""
1031
- # Sprint4 - add the formatted name of the target col to used resp metrics
1032
- value = False
1033
-
1034
- st.session_state["project_dct"]["model_build"][
1035
- "session_state_saved"
1036
- ] = {}
1037
- for key in [
1038
- "Model",
1039
- "bin_dict",
1040
- "used_response_metrics",
1041
- "date",
1042
- "saved_model_names",
1043
- "media_data",
1044
- "X_test_spends",
1045
- ]:
1046
- st.session_state["project_dct"]["model_build"][
1047
- "session_state_saved"
1048
- ][key] = st.session_state[key]
1049
-
1050
- project_dct_path = os.path.join(
1051
- st.session_state["project_path"], "project_dct.pkl"
1052
- )
1053
- with open(project_dct_path, "wb") as f:
1054
- pickle.dump(st.session_state["project_dct"], f)
1055
-
1056
- update_db("4_Model_Build.py")
1057
-
1058
- st.toast("💾 Saved Successfully!")
1059
- else:
1060
- st.session_state["project_dct"]["model_build"][
1061
- "show_results_check"
1062
- ] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/5_Model_Tuning.py DELETED
@@ -1,912 +0,0 @@
1
- """
2
- MMO Build Sprint 3
3
- date :
4
- changes : capability to tune MixedLM as well as simple LR in the same page
5
- """
6
-
7
- import os
8
-
9
- import streamlit as st
10
- import pandas as pd
11
- from Eda_functions import format_numbers
12
- import pickle
13
- from utilities import set_header, load_local_css
14
- import statsmodels.api as sm
15
- import re
16
- from sklearn.preprocessing import MinMaxScaler
17
- import matplotlib.pyplot as plt
18
- from statsmodels.stats.outliers_influence import variance_inflation_factor
19
- import yaml
20
- from yaml import SafeLoader
21
- import streamlit_authenticator as stauth
22
-
23
- st.set_option("deprecation.showPyplotGlobalUse", False)
24
- import statsmodels.formula.api as smf
25
- from Data_prep_functions import *
26
- import sqlite3
27
- from utilities import update_db
28
-
29
- # for i in ["model_tuned", "X_train_tuned", "X_test_tuned", "tuned_model_features", "tuned_model", "tuned_model_dict"] :
30
-
31
- st.set_page_config(
32
- page_title="Model Tuning",
33
- page_icon=":shark:",
34
- layout="wide",
35
- initial_sidebar_state="collapsed",
36
- )
37
- load_local_css("styles.css")
38
- set_header()
39
- # Check for authentication status
40
- for k, v in st.session_state.items():
41
- # print(k, v)
42
- if k not in [
43
- "logout",
44
- "login",
45
- "config",
46
- "build_tuned_model",
47
- ] and not k.startswith("FormSubmitter"):
48
- st.session_state[k] = v
49
- with open("config.yaml") as file:
50
- config = yaml.load(file, Loader=SafeLoader)
51
- st.session_state["config"] = config
52
- authenticator = stauth.Authenticate(
53
- config["credentials"],
54
- config["cookie"]["name"],
55
- config["cookie"]["key"],
56
- config["cookie"]["expiry_days"],
57
- config["preauthorized"],
58
- )
59
- st.session_state["authenticator"] = authenticator
60
- name, authentication_status, username = authenticator.login("Login", "main")
61
- auth_status = st.session_state.get("authentication_status")
62
-
63
- if auth_status == True:
64
- authenticator.logout("Logout", "main")
65
- is_state_initiaized = st.session_state.get("initialized", False)
66
-
67
- if "project_dct" not in st.session_state:
68
- st.error("Please load a project from Home page")
69
- st.stop()
70
-
71
- if not os.path.exists(
72
- os.path.join(st.session_state["project_path"], "best_models.pkl")
73
- ):
74
- st.error("Please save a model before tuning")
75
- st.stop()
76
-
77
- conn = sqlite3.connect(
78
- r"DB/User.db", check_same_thread=False
79
- ) # connection with sql db
80
- c = conn.cursor()
81
-
82
- if not is_state_initiaized:
83
- if "session_name" not in st.session_state:
84
- st.session_state["session_name"] = None
85
-
86
- if (
87
- "session_state_saved"
88
- in st.session_state["project_dct"]["model_build"].keys()
89
- ):
90
- for key in [
91
- "Model",
92
- "date",
93
- "saved_model_names",
94
- "media_data",
95
- "X_test_spends",
96
- ]:
97
- if key not in st.session_state:
98
- st.session_state[key] = st.session_state["project_dct"][
99
- "model_build"
100
- ]["session_state_saved"][key]
101
- st.session_state["bin_dict"] = st.session_state["project_dct"][
102
- "model_build"
103
- ]["session_state_saved"]["bin_dict"]
104
- if (
105
- "used_response_metrics" not in st.session_state
106
- or st.session_state["used_response_metrics"] == []
107
- ):
108
- st.session_state["used_response_metrics"] = st.session_state[
109
- "project_dct"
110
- ]["model_build"]["session_state_saved"][
111
- "used_response_metrics"
112
- ]
113
- else:
114
- st.error("Please load a session with a built model")
115
- st.stop()
116
-
117
- # if 'sel_model' not in st.session_state["project_dct"]["model_tuning"].keys():
118
- # st.session_state["project_dct"]["model_tuning"]['sel_model']= {}
119
-
120
- for key in ["select_all_flags_check", "selected_flags", "sel_model"]:
121
- if key not in st.session_state["project_dct"]["model_tuning"].keys():
122
- st.session_state["project_dct"]["model_tuning"][key] = {}
123
- # Sprint3
124
- # is_panel = st.session_state['is_panel']
125
- # panel_col = 'markets' # set the panel column
126
- date_col = "date"
127
-
128
- panel_col = [
129
- col.lower()
130
- .replace(".", "_")
131
- .replace("@", "_")
132
- .replace(" ", "_")
133
- .replace("-", "")
134
- .replace(":", "")
135
- .replace("__", "_")
136
- for col in st.session_state["bin_dict"]["Panel Level 1"]
137
- ][
138
- 0
139
- ] # set the panel column
140
- is_panel = True if len(panel_col) > 0 else False
141
-
142
- # flag indicating there is not tuned model till now
143
-
144
- # Sprint4 - model tuned dict
145
- if "Model_Tuned" not in st.session_state:
146
- st.session_state["Model_Tuned"] = {}
147
-
148
- st.title("1. Model Tuning")
149
-
150
- if "is_tuned_model" not in st.session_state:
151
- st.session_state["is_tuned_model"] = {}
152
- # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
153
- if (
154
- "used_response_metrics" in st.session_state
155
- and st.session_state["used_response_metrics"] != []
156
- ):
157
- default_target_idx = (
158
- st.session_state["project_dct"]["model_tuning"].get(
159
- "sel_target_col", None
160
- )
161
- if st.session_state["project_dct"]["model_tuning"].get(
162
- "sel_target_col", None
163
- )
164
- is not None
165
- else st.session_state["used_response_metrics"][0]
166
- )
167
- sel_target_col = st.selectbox(
168
- "Select the response metric",
169
- st.session_state["used_response_metrics"],
170
- index=st.session_state["used_response_metrics"].index(
171
- default_target_idx
172
- ),
173
- )
174
- target_col = (
175
- sel_target_col.lower()
176
- .replace(" ", "_")
177
- .replace("-", "")
178
- .replace(":", "")
179
- .replace("__", "_")
180
- )
181
- st.session_state["project_dct"]["model_tuning"][
182
- "sel_target_col"
183
- ] = sel_target_col
184
-
185
- else:
186
- sel_target_col = "Total Approved Accounts - Revenue"
187
- target_col = "total_approved_accounts_revenue"
188
-
189
- # Sprint4 - Look through all saved models, only show saved models of the sel resp metric (target_col)
190
- # saved_models = st.session_state['saved_model_names']
191
- with open(
192
- os.path.join(st.session_state["project_path"], "best_models.pkl"), "rb"
193
- ) as file:
194
- model_dict = pickle.load(file)
195
-
196
- saved_models = model_dict.keys()
197
- required_saved_models = [
198
- m.split("__")[0]
199
- for m in saved_models
200
- if m.split("__")[1] == target_col
201
- ]
202
-
203
- if len(required_saved_models) > 0:
204
- default_model_idx = st.session_state["project_dct"]["model_tuning"][
205
- "sel_model"
206
- ].get(sel_target_col, required_saved_models[0])
207
- sel_model = st.selectbox(
208
- "Select the model to tune",
209
- required_saved_models,
210
- index=required_saved_models.index(default_model_idx),
211
- )
212
- else:
213
- default_model_idx = st.session_state["project_dct"]["model_tuning"][
214
- "sel_model"
215
- ].get(sel_target_col, 0)
216
- sel_model = st.selectbox(
217
- "Select the model to tune", required_saved_models
218
- )
219
-
220
- st.session_state["project_dct"]["model_tuning"]["sel_model"][
221
- sel_target_col
222
- ] = default_model_idx
223
-
224
- sel_model_dict = model_dict[
225
- sel_model + "__" + target_col
226
- ] # Sprint4 - get the model obj of the selected model
227
-
228
- X_train = sel_model_dict["X_train"]
229
- X_test = sel_model_dict["X_test"]
230
- y_train = sel_model_dict["y_train"]
231
- y_test = sel_model_dict["y_test"]
232
- df = st.session_state["media_data"]
233
-
234
- if "selected_model" not in st.session_state:
235
- st.session_state["selected_model"] = 0
236
-
237
- st.markdown("### 1.1 Event Flags")
238
- st.markdown(
239
- "Helps in quantifying the impact of specific occurrences of events"
240
- )
241
-
242
- flag_expander_default = (
243
- st.session_state["project_dct"]["model_tuning"].get(
244
- "flag_expander", None
245
- )
246
- if st.session_state["project_dct"]["model_tuning"].get(
247
- "flag_expander", None
248
- )
249
- is not None
250
- else False
251
- )
252
-
253
- with st.expander("Apply Event Flags", flag_expander_default):
254
- st.session_state["project_dct"]["model_tuning"]["flag_expander"] = True
255
-
256
- model = sel_model_dict["Model_object"]
257
- date = st.session_state["date"]
258
- date = pd.to_datetime(date)
259
- X_train = sel_model_dict["X_train"]
260
-
261
- # features_set= model_dict[st.session_state["selected_model"]]['feature_set']
262
- features_set = sel_model_dict["feature_set"]
263
-
264
- col = st.columns(3)
265
- min_date = min(date)
266
- max_date = max(date)
267
-
268
- start_date_default = (
269
- st.session_state["project_dct"]["model_tuning"].get(
270
- "start_date_default"
271
- )
272
- if st.session_state["project_dct"]["model_tuning"].get(
273
- "start_date_default"
274
- )
275
- is not None
276
- else min_date
277
- )
278
- end_date_default = (
279
- st.session_state["project_dct"]["model_tuning"].get(
280
- "end_date_default"
281
- )
282
- if st.session_state["project_dct"]["model_tuning"].get(
283
- "end_date_default"
284
- )
285
- is not None
286
- else max_date
287
- )
288
- with col[0]:
289
- start_date = st.date_input(
290
- "Select Start Date",
291
- start_date_default,
292
- min_value=min_date,
293
- max_value=max_date,
294
- )
295
- with col[1]:
296
- end_date_default = (
297
- end_date_default
298
- if end_date_default >= start_date
299
- else start_date
300
- )
301
- end_date = st.date_input(
302
- "Select End Date",
303
- end_date_default,
304
- min_value=max(min_date, start_date),
305
- max_value=max_date,
306
- )
307
- with col[2]:
308
- repeat_default = (
309
- st.session_state["project_dct"]["model_tuning"].get(
310
- "repeat_default"
311
- )
312
- if st.session_state["project_dct"]["model_tuning"].get(
313
- "repeat_default"
314
- )
315
- is not None
316
- else "No"
317
- )
318
- repeat_default_idx = 0 if repeat_default.lower() == "yes" else 1
319
- repeat = st.selectbox(
320
- "Repeat Annually", ["Yes", "No"], index=repeat_default_idx
321
- )
322
- st.session_state["project_dct"]["model_tuning"][
323
- "start_date_default"
324
- ] = start_date
325
- st.session_state["project_dct"]["model_tuning"][
326
- "end_date_default"
327
- ] = end_date
328
- st.session_state["project_dct"]["model_tuning"][
329
- "repeat_default"
330
- ] = repeat
331
-
332
- if repeat == "Yes":
333
- repeat = True
334
- else:
335
- repeat = False
336
-
337
- if "Flags" not in st.session_state:
338
- st.session_state["Flags"] = {}
339
- if "flags" in st.session_state["project_dct"]["model_tuning"].keys():
340
- st.session_state["Flags"] = st.session_state["project_dct"][
341
- "model_tuning"
342
- ]["flags"]
343
- # print("**"*50)
344
- # print(y_train)
345
- # print("**"*50)
346
- # print(model.fittedvalues)
347
- if is_panel: # Sprint3
348
- met, line_values, fig_flag = plot_actual_vs_predicted(
349
- X_train[date_col],
350
- y_train,
351
- model.fittedvalues,
352
- model,
353
- target_column=sel_target_col,
354
- flag=(start_date, end_date),
355
- repeat_all_years=repeat,
356
- is_panel=True,
357
- )
358
- st.plotly_chart(fig_flag, use_container_width=True)
359
-
360
- # create flag on test
361
- met, test_line_values, fig_flag = plot_actual_vs_predicted(
362
- X_test[date_col],
363
- y_test,
364
- sel_model_dict["pred_test"],
365
- model,
366
- target_column=sel_target_col,
367
- flag=(start_date, end_date),
368
- repeat_all_years=repeat,
369
- is_panel=True,
370
- )
371
-
372
- else:
373
- pred_train = model.predict(X_train[features_set])
374
- met, line_values, fig_flag = plot_actual_vs_predicted(
375
- X_train[date_col],
376
- y_train,
377
- pred_train,
378
- model,
379
- flag=(start_date, end_date),
380
- repeat_all_years=repeat,
381
- is_panel=False,
382
- )
383
- st.plotly_chart(fig_flag, use_container_width=True)
384
-
385
- pred_test = model.predict(X_test[features_set])
386
- met, test_line_values, fig_flag = plot_actual_vs_predicted(
387
- X_test[date_col],
388
- y_test,
389
- pred_test,
390
- model,
391
- flag=(start_date, end_date),
392
- repeat_all_years=repeat,
393
- is_panel=False,
394
- )
395
- flag_name = "f1_flag"
396
- flag_name = st.text_input("Enter Flag Name")
397
- # Sprint4 - add selected target col to flag name
398
- if st.button("Update flag"):
399
- st.session_state["Flags"][flag_name + "__" + target_col] = {}
400
- st.session_state["Flags"][flag_name + "__" + target_col][
401
- "train"
402
- ] = line_values
403
- st.session_state["Flags"][flag_name + "__" + target_col][
404
- "test"
405
- ] = test_line_values
406
- st.success(f'{flag_name + "__" + target_col} stored')
407
-
408
- st.session_state["project_dct"]["model_tuning"]["flags"] = (
409
- st.session_state["Flags"]
410
- )
411
- # Sprint4 - only show flag created for the particular target col
412
- if st.session_state["Flags"] is None:
413
- st.session_state["Flags"] = {}
414
- target_model_flags = [
415
- f.split("__")[0]
416
- for f in st.session_state["Flags"].keys()
417
- if f.split("__")[1] == target_col
418
- ]
419
- options = list(target_model_flags)
420
- selected_options = []
421
- num_columns = 4
422
- num_rows = -(-len(options) // num_columns)
423
-
424
- tick = False
425
- if st.checkbox(
426
- "Select all",
427
- value=st.session_state["project_dct"]["model_tuning"][
428
- "select_all_flags_check"
429
- ].get(sel_target_col, False),
430
- ):
431
- tick = True
432
- st.session_state["project_dct"]["model_tuning"][
433
- "select_all_flags_check"
434
- ][sel_target_col] = True
435
- else:
436
- st.session_state["project_dct"]["model_tuning"][
437
- "select_all_flags_check"
438
- ][sel_target_col] = False
439
- selection_defualts = st.session_state["project_dct"]["model_tuning"][
440
- "selected_flags"
441
- ].get(sel_target_col, [])
442
- selected_options = selection_defualts
443
- for row in range(num_rows):
444
- cols = st.columns(num_columns)
445
- for col in cols:
446
- if options:
447
- option = options.pop(0)
448
- option_default = (
449
- True if option in selection_defualts else False
450
- )
451
- selected = col.checkbox(option, value=(tick or option_default))
452
- if selected:
453
- selected_options.append(option)
454
- st.session_state["project_dct"]["model_tuning"]["selected_flags"][
455
- sel_target_col
456
- ] = selected_options
457
-
458
- st.markdown("### 1.2 Select Parameters to Apply")
459
- parameters = st.columns(3)
460
- with parameters[0]:
461
- Trend = st.checkbox(
462
- "**Trend**",
463
- value=st.session_state["project_dct"]["model_tuning"].get(
464
- "trend_check", False
465
- ),
466
- )
467
- st.markdown(
468
- "Helps account for long-term trends or seasonality that could influence advertising effectiveness"
469
- )
470
- with parameters[1]:
471
- week_number = st.checkbox(
472
- "**Week_number**",
473
- value=st.session_state["project_dct"]["model_tuning"].get(
474
- "week_num_check", False
475
- ),
476
- )
477
- st.markdown(
478
- "Assists in detecting and incorporating weekly patterns or seasonality"
479
- )
480
- with parameters[2]:
481
- sine_cosine = st.checkbox(
482
- "**Sine and Cosine Waves**",
483
- value=st.session_state["project_dct"]["model_tuning"].get(
484
- "sine_cosine_check", False
485
- ),
486
- )
487
- st.markdown(
488
- "Helps in capturing cyclical patterns or seasonality in the data"
489
- )
490
- #
491
- # def get_tuned_model():
492
- # st.session_state['build_tuned_model']=True
493
-
494
- if st.button(
495
- "Build model with Selected Parameters and Flags",
496
- key="build_tuned_model",
497
- ):
498
- new_features = features_set
499
- st.header("2.1 Results Summary")
500
- # date=list(df.index)
501
- # df = df.reset_index(drop=True)
502
- # X_train=df[features_set]
503
- ss = MinMaxScaler()
504
- if is_panel == True:
505
- X_train_tuned = X_train[features_set]
506
- # X_train_tuned = pd.DataFrame(ss.fit_transform(X), columns=X.columns)
507
- X_train_tuned[target_col] = X_train[target_col]
508
- X_train_tuned[date_col] = X_train[date_col]
509
- X_train_tuned[panel_col] = X_train[panel_col]
510
-
511
- X_test_tuned = X_test[features_set]
512
- # X_test_tuned = pd.DataFrame(ss.transform(X), columns=X.columns)
513
- X_test_tuned[target_col] = X_test[target_col]
514
- X_test_tuned[date_col] = X_test[date_col]
515
- X_test_tuned[panel_col] = X_test[panel_col]
516
-
517
- else:
518
- X_train_tuned = X_train[features_set]
519
- # X_train_tuned = pd.DataFrame(ss.fit_transform(X_train_tuned), columns=X_train_tuned.columns)
520
-
521
- X_test_tuned = X_test[features_set]
522
- # X_test_tuned = pd.DataFrame(ss.transform(X_test_tuned), columns=X_test_tuned.columns)
523
-
524
- for flag in selected_options:
525
- # Spirnt4 - added target_col in flag name
526
- X_train_tuned[flag] = st.session_state["Flags"][
527
- flag + "__" + target_col
528
- ]["train"]
529
- X_test_tuned[flag] = st.session_state["Flags"][
530
- flag + "__" + target_col
531
- ]["test"]
532
-
533
- # test
534
- # X_train_tuned.to_csv("Test/X_train_tuned_flag.csv",index=False)
535
- # X_test_tuned.to_csv("Test/X_test_tuned_flag.csv",index=False)
536
-
537
- # print("()()"*20,flag, len(st.session_state['Flags'][flag]))
538
- if Trend:
539
- st.session_state["project_dct"]["model_tuning"][
540
- "trend_check"
541
- ] = True
542
- # Sprint3 - group by panel, calculate trend of each panel spearately. Add trend to new feature set
543
- if is_panel:
544
- newdata = pd.DataFrame()
545
- panel_wise_end_point_train = {}
546
- for panel, groupdf in X_train_tuned.groupby(panel_col):
547
- groupdf.sort_values(date_col, inplace=True)
548
- groupdf["Trend"] = np.arange(1, len(groupdf) + 1, 1)
549
- newdata = pd.concat([newdata, groupdf])
550
- panel_wise_end_point_train[panel] = len(groupdf)
551
- X_train_tuned = newdata.copy()
552
-
553
- test_newdata = pd.DataFrame()
554
- for panel, test_groupdf in X_test_tuned.groupby(panel_col):
555
- test_groupdf.sort_values(date_col, inplace=True)
556
- start = panel_wise_end_point_train[panel] + 1
557
- end = start + len(test_groupdf) # should be + 1? - Sprint4
558
- # print("??"*20, panel, len(test_groupdf), len(np.arange(start, end, 1)), start)
559
- test_groupdf["Trend"] = np.arange(start, end, 1)
560
- test_newdata = pd.concat([test_newdata, test_groupdf])
561
- X_test_tuned = test_newdata.copy()
562
-
563
- new_features = new_features + ["Trend"]
564
-
565
- else:
566
- X_train_tuned["Trend"] = np.arange(
567
- 1, len(X_train_tuned) + 1, 1
568
- )
569
- X_test_tuned["Trend"] = np.arange(
570
- len(X_train_tuned) + 1,
571
- len(X_train_tuned) + len(X_test_tuned) + 1,
572
- 1,
573
- )
574
- new_features = new_features + ["Trend"]
575
- else:
576
- st.session_state["project_dct"]["model_tuning"][
577
- "trend_check"
578
- ] = False
579
-
580
- if week_number:
581
- st.session_state["project_dct"]["model_tuning"][
582
- "week_num_check"
583
- ] = True
584
- # Sprint3 - create weeknumber from date column in xtrain tuned. add week num to new feature set
585
- if is_panel:
586
- X_train_tuned[date_col] = pd.to_datetime(
587
- X_train_tuned[date_col]
588
- )
589
- X_train_tuned["Week_number"] = X_train_tuned[
590
- date_col
591
- ].dt.day_of_week
592
- if X_train_tuned["Week_number"].nunique() == 1:
593
- st.write(
594
- "All dates in the data are of the same week day. Hence Week number can't be used."
595
- )
596
- else:
597
- X_test_tuned[date_col] = pd.to_datetime(
598
- X_test_tuned[date_col]
599
- )
600
- X_test_tuned["Week_number"] = X_test_tuned[
601
- date_col
602
- ].dt.day_of_week
603
- new_features = new_features + ["Week_number"]
604
-
605
- else:
606
- date = pd.to_datetime(date.values)
607
- X_train_tuned["Week_number"] = pd.to_datetime(
608
- X_train[date_col]
609
- ).dt.day_of_week
610
- X_test_tuned["Week_number"] = pd.to_datetime(
611
- X_test[date_col]
612
- ).dt.day_of_week
613
- new_features = new_features + ["Week_number"]
614
- else:
615
- st.session_state["project_dct"]["model_tuning"][
616
- "week_num_check"
617
- ] = False
618
-
619
- if sine_cosine:
620
- st.session_state["project_dct"]["model_tuning"][
621
- "sine_cosine_check"
622
- ] = True
623
- # Sprint3 - create panel wise sine cosine waves in xtrain tuned. add to new feature set
624
- if is_panel:
625
- new_features = new_features + ["sine_wave", "cosine_wave"]
626
- newdata = pd.DataFrame()
627
- newdata_test = pd.DataFrame()
628
- groups = X_train_tuned.groupby(panel_col)
629
- frequency = 2 * np.pi / 365 # Adjust the frequency as needed
630
-
631
- train_panel_wise_end_point = {}
632
- for panel, groupdf in groups:
633
- num_samples = len(groupdf)
634
- train_panel_wise_end_point[panel] = num_samples
635
- days_since_start = np.arange(num_samples)
636
- sine_wave = np.sin(frequency * days_since_start)
637
- cosine_wave = np.cos(frequency * days_since_start)
638
- sine_cosine_df = pd.DataFrame(
639
- {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
640
- )
641
- assert len(sine_cosine_df) == len(groupdf)
642
- # groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
643
- groupdf["sine_wave"] = sine_wave
644
- groupdf["cosine_wave"] = cosine_wave
645
- newdata = pd.concat([newdata, groupdf])
646
-
647
- X_train_tuned = newdata.copy()
648
-
649
- test_groups = X_test_tuned.groupby(panel_col)
650
- for panel, test_groupdf in test_groups:
651
- num_samples = len(test_groupdf)
652
- start = train_panel_wise_end_point[panel]
653
- days_since_start = np.arange(start, start + num_samples, 1)
654
- # print("##", panel, num_samples, start, len(np.arange(start, start+num_samples, 1)))
655
- sine_wave = np.sin(frequency * days_since_start)
656
- cosine_wave = np.cos(frequency * days_since_start)
657
- sine_cosine_df = pd.DataFrame(
658
- {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
659
- )
660
- assert len(sine_cosine_df) == len(test_groupdf)
661
- # groupdf = pd.concat([groupdf, sine_cosine_df], axis=1)
662
- test_groupdf["sine_wave"] = sine_wave
663
- test_groupdf["cosine_wave"] = cosine_wave
664
- newdata_test = pd.concat([newdata_test, test_groupdf])
665
-
666
- X_test_tuned = newdata_test.copy()
667
-
668
- else:
669
- new_features = new_features + ["sine_wave", "cosine_wave"]
670
-
671
- num_samples = len(X_train_tuned)
672
- frequency = 2 * np.pi / 365 # Adjust the frequency as needed
673
- days_since_start = np.arange(num_samples)
674
- sine_wave = np.sin(frequency * days_since_start)
675
- cosine_wave = np.cos(frequency * days_since_start)
676
- sine_cosine_df = pd.DataFrame(
677
- {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
678
- )
679
- # Concatenate the sine and cosine waves with the scaled X DataFrame
680
- X_train_tuned = pd.concat(
681
- [X_train_tuned, sine_cosine_df], axis=1
682
- )
683
-
684
- test_num_samples = len(X_test_tuned)
685
- start = num_samples
686
- days_since_start = np.arange(
687
- start, start + test_num_samples, 1
688
- )
689
- sine_wave = np.sin(frequency * days_since_start)
690
- cosine_wave = np.cos(frequency * days_since_start)
691
- sine_cosine_df = pd.DataFrame(
692
- {"sine_wave": sine_wave, "cosine_wave": cosine_wave}
693
- )
694
- # Concatenate the sine and cosine waves with the scaled X DataFrame
695
- X_test_tuned = pd.concat(
696
- [X_test_tuned, sine_cosine_df], axis=1
697
- )
698
- else:
699
- st.session_state["project_dct"]["model_tuning"][
700
- "sine_cosine_check"
701
- ] = False
702
-
703
- # model
704
- if selected_options:
705
- new_features = new_features + selected_options
706
- if is_panel:
707
- inp_vars_str = " + ".join(new_features)
708
- new_features = list(set(new_features))
709
-
710
- md_str = target_col + " ~ " + inp_vars_str
711
- md_tuned = smf.mixedlm(
712
- md_str,
713
- data=X_train_tuned[[target_col] + new_features],
714
- groups=X_train_tuned[panel_col],
715
- )
716
- model_tuned = md_tuned.fit()
717
-
718
- # plot act v pred for original model and tuned model
719
- metrics_table, line, actual_vs_predicted_plot = (
720
- plot_actual_vs_predicted(
721
- X_train[date_col],
722
- y_train,
723
- model.fittedvalues,
724
- model,
725
- target_column=sel_target_col,
726
- is_panel=True,
727
- )
728
- )
729
- metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
730
- plot_actual_vs_predicted(
731
- X_train_tuned[date_col],
732
- X_train_tuned[target_col],
733
- model_tuned.fittedvalues,
734
- model_tuned,
735
- target_column=sel_target_col,
736
- is_panel=True,
737
- )
738
- )
739
-
740
- else:
741
- new_features = list(set(new_features))
742
- model_tuned = sm.OLS(y_train, X_train_tuned[new_features]).fit()
743
- metrics_table, line, actual_vs_predicted_plot = (
744
- plot_actual_vs_predicted(
745
- date[:130],
746
- y_train,
747
- model.predict(X_train[features_set]),
748
- model,
749
- target_column=sel_target_col,
750
- )
751
- )
752
- metrics_table_tuned, line, actual_vs_predicted_plot_tuned = (
753
- plot_actual_vs_predicted(
754
- date[:130],
755
- y_train,
756
- model_tuned.predict(X_train_tuned),
757
- model_tuned,
758
- target_column=sel_target_col,
759
- )
760
- )
761
-
762
- mape = np.round(metrics_table.iloc[0, 1], 2)
763
- r2 = np.round(metrics_table.iloc[1, 1], 2)
764
- adjr2 = np.round(metrics_table.iloc[2, 1], 2)
765
-
766
- mape_tuned = np.round(metrics_table_tuned.iloc[0, 1], 2)
767
- r2_tuned = np.round(metrics_table_tuned.iloc[1, 1], 2)
768
- adjr2_tuned = np.round(metrics_table_tuned.iloc[2, 1], 2)
769
-
770
- parameters_ = st.columns(3)
771
- with parameters_[0]:
772
- st.metric("R2", r2_tuned, np.round(r2_tuned - r2, 2))
773
- with parameters_[1]:
774
- st.metric(
775
- "Adjusted R2", adjr2_tuned, np.round(adjr2_tuned - adjr2, 2)
776
- )
777
- with parameters_[2]:
778
- st.metric(
779
- "MAPE", mape_tuned, np.round(mape_tuned - mape, 2), "inverse"
780
- )
781
- st.write(model_tuned.summary())
782
-
783
- X_train_tuned[date_col] = X_train[date_col]
784
- X_test_tuned[date_col] = X_test[date_col]
785
- X_train_tuned[target_col] = y_train
786
- X_test_tuned[target_col] = y_test
787
-
788
- st.header("2.2 Actual vs. Predicted Plot")
789
- # if is_panel:
790
- # metrics_table, line, actual_vs_predicted_plot = plot_actual_vs_predicted(date, y_train, model.predict(X_train),
791
- # model, target_column='Revenue',is_panel=True)
792
- # else:
793
- # metrics_table,line,actual_vs_predicted_plot=plot_actual_vs_predicted(date, y_train, model.predict(X_train), model,target_column='Revenue')
794
- if is_panel:
795
- metrics_table, line, actual_vs_predicted_plot = (
796
- plot_actual_vs_predicted(
797
- X_train_tuned[date_col],
798
- X_train_tuned[target_col],
799
- model_tuned.fittedvalues,
800
- model_tuned,
801
- target_column=sel_target_col,
802
- is_panel=True,
803
- )
804
- )
805
- else:
806
- metrics_table, line, actual_vs_predicted_plot = (
807
- plot_actual_vs_predicted(
808
- X_train_tuned[date_col],
809
- X_train_tuned[target_col],
810
- model_tuned.predict(X_train_tuned[new_features]),
811
- model_tuned,
812
- target_column=sel_target_col,
813
- is_panel=False,
814
- )
815
- )
816
- # plot_actual_vs_predicted(X_train[date_col], y_train,
817
- # model.fittedvalues, model,
818
- # target_column='Revenue',
819
- # is_panel=is_panel)
820
-
821
- st.plotly_chart(actual_vs_predicted_plot, use_container_width=True)
822
-
823
- st.markdown("## 2.3 Residual Analysis")
824
- if is_panel:
825
- columns = st.columns(2)
826
- with columns[0]:
827
- fig = plot_residual_predicted(
828
- y_train, model_tuned.fittedvalues, X_train_tuned
829
- )
830
- st.plotly_chart(fig)
831
-
832
- with columns[1]:
833
- st.empty()
834
- fig = qqplot(y_train, model_tuned.fittedvalues)
835
- st.plotly_chart(fig)
836
-
837
- with columns[0]:
838
- fig = residual_distribution(y_train, model_tuned.fittedvalues)
839
- st.pyplot(fig)
840
- else:
841
- columns = st.columns(2)
842
- with columns[0]:
843
- fig = plot_residual_predicted(
844
- y_train,
845
- model_tuned.predict(X_train_tuned[new_features]),
846
- X_train,
847
- )
848
- st.plotly_chart(fig)
849
-
850
- with columns[1]:
851
- st.empty()
852
- fig = qqplot(
853
- y_train, model_tuned.predict(X_train_tuned[new_features])
854
- )
855
- st.plotly_chart(fig)
856
-
857
- with columns[0]:
858
- fig = residual_distribution(
859
- y_train, model_tuned.predict(X_train_tuned[new_features])
860
- )
861
- st.pyplot(fig)
862
-
863
- # st.session_state['is_tuned_model'][target_col] = True
864
- # Sprint4 - saved tuned model in a dict
865
- st.session_state["Model_Tuned"][sel_model + "__" + target_col] = {
866
- "Model_object": model_tuned,
867
- "feature_set": new_features,
868
- "X_train_tuned": X_train_tuned,
869
- "X_test_tuned": X_test_tuned,
870
- }
871
-
872
- # Pending
873
- # if st.session_state['build_tuned_model']==True:
874
- if st.session_state["Model_Tuned"] is not None:
875
- if st.checkbox(
876
- "Use this model to build response curves", key="save_model"
877
- ):
878
- # save_model = st.button('Use this model to build response curves', key='saved_tuned_model')
879
- # if save_model:
880
- st.session_state["is_tuned_model"][target_col] = True
881
- with open(
882
- os.path.join(
883
- st.session_state["project_path"], "tuned_model.pkl"
884
- ),
885
- "wb",
886
- ) as f:
887
- # pickle.dump(st.session_state['tuned_model'], f)
888
- pickle.dump(st.session_state["Model_Tuned"], f) # Sprint4
889
-
890
- st.session_state["project_dct"]["model_tuning"][
891
- "session_state_saved"
892
- ] = {}
893
- for key in [
894
- "bin_dict",
895
- "used_response_metrics",
896
- "is_tuned_model",
897
- "media_data",
898
- "X_test_spends",
899
- ]:
900
- st.session_state["project_dct"]["model_tuning"][
901
- "session_state_saved"
902
- ][key] = st.session_state[key]
903
-
904
- project_dct_path = os.path.join(
905
- st.session_state["project_path"], "project_dct.pkl"
906
- )
907
- with open(project_dct_path, "wb") as f:
908
- pickle.dump(st.session_state["project_dct"], f)
909
-
910
- update_db("5_Model_Tuning.py")
911
-
912
- st.success(sel_model + "__" + target_col + " Tuned saved!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/6_AI_Model_Results.py DELETED
@@ -1,728 +0,0 @@
1
- import plotly.express as px
2
- import numpy as np
3
- import plotly.graph_objects as go
4
- import streamlit as st
5
- import pandas as pd
6
- import statsmodels.api as sm
7
- from sklearn.metrics import mean_absolute_percentage_error
8
- import sys
9
- import os
10
- from utilities import set_header, load_local_css, load_authenticator
11
- import seaborn as sns
12
- import matplotlib.pyplot as plt
13
- import sweetviz as sv
14
- import tempfile
15
- from sklearn.preprocessing import MinMaxScaler
16
- from st_aggrid import AgGrid
17
- from st_aggrid import GridOptionsBuilder, GridUpdateMode
18
- from st_aggrid import GridOptionsBuilder
19
- import sys
20
- import re
21
- import pickle
22
- from sklearn.metrics import r2_score, mean_absolute_percentage_error
23
- from Data_prep_functions import plot_actual_vs_predicted
24
- import sqlite3
25
- from utilities import update_db
26
-
27
- sys.setrecursionlimit(10**6)
28
-
29
- original_stdout = sys.stdout
30
- sys.stdout = open("temp_stdout.txt", "w")
31
- sys.stdout.close()
32
- sys.stdout = original_stdout
33
-
34
- st.set_page_config(layout="wide")
35
- load_local_css("styles.css")
36
- set_header()
37
-
38
- # TODO :
39
- ## 1. Add non panel model support
40
- ## 2. EDA Function
41
-
42
- for k, v in st.session_state.items():
43
- if k not in ["logout", "login", "config"] and not k.startswith(
44
- "FormSubmitter"
45
- ):
46
- st.session_state[k] = v
47
-
48
- authenticator = st.session_state.get("authenticator")
49
- if authenticator is None:
50
- authenticator = load_authenticator()
51
-
52
- name, authentication_status, username = authenticator.login("Login", "main")
53
- auth_status = st.session_state.get("authentication_status")
54
-
55
- if auth_status == True:
56
- is_state_initiaized = st.session_state.get("initialized", False)
57
- if not is_state_initiaized:
58
- if "session_name" not in st.session_state:
59
- st.session_state["session_name"] = None
60
-
61
- if "project_dct" not in st.session_state:
62
- st.error("Please load a project from Home page")
63
- st.stop()
64
-
65
- conn = sqlite3.connect(
66
- r"DB/User.db", check_same_thread=False
67
- ) # connection with sql db
68
- c = conn.cursor()
69
-
70
- if not os.path.exists(
71
- os.path.join(st.session_state["project_path"], "tuned_model.pkl")
72
- ):
73
- st.error("Please save a tuned model")
74
- st.stop()
75
-
76
- if (
77
- "session_state_saved"
78
- in st.session_state["project_dct"]["model_tuning"].keys()
79
- and st.session_state["project_dct"]["model_tuning"][
80
- "session_state_saved"
81
- ]
82
- != []
83
- ):
84
- for key in ["used_response_metrics", "media_data", "bin_dict"]:
85
- if key not in st.session_state:
86
- st.session_state[key] = st.session_state["project_dct"][
87
- "model_tuning"
88
- ]["session_state_saved"][key]
89
- st.session_state["bin_dict"] = st.session_state["project_dct"][
90
- "model_build"
91
- ]["session_state_saved"]["bin_dict"]
92
-
93
- media_data = st.session_state["media_data"]
94
- panel_col = [
95
- col.lower()
96
- .replace(".", "_")
97
- .replace("@", "_")
98
- .replace(" ", "_")
99
- .replace("-", "")
100
- .replace(":", "")
101
- .replace("__", "_")
102
- for col in st.session_state["bin_dict"]["Panel Level 1"]
103
- ][
104
- 0
105
- ] # set the panel column
106
- is_panel = True if len(panel_col) > 0 else False
107
- date_col = "date"
108
-
109
- def plot_residual_predicted(actual, predicted, df_):
110
- df_["Residuals"] = actual - pd.Series(predicted)
111
- df_["StdResidual"] = (
112
- df_["Residuals"] - df_["Residuals"].mean()
113
- ) / df_["Residuals"].std()
114
-
115
- # Create a Plotly scatter plot
116
- fig = px.scatter(
117
- df_,
118
- x=predicted,
119
- y="StdResidual",
120
- opacity=0.5,
121
- color_discrete_sequence=["#11B6BD"],
122
- )
123
-
124
- # Add horizontal lines
125
- fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
126
- fig.add_hline(y=2, line_color="red")
127
- fig.add_hline(y=-2, line_color="red")
128
-
129
- fig.update_xaxes(title="Predicted")
130
- fig.update_yaxes(title="Standardized Residuals (Actual - Predicted)")
131
-
132
- # Set the same width and height for both figures
133
- fig.update_layout(
134
- title="Residuals over Predicted Values",
135
- autosize=False,
136
- width=600,
137
- height=400,
138
- )
139
-
140
- return fig
141
-
142
- def residual_distribution(actual, predicted):
143
- Residuals = actual - pd.Series(predicted)
144
-
145
- # Create a Seaborn distribution plot
146
- sns.set(style="whitegrid")
147
- plt.figure(figsize=(6, 4))
148
- sns.histplot(Residuals, kde=True, color="#11B6BD")
149
-
150
- plt.title(" Distribution of Residuals")
151
- plt.xlabel("Residuals")
152
- plt.ylabel("Probability Density")
153
-
154
- return plt
155
-
156
- def qqplot(actual, predicted):
157
- Residuals = actual - pd.Series(predicted)
158
- Residuals = pd.Series(Residuals)
159
- Resud_std = (Residuals - Residuals.mean()) / Residuals.std()
160
-
161
- # Create a QQ plot using Plotly with custom colors
162
- fig = go.Figure()
163
- fig.add_trace(
164
- go.Scatter(
165
- x=sm.ProbPlot(Resud_std).theoretical_quantiles,
166
- y=sm.ProbPlot(Resud_std).sample_quantiles,
167
- mode="markers",
168
- marker=dict(size=5, color="#11B6BD"),
169
- name="QQ Plot",
170
- )
171
- )
172
-
173
- # Add the 45-degree reference line
174
- diagonal_line = go.Scatter(
175
- x=[
176
- -2,
177
- 2,
178
- ], # Adjust the x values as needed to fit the range of your data
179
- y=[-2, 2], # Adjust the y values accordingly
180
- mode="lines",
181
- line=dict(color="red"), # Customize the line color and style
182
- name=" ",
183
- )
184
- fig.add_trace(diagonal_line)
185
-
186
- # Customize the layout
187
- fig.update_layout(
188
- title="QQ Plot of Residuals",
189
- title_x=0.5,
190
- autosize=False,
191
- width=600,
192
- height=400,
193
- xaxis_title="Theoretical Quantiles",
194
- yaxis_title="Sample Quantiles",
195
- )
196
-
197
- return fig
198
-
199
- def get_random_effects(media_data, panel_col, mdf):
200
- random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
201
- for i, market in enumerate(media_data[panel_col].unique()):
202
- print(i, end="\r")
203
- intercept = mdf.random_effects[market].values[0]
204
- random_eff_df.loc[i, "random_effect"] = intercept
205
- random_eff_df.loc[i, panel_col] = market
206
-
207
- return random_eff_df
208
-
209
- def mdf_predict(X_df, mdf, random_eff_df):
210
- X = X_df.copy()
211
- X = pd.merge(
212
- X,
213
- random_eff_df[[panel_col, "random_effect"]],
214
- on=panel_col,
215
- how="left",
216
- )
217
- X["pred_fixed_effect"] = mdf.predict(X)
218
-
219
- X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
220
- X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
221
- return X
222
-
223
- def metrics_df_panel(model_dict):
224
- metrics_df = pd.DataFrame(
225
- columns=[
226
- "Model",
227
- "R2",
228
- "ADJR2",
229
- "Train Mape",
230
- "Test Mape",
231
- "Summary",
232
- "Model_object",
233
- ]
234
- )
235
- i = 0
236
- for key in model_dict.keys():
237
- target = key.split("__")[1]
238
- metrics_df.at[i, "Model"] = target
239
- y = model_dict[key]["X_train_tuned"][target]
240
-
241
- random_df = get_random_effects(
242
- media_data, panel_col, model_dict[key]["Model_object"]
243
- )
244
- pred = mdf_predict(
245
- model_dict[key]["X_train_tuned"],
246
- model_dict[key]["Model_object"],
247
- random_df,
248
- )["pred"]
249
-
250
- ytest = model_dict[key]["X_test_tuned"][target]
251
- predtest = mdf_predict(
252
- model_dict[key]["X_test_tuned"],
253
- model_dict[key]["Model_object"],
254
- random_df,
255
- )["pred"]
256
-
257
- metrics_df.at[i, "R2"] = r2_score(y, pred)
258
- metrics_df.at[i, "ADJR2"] = 1 - (1 - metrics_df.loc[i, "R2"]) * (
259
- len(y) - 1
260
- ) / (len(y) - len(model_dict[key]["feature_set"]) - 1)
261
- metrics_df.at[i, "Train Mape"] = mean_absolute_percentage_error(
262
- y, pred
263
- )
264
- metrics_df.at[i, "Test Mape"] = mean_absolute_percentage_error(
265
- ytest, predtest
266
- )
267
- metrics_df.at[i, "Summary"] = model_dict[key][
268
- "Model_object"
269
- ].summary()
270
- metrics_df.at[i, "Model_object"] = model_dict[key]["Model_object"]
271
- i += 1
272
- metrics_df = np.round(metrics_df, 2)
273
- return metrics_df
274
-
275
- with open(
276
- os.path.join(
277
- st.session_state["project_path"], "final_df_transformed.pkl"
278
- ),
279
- "rb",
280
- ) as f:
281
- data = pickle.load(f)
282
- transformed_data = data["final_df_transformed"]
283
- with open(
284
- os.path.join(st.session_state["project_path"], "data_import.pkl"), "rb"
285
- ) as f:
286
- data = pickle.load(f)
287
- st.session_state["bin_dict"] = data["bin_dict"]
288
- with open(
289
- os.path.join(st.session_state["project_path"], "tuned_model.pkl"), "rb"
290
- ) as file:
291
- tuned_model_dict = pickle.load(file)
292
- feature_set_dct = {
293
- key.split("__")[1]: key_dict["feature_set"]
294
- for key, key_dict in tuned_model_dict.items()
295
- }
296
-
297
- # """ the above part should be modified so that we are fetching features set from the saved model"""
298
-
299
- def contributions(X, model, target):
300
- X1 = X.copy()
301
- for j, col in enumerate(X1.columns):
302
- X1[col] = X1[col] * model.params.values[j]
303
-
304
- contributions = np.round(
305
- (X1.sum() / sum(X1.sum()) * 100).sort_values(ascending=False), 2
306
- )
307
- contributions = (
308
- pd.DataFrame(contributions, columns=target)
309
- .reset_index()
310
- .rename(columns={"index": "Channel"})
311
- )
312
- contributions["Channel"] = [
313
- re.split(r"_imp|_cli", col)[0] for col in contributions["Channel"]
314
- ]
315
-
316
- return contributions
317
-
318
- if "contribution_df" not in st.session_state:
319
- st.session_state["contribution_df"] = None
320
-
321
- def contributions_panel(model_dict):
322
- media_data = st.session_state["media_data"]
323
- contribution_df = pd.DataFrame(columns=["Channel"])
324
- for key in model_dict.keys():
325
- best_feature_set = model_dict[key]["feature_set"]
326
- model = model_dict[key]["Model_object"]
327
- target = key.split("__")[1]
328
- X_train = model_dict[key]["X_train_tuned"]
329
- contri_df = pd.DataFrame()
330
-
331
- y = []
332
- y_pred = []
333
-
334
- random_eff_df = get_random_effects(media_data, panel_col, model)
335
- random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
336
- random_eff_df["panel_effect"] = (
337
- random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
338
- )
339
-
340
- coef_df = pd.DataFrame(model.fe_params)
341
- coef_df.reset_index(inplace=True)
342
- coef_df.columns = ["feature", "coef"]
343
-
344
- x_train_contribution = X_train.copy()
345
- x_train_contribution = mdf_predict(
346
- x_train_contribution, model, random_eff_df
347
- )
348
-
349
- x_train_contribution = pd.merge(
350
- x_train_contribution,
351
- random_eff_df[[panel_col, "panel_effect"]],
352
- on=panel_col,
353
- how="left",
354
- )
355
-
356
- for i in range(len(coef_df))[1:]:
357
- coef = coef_df.loc[i, "coef"]
358
- col = coef_df.loc[i, "feature"]
359
- x_train_contribution[str(col) + "_contr"] = (
360
- coef * x_train_contribution[col]
361
- )
362
-
363
- # x_train_contribution['sum_contributions'] = x_train_contribution.filter(regex="contr").sum(axis=1)
364
- # x_train_contribution['sum_contributions'] = x_train_contribution['sum_contributions'] + x_train_contribution[
365
- # 'panel_effect']
366
-
367
- base_cols = ["panel_effect"] + [
368
- c
369
- for c in x_train_contribution.filter(regex="contr").columns
370
- if c
371
- in [
372
- "Week_number_contr",
373
- "Trend_contr",
374
- "sine_wave_contr",
375
- "cosine_wave_contr",
376
- ]
377
- ]
378
- x_train_contribution["base_contr"] = x_train_contribution[
379
- base_cols
380
- ].sum(axis=1)
381
- x_train_contribution.drop(columns=base_cols, inplace=True)
382
- # x_train_contribution.to_csv("Test/smr_x_train_contribution.csv", index=False)
383
-
384
- contri_df = pd.DataFrame(
385
- x_train_contribution.filter(regex="contr").sum(axis=0)
386
- )
387
- contri_df.reset_index(inplace=True)
388
- contri_df.columns = ["Channel", target]
389
- contri_df["Channel"] = (
390
- contri_df["Channel"]
391
- .str.split("(_impres|_clicks)")
392
- .apply(lambda c: c[0])
393
- )
394
- contri_df[target] = (
395
- 100 * contri_df[target] / contri_df[target].sum()
396
- )
397
- contri_df["Channel"].replace("base_contr", "base", inplace=True)
398
- contribution_df = pd.merge(
399
- contribution_df, contri_df, on="Channel", how="outer"
400
- )
401
- # st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
402
- return contribution_df
403
-
404
- metrics_table = metrics_df_panel(tuned_model_dict)
405
-
406
- eda_columns = st.columns(2)
407
- with eda_columns[1]:
408
- eda = st.button(
409
- "Generate EDA Report",
410
- help="Click to generate a bivariate report for the selected response metric from the table below.",
411
- )
412
-
413
- # st.markdown('Model Metrics')
414
- st.title("Contribution Overview")
415
- options = st.session_state["used_response_metrics"]
416
- options = [
417
- opt.lower()
418
- .replace(" ", "_")
419
- .replace("-", "")
420
- .replace(":", "")
421
- .replace("__", "_")
422
- for opt in options
423
- ]
424
-
425
- default_options = (
426
- st.session_state["project_dct"]["saved_model_results"].get(
427
- "selected_options"
428
- )
429
- if st.session_state["project_dct"]["saved_model_results"].get(
430
- "selected_options"
431
- )
432
- is not None
433
- else [options[-1]]
434
- )
435
- for i in default_options:
436
- if i not in options:
437
- st.write(i)
438
- default_options.remove(i)
439
- contribution_selections = st.multiselect(
440
- "Select the Response Metrics to compare contributions",
441
- options,
442
- default=default_options,
443
- )
444
- trace_data = []
445
-
446
- st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
447
-
448
- for selection in contribution_selections:
449
-
450
- trace = go.Bar(
451
- x=st.session_state["contribution_df"]["Channel"],
452
- y=st.session_state["contribution_df"][selection],
453
- name=selection,
454
- text=np.round(st.session_state["contribution_df"][selection], 0)
455
- .astype(int)
456
- .astype(str)
457
- + "%",
458
- textposition="outside",
459
- )
460
- trace_data.append(trace)
461
-
462
- layout = go.Layout(
463
- title="Metrics Contribution by Channel",
464
- xaxis=dict(title="Channel Name"),
465
- yaxis=dict(title="Metrics Contribution"),
466
- barmode="group",
467
- )
468
- fig = go.Figure(data=trace_data, layout=layout)
469
- st.plotly_chart(fig, use_container_width=True)
470
-
471
- ############################################ Waterfall Chart ############################################
472
- # import plotly.graph_objects as go
473
-
474
- # # Initialize a Plotly figure
475
- # fig = go.Figure()
476
-
477
- # for selection in contribution_selections:
478
- # # Ensure y_values are numeric
479
- # y_values = st.session_state["contribution_df"][selection].values.astype(float)
480
-
481
- # # Generating text labels for each bar, ensuring operations are compatible with string formats
482
- # text_values = [f"{val}%" for val in np.round(y_values, 0).astype(int)]
483
-
484
- # fig.add_trace(
485
- # go.Waterfall(
486
- # name=selection,
487
- # orientation="v",
488
- # measure=["relative"]
489
- # * len(y_values), # Adjust if you have absolute values at certain points
490
- # x=st.session_state["contribution_df"]["Channel"].tolist(),
491
- # text=text_values,
492
- # textposition="outside",
493
- # y=y_values,
494
- # increasing={"marker": {"color": "green"}},
495
- # decreasing={"marker": {"color": "red"}},
496
- # totals={"marker": {"color": "blue"}},
497
- # )
498
- # )
499
-
500
- # fig.update_layout(
501
- # title="Metrics Contribution by Channel",
502
- # xaxis={"title": "Channel Name"},
503
- # yaxis={"title": "Metrics Contribution"},
504
- # height=600,
505
- # )
506
-
507
- # # Displaying the waterfall chart in Streamlit
508
- # st.plotly_chart(fig, use_container_width=True)
509
-
510
- import plotly.graph_objects as go
511
-
512
- # Initialize a Plotly figure
513
- fig = go.Figure()
514
-
515
- for selection in contribution_selections:
516
- # Ensure contributions are numeric
517
- contributions = (
518
- st.session_state["contribution_df"][selection]
519
- .values.astype(float)
520
- .tolist()
521
- )
522
- channel_names = st.session_state["contribution_df"]["Channel"].tolist()
523
-
524
- display_name, display_contribution, base_contribution = [], [], 0
525
- for channel_name, contribution in zip(channel_names, contributions):
526
- if channel_name != "const" and channel_name != "base":
527
- display_name.append(channel_name)
528
- display_contribution.append(contribution)
529
- else:
530
- base_contribution = contribution
531
-
532
- display_name = ["Base Sales"] + display_name
533
- display_contribution = [base_contribution] + display_contribution
534
-
535
- # Generating text labels for each bar, ensuring operations are compatible with string formats
536
- text_values = [
537
- f"{val}%" for val in np.round(display_contribution, 0).astype(int)
538
- ]
539
-
540
- fig.add_trace(
541
- go.Waterfall(
542
- orientation="v",
543
- measure=["relative"]
544
- * len(
545
- display_contribution
546
- ), # Adjust if you have absolute values at certain points
547
- x=display_name,
548
- text=text_values,
549
- textposition="outside",
550
- y=display_contribution,
551
- increasing={"marker": {"color": "green"}},
552
- decreasing={"marker": {"color": "red"}},
553
- totals={"marker": {"color": "blue"}},
554
- )
555
- )
556
-
557
- fig.update_layout(
558
- title="Metrics Contribution by Channel",
559
- xaxis={"title": "Channel Name"},
560
- yaxis={"title": "Metrics Contribution"},
561
- height=600,
562
- )
563
-
564
- # Displaying the waterfall chart in Streamlit
565
- st.plotly_chart(fig, use_container_width=True)
566
-
567
- ############################################ Waterfall Chart ############################################
568
-
569
- st.title("Analysis of Models Result")
570
- # st.markdown()
571
- previous_selection = st.session_state["project_dct"][
572
- "saved_model_results"
573
- ].get("model_grid_sel", [1])
574
- st.write(np.round(metrics_table, 2))
575
- gd_table = metrics_table.iloc[:, :-2]
576
-
577
- gd = GridOptionsBuilder.from_dataframe(gd_table)
578
- # gd.configure_pagination(enabled=True)
579
- gd.configure_selection(
580
- use_checkbox=True,
581
- selection_mode="single",
582
- pre_select_all_rows=False,
583
- pre_selected_rows=previous_selection,
584
- )
585
-
586
- gridoptions = gd.build()
587
- table = AgGrid(
588
- gd_table,
589
- gridOptions=gridoptions,
590
- fit_columns_on_grid_load=True,
591
- height=200,
592
- )
593
- # table=metrics_table.iloc[:,:-2]
594
- # table.insert(0, "Select", False)
595
- # selection_table=st.data_editor(table,column_config={"Select": st.column_config.CheckboxColumn(required=True)})
596
- if len(table.selected_rows) > 0:
597
- st.session_state["project_dct"]["saved_model_results"][
598
- "model_grid_sel"
599
- ] = table.selected_rows[0]["_selectedRowNodeInfo"]["nodeRowIndex"]
600
- if len(table.selected_rows) == 0:
601
- st.warning(
602
- "Click on the checkbox to view comprehensive results of the selected model."
603
- )
604
- st.stop()
605
- else:
606
- target_column = table.selected_rows[0]["Model"]
607
- feature_set = feature_set_dct[target_column]
608
-
609
- # with eda_columns[1]:
610
- # if eda:
611
- # def generate_report_with_target(channel_data, target_feature):
612
- # report = sv.analyze(
613
- # [channel_data, "Dataset"], target_feat=target_feature, verbose=False
614
- # )
615
- # temp_dir = tempfile.mkdtemp()
616
- # report_path = os.path.join(temp_dir, "report.html")
617
- # report.show_html(
618
- # filepath=report_path, open_browser=False
619
- # ) # Generate the report as an HTML file
620
- # return report_path
621
- #
622
- # report_data = transformed_data[feature_set]
623
- # report_data[target_column] = transformed_data[target_column]
624
- # report_file = generate_report_with_target(report_data, target_column)
625
- #
626
- # if os.path.exists(report_file):
627
- # with open(report_file, "rb") as f:
628
- # st.download_button(
629
- # label="Download EDA Report",
630
- # data=f.read(),
631
- # file_name="report.html",
632
- # mime="text/html",
633
- # )
634
- # else:
635
- # st.warning("Report generation failed. Unable to find the report file.")
636
-
637
- model = metrics_table[metrics_table["Model"] == target_column][
638
- "Model_object"
639
- ].iloc[0]
640
- target = metrics_table[metrics_table["Model"] == target_column][
641
- "Model"
642
- ].iloc[0]
643
- st.header("Model Summary")
644
- st.write(model.summary())
645
-
646
- sel_dict = tuned_model_dict[
647
- [k for k in tuned_model_dict.keys() if k.split("__")[1] == target][0]
648
- ]
649
- X_train = sel_dict["X_train_tuned"]
650
- y_train = X_train[target]
651
- random_effects = get_random_effects(media_data, panel_col, model)
652
- pred = mdf_predict(X_train, model, random_effects)["pred"]
653
-
654
- X_test = sel_dict["X_test_tuned"]
655
- y_test = X_test[target]
656
- predtest = mdf_predict(X_test, model, random_effects)["pred"]
657
- metrics_table_train, _, fig_train = plot_actual_vs_predicted(
658
- X_train[date_col],
659
- y_train,
660
- pred,
661
- model,
662
- target_column=target_column,
663
- flag=None,
664
- repeat_all_years=False,
665
- is_panel=is_panel,
666
- )
667
-
668
- metrics_table_test, _, fig_test = plot_actual_vs_predicted(
669
- X_test[date_col],
670
- y_test,
671
- predtest,
672
- model,
673
- target_column=target_column,
674
- flag=None,
675
- repeat_all_years=False,
676
- is_panel=is_panel,
677
- )
678
-
679
- metrics_table_train = metrics_table_train.set_index("Metric").transpose()
680
- metrics_table_train.index = ["Train"]
681
- metrics_table_test = metrics_table_test.set_index("Metric").transpose()
682
- metrics_table_test.index = ["test"]
683
- metrics_table = np.round(
684
- pd.concat([metrics_table_train, metrics_table_test]), 2
685
- )
686
-
687
- st.markdown("Result Overview")
688
- st.dataframe(np.round(metrics_table, 2), use_container_width=True)
689
-
690
- st.subheader("Actual vs Predicted Plot Train")
691
-
692
- st.plotly_chart(fig_train, use_container_width=True)
693
- st.subheader("Actual vs Predicted Plot Test")
694
- st.plotly_chart(fig_test, use_container_width=True)
695
-
696
- st.markdown("## Residual Analysis")
697
- columns = st.columns(2)
698
-
699
- Xtrain1 = X_train.copy()
700
- with columns[0]:
701
- fig = plot_residual_predicted(y_train, model.predict(Xtrain1), Xtrain1)
702
- st.plotly_chart(fig)
703
-
704
- with columns[1]:
705
- st.empty()
706
- fig = qqplot(y_train, model.predict(X_train))
707
- st.plotly_chart(fig)
708
-
709
- with columns[0]:
710
- fig = residual_distribution(y_train, model.predict(X_train))
711
- st.pyplot(fig)
712
-
713
- update_db("6_AI_Model_Result.py")
714
-
715
-
716
- elif auth_status == False:
717
- st.error("Username/Password is incorrect")
718
- try:
719
- username_forgot_pw, email_forgot_password, random_password = (
720
- authenticator.forgot_password("Forgot password")
721
- )
722
- if username_forgot_pw:
723
- st.success("New password sent securely")
724
- # Random password to be transferred to the user securely
725
- elif username_forgot_pw == False:
726
- st.error("Username not found")
727
- except Exception as e:
728
- st.error(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/7_Current_Media_Performance.py DELETED
@@ -1,573 +0,0 @@
1
- """
2
- MMO Build Sprint 3
3
- additions : contributions calculated using tuned Mixed LM model
4
- pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
5
-
6
- MMO Build Sprint 4
7
- additions : response metrics selection
8
- pending : contributions calculations using - 1. not tuned Mixed LM model, 2. tuned OLS model, 3. not tuned OLS model
9
- """
10
-
11
- import streamlit as st
12
- import pandas as pd
13
- from sklearn.preprocessing import MinMaxScaler
14
- import pickle
15
- import os
16
-
17
- from utilities_with_panel import load_local_css, set_header
18
- import yaml
19
- from yaml import SafeLoader
20
- import streamlit_authenticator as stauth
21
- import sqlite3
22
- from utilities import update_db
23
-
24
- st.set_page_config(layout="wide")
25
- load_local_css("styles.css")
26
- set_header()
27
- for k, v in st.session_state.items():
28
- # print(k, v)
29
- if k not in [
30
- "logout",
31
- "login",
32
- "config",
33
- "build_tuned_model",
34
- ] and not k.startswith("FormSubmitter"):
35
- st.session_state[k] = v
36
- with open("config.yaml") as file:
37
- config = yaml.load(file, Loader=SafeLoader)
38
- st.session_state["config"] = config
39
- authenticator = stauth.Authenticate(
40
- config["credentials"],
41
- config["cookie"]["name"],
42
- config["cookie"]["key"],
43
- config["cookie"]["expiry_days"],
44
- config["preauthorized"],
45
- )
46
- st.session_state["authenticator"] = authenticator
47
- name, authentication_status, username = authenticator.login("Login", "main")
48
- auth_status = st.session_state.get("authentication_status")
49
-
50
- if auth_status == True:
51
- authenticator.logout("Logout", "main")
52
- is_state_initiaized = st.session_state.get("initialized", False)
53
-
54
- if "project_dct" not in st.session_state:
55
- st.error("Please load a project from Home page")
56
- st.stop()
57
-
58
- conn = sqlite3.connect(
59
- r"DB/User.db", check_same_thread=False
60
- ) # connection with sql db
61
- c = conn.cursor()
62
-
63
- if not os.path.exists(
64
- os.path.join(st.session_state["project_path"], "tuned_model.pkl")
65
- ):
66
- st.error("Please save a tuned model")
67
- st.stop()
68
-
69
- if (
70
- "session_state_saved"
71
- in st.session_state["project_dct"]["model_tuning"].keys()
72
- and st.session_state["project_dct"]["model_tuning"][
73
- "session_state_saved"
74
- ]
75
- != []
76
- ):
77
- for key in [
78
- "used_response_metrics",
79
- "is_tuned_model",
80
- "media_data",
81
- "X_test_spends",
82
- ]:
83
- st.session_state[key] = st.session_state["project_dct"][
84
- "model_tuning"
85
- ]["session_state_saved"][key]
86
- elif (
87
- "session_state_saved"
88
- in st.session_state["project_dct"]["model_build"].keys()
89
- and st.session_state["project_dct"]["model_build"][
90
- "session_state_saved"
91
- ]
92
- != []
93
- ):
94
- for key in [
95
- "used_response_metrics",
96
- "date",
97
- "saved_model_names",
98
- "media_data",
99
- "X_test_spends",
100
- ]:
101
- st.session_state[key] = st.session_state["project_dct"][
102
- "model_build"
103
- ]["session_state_saved"][key]
104
- else:
105
- st.error("Please tune a model first")
106
- st.session_state["bin_dict"] = st.session_state["project_dct"][
107
- "model_build"
108
- ]["session_state_saved"]["bin_dict"]
109
- st.session_state["media_data"].columns = [
110
- c.lower() for c in st.session_state["media_data"].columns
111
- ]
112
-
113
- from utilities_with_panel import (
114
- overview_test_data_prep_panel,
115
- overview_test_data_prep_nonpanel,
116
- initialize_data,
117
- create_channel_summary,
118
- create_contribution_pie,
119
- create_contribuion_stacked_plot,
120
- create_channel_spends_sales_plot,
121
- format_numbers,
122
- channel_name_formating,
123
- )
124
-
125
- import plotly.graph_objects as go
126
- import streamlit_authenticator as stauth
127
- import yaml
128
- from yaml import SafeLoader
129
- import time
130
-
131
- def get_random_effects(media_data, panel_col, mdf):
132
- random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
133
- for i, market in enumerate(media_data[panel_col].unique()):
134
- print(i, end="\r")
135
- intercept = mdf.random_effects[market].values[0]
136
- random_eff_df.loc[i, "random_effect"] = intercept
137
- random_eff_df.loc[i, panel_col] = market
138
-
139
- return random_eff_df
140
-
141
- def process_train_and_test(train, test, features, panel_col, target_col):
142
- X1 = train[features]
143
-
144
- ss = MinMaxScaler()
145
- X1 = pd.DataFrame(ss.fit_transform(X1), columns=X1.columns)
146
-
147
- X1[panel_col] = train[panel_col]
148
- X1[target_col] = train[target_col]
149
-
150
- if test is not None:
151
- X2 = test[features]
152
- X2 = pd.DataFrame(ss.transform(X2), columns=X2.columns)
153
- X2[panel_col] = test[panel_col]
154
- X2[target_col] = test[target_col]
155
- return X1, X2
156
- return X1
157
-
158
- def mdf_predict(X_df, mdf, random_eff_df):
159
- X = X_df.copy()
160
- X = pd.merge(
161
- X,
162
- random_eff_df[[panel_col, "random_effect"]],
163
- on=panel_col,
164
- how="left",
165
- )
166
- X["pred_fixed_effect"] = mdf.predict(X)
167
-
168
- X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
169
- X.to_csv("Test/merged_df_contri.csv", index=False)
170
- X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
171
-
172
- return X
173
-
174
- # target='Revenue'
175
-
176
- # is_panel=False
177
- # is_panel = st.session_state['is_panel']
178
- panel_col = [
179
- col.lower()
180
- .replace(".", "_")
181
- .replace("@", "_")
182
- .replace(" ", "_")
183
- .replace("-", "")
184
- .replace(":", "")
185
- .replace("__", "_")
186
- for col in st.session_state["bin_dict"]["Panel Level 1"]
187
- ][
188
- 0
189
- ] # set the panel column
190
- is_panel = True if len(panel_col) > 0 else False
191
- date_col = "date"
192
-
193
- # Sprint4 - if used_response_metrics is not blank, then select one of the used_response_metrics, else target is revenue by default
194
- if (
195
- "used_response_metrics" in st.session_state
196
- and st.session_state["used_response_metrics"] != []
197
- ):
198
- sel_target_col = st.selectbox(
199
- "Select the response metric",
200
- st.session_state["used_response_metrics"],
201
- )
202
- target_col = (
203
- sel_target_col.lower()
204
- .replace(" ", "_")
205
- .replace("-", "")
206
- .replace(":", "")
207
- .replace("__", "_")
208
- )
209
- else:
210
- sel_target_col = "Total Approved Accounts - Revenue"
211
- target_col = "total_approved_accounts_revenue"
212
-
213
- target = sel_target_col
214
-
215
- # Sprint4 - Look through all saved tuned models, only show saved models of the sel resp metric (target_col)
216
- # saved_models = st.session_state['saved_model_names']
217
- # Sprint4 - get the model obj of the selected model
218
- # st.write(sel_model_dict)
219
-
220
- # Sprint3 - Contribution
221
- if is_panel:
222
- # read tuned mixedLM model
223
- # if st.session_state["tuned_model"] is not None :
224
- if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
225
- with open(
226
- os.path.join(
227
- st.session_state["project_path"], "tuned_model.pkl"
228
- ),
229
- "rb",
230
- ) as file:
231
- model_dict = pickle.load(file)
232
- saved_models = list(model_dict.keys())
233
- # st.write(saved_models)
234
- required_saved_models = [
235
- m.split("__")[0]
236
- for m in saved_models
237
- if m.split("__")[1] == target_col
238
- ]
239
- sel_model = st.selectbox(
240
- "Select the model to review", required_saved_models
241
- )
242
- sel_model_dict = model_dict[sel_model + "__" + target_col]
243
-
244
- model = sel_model_dict["Model_object"]
245
- X_train = sel_model_dict["X_train_tuned"]
246
- X_test = sel_model_dict["X_test_tuned"]
247
- best_feature_set = sel_model_dict["feature_set"]
248
-
249
- else: # if non tuned model to be used # Pending
250
- with open(
251
- os.path.join(
252
- st.session_state["project_path"], "best_models.pkl"
253
- ),
254
- "rb",
255
- ) as file:
256
- model_dict = pickle.load(file)
257
- # st.write(model_dict)
258
- saved_models = list(model_dict.keys())
259
- required_saved_models = [
260
- m.split("__")[0]
261
- for m in saved_models
262
- if m.split("__")[1] == target_col
263
- ]
264
- sel_model = st.selectbox(
265
- "Select the model to review", required_saved_models
266
- )
267
- sel_model_dict = model_dict[sel_model + "__" + target_col]
268
- # st.write(sel_model, sel_model_dict)
269
- model = sel_model_dict["Model_object"]
270
- X_train = sel_model_dict["X_train"]
271
- X_test = sel_model_dict["X_test"]
272
- best_feature_set = sel_model_dict["feature_set"]
273
-
274
- # Calculate contributions
275
-
276
- with open(
277
- os.path.join(st.session_state["project_path"], "data_import.pkl"),
278
- "rb",
279
- ) as f:
280
- data = pickle.load(f)
281
-
282
- # Accessing the loaded objects
283
- st.session_state["orig_media_data"] = data["final_df"]
284
-
285
- st.session_state["orig_media_data"].columns = [
286
- col.lower()
287
- .replace(".", "_")
288
- .replace("@", "_")
289
- .replace(" ", "_")
290
- .replace("-", "")
291
- .replace(":", "")
292
- .replace("__", "_")
293
- for col in st.session_state["orig_media_data"].columns
294
- ]
295
-
296
- media_data = st.session_state["media_data"]
297
-
298
- # st.session_state['orig_media_data']=st.session_state["media_data"]
299
-
300
- # st.write(media_data)
301
-
302
- contri_df = pd.DataFrame()
303
-
304
- y = []
305
- y_pred = []
306
-
307
- random_eff_df = get_random_effects(media_data, panel_col, model)
308
- random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
309
- random_eff_df["panel_effect"] = (
310
- random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
311
- )
312
- # random_eff_df.to_csv("Test/random_eff_df_contri.csv", index=False)
313
-
314
- coef_df = pd.DataFrame(model.fe_params)
315
- coef_df.reset_index(inplace=True)
316
- coef_df.columns = ["feature", "coef"]
317
-
318
- # coef_df.reset_index().to_csv("Test/coef_df_contri1.csv",index=False)
319
- # print(model.fe_params)
320
-
321
- x_train_contribution = X_train.copy()
322
- x_test_contribution = X_test.copy()
323
-
324
- # preprocessing not needed since X_train is already preprocessed
325
- # X1, X2 = process_train_and_test(x_train_contribution, x_test_contribution, best_feature_set, panel_col, target_col)
326
- # x_train_contribution[best_feature_set] = X1[best_feature_set]
327
- # x_test_contribution[best_feature_set] = X2[best_feature_set]
328
-
329
- x_train_contribution = mdf_predict(
330
- x_train_contribution, model, random_eff_df
331
- )
332
- x_test_contribution = mdf_predict(
333
- x_test_contribution, model, random_eff_df
334
- )
335
-
336
- x_train_contribution = pd.merge(
337
- x_train_contribution,
338
- random_eff_df[[panel_col, "panel_effect"]],
339
- on=panel_col,
340
- how="left",
341
- )
342
- x_test_contribution = pd.merge(
343
- x_test_contribution,
344
- random_eff_df[[panel_col, "panel_effect"]],
345
- on=panel_col,
346
- how="left",
347
- )
348
-
349
- for i in range(len(coef_df))[1:]:
350
- coef = coef_df.loc[i, "coef"]
351
- col = coef_df.loc[i, "feature"]
352
- x_train_contribution[str(col) + "_contr"] = (
353
- coef * x_train_contribution[col]
354
- )
355
- x_test_contribution[str(col) + "_contr"] = (
356
- coef * x_train_contribution[col]
357
- )
358
-
359
- x_train_contribution["sum_contributions"] = (
360
- x_train_contribution.filter(regex="contr").sum(axis=1)
361
- )
362
- x_train_contribution["sum_contributions"] = (
363
- x_train_contribution["sum_contributions"]
364
- + x_train_contribution["panel_effect"]
365
- )
366
-
367
- x_test_contribution["sum_contributions"] = x_test_contribution.filter(
368
- regex="contr"
369
- ).sum(axis=1)
370
- x_test_contribution["sum_contributions"] = (
371
- x_test_contribution["sum_contributions"]
372
- + x_test_contribution["panel_effect"]
373
- )
374
-
375
- # # test
376
- x_train_contribution.to_csv(
377
- "Test/x_train_contribution.csv", index=False
378
- )
379
- x_test_contribution.to_csv("Test/x_test_contribution.csv", index=False)
380
- #
381
- # st.session_state['orig_media_data'].to_csv("Test/transformed_data.csv",index=False)
382
- # st.session_state['X_test_spends'].to_csv("Test/test_spends.csv",index=False)
383
- # # st.write(st.session_state['orig_media_data'].columns)
384
-
385
- # st.write(date_col,panel_col)
386
- # st.write(x_test_contribution)
387
-
388
- overview_test_data_prep_panel(
389
- x_test_contribution,
390
- st.session_state["orig_media_data"],
391
- st.session_state["X_test_spends"],
392
- date_col,
393
- panel_col,
394
- target_col,
395
- )
396
-
397
- else: # NON PANEL
398
- if st.session_state["is_tuned_model"][target_col] == True: # Sprint4
399
- with open(
400
- os.path.join(
401
- st.session_state["project_path"], "tuned_model.pkl"
402
- ),
403
- "rb",
404
- ) as file:
405
- model_dict = pickle.load(file)
406
- saved_models = list(model_dict.keys())
407
- required_saved_models = [
408
- m.split("__")[0]
409
- for m in saved_models
410
- if m.split("__")[1] == target_col
411
- ]
412
- sel_model = st.selectbox(
413
- "Select the model to review", required_saved_models
414
- )
415
- sel_model_dict = model_dict[sel_model + "__" + target_col]
416
-
417
- model = sel_model_dict["Model_object"]
418
- X_train = sel_model_dict["X_train_tuned"]
419
- X_test = sel_model_dict["X_test_tuned"]
420
- best_feature_set = sel_model_dict["feature_set"]
421
-
422
- else: # Sprint4
423
- with open(
424
- os.path.join(
425
- st.session_state["project_path"], "best_models.pkl"
426
- ),
427
- "rb",
428
- ) as file:
429
- model_dict = pickle.load(file)
430
- saved_models = list(model_dict.keys())
431
- required_saved_models = [
432
- m.split("__")[0]
433
- for m in saved_models
434
- if m.split("__")[1] == target_col
435
- ]
436
- sel_model = st.selectbox(
437
- "Select the model to review", required_saved_models
438
- )
439
- sel_model_dict = model_dict[sel_model + "__" + target_col]
440
-
441
- model = sel_model_dict["Model_object"]
442
- X_train = sel_model_dict["X_train"]
443
- X_test = sel_model_dict["X_test"]
444
- best_feature_set = sel_model_dict["feature_set"]
445
-
446
- x_train_contribution = X_train.copy()
447
- x_test_contribution = X_test.copy()
448
-
449
- x_train_contribution["pred"] = model.predict(
450
- x_train_contribution[best_feature_set]
451
- )
452
- x_test_contribution["pred"] = model.predict(
453
- x_test_contribution[best_feature_set]
454
- )
455
-
456
- for num, i in enumerate(model.params.values):
457
- col = best_feature_set[num]
458
- x_train_contribution[col + "_contr"] = X_train[col] * i
459
- x_test_contribution[col + "_contr"] = X_test[col] * i
460
-
461
- x_test_contribution.to_csv(
462
- "Test/x_test_contribution_non_panel.csv", index=False
463
- )
464
- overview_test_data_prep_nonpanel(
465
- x_test_contribution,
466
- st.session_state["orig_media_data"].copy(),
467
- st.session_state["X_test_spends"].copy(),
468
- date_col,
469
- target_col,
470
- )
471
- # for k, v in st.session_sta
472
- # te.items():
473
-
474
- # if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
475
- # st.session_state[k] = v
476
-
477
- # authenticator = st.session_state.get('authenticator')
478
-
479
- # if authenticator is None:
480
- # authenticator = load_authenticator()
481
-
482
- # name, authentication_status, username = authenticator.login('Login', 'main')
483
- # auth_status = st.session_state['authentication_status']
484
-
485
- # if auth_status:
486
- # authenticator.logout('Logout', 'main')
487
-
488
- # is_state_initiaized = st.session_state.get('initialized',False)
489
- # if not is_state_initiaized:
490
-
491
- initialize_data(target_col)
492
- scenario = st.session_state["scenario"]
493
- raw_df = st.session_state["raw_df"]
494
- st.header("Overview of previous spends")
495
-
496
- # st.write(scenario.actual_total_spends)
497
- # st.write(scenario.actual_total_sales)
498
- columns = st.columns((1, 1, 3))
499
-
500
- with columns[0]:
501
- st.metric(
502
- label="Spends",
503
- value=format_numbers(float(scenario.actual_total_spends)),
504
- )
505
- ###print(f"##################### {scenario.actual_total_sales} ##################")
506
- with columns[1]:
507
- st.metric(
508
- label=target,
509
- value=format_numbers(
510
- float(scenario.actual_total_sales), include_indicator=False
511
- ),
512
- )
513
-
514
- actual_summary_df = create_channel_summary(scenario)
515
- actual_summary_df["Channel"] = actual_summary_df["Channel"].apply(
516
- channel_name_formating
517
- )
518
-
519
- columns = st.columns((2, 1))
520
- with columns[0]:
521
- with st.expander("Channel wise overview"):
522
- st.markdown(
523
- actual_summary_df.style.set_table_styles(
524
- [
525
- {
526
- "selector": "th",
527
- "props": [("background-color", "#11B6BD")],
528
- },
529
- {
530
- "selector": "tr:nth-child(even)",
531
- "props": [("background-color", "#11B6BD")],
532
- },
533
- ]
534
- ).to_html(),
535
- unsafe_allow_html=True,
536
- )
537
-
538
- st.markdown("<hr>", unsafe_allow_html=True)
539
- ##############################
540
-
541
- st.plotly_chart(
542
- create_contribution_pie(scenario), use_container_width=True
543
- )
544
- st.markdown("<hr>", unsafe_allow_html=True)
545
-
546
- ################################3
547
- st.plotly_chart(
548
- create_contribuion_stacked_plot(scenario), use_container_width=True
549
- )
550
- st.markdown("<hr>", unsafe_allow_html=True)
551
- #######################################
552
-
553
- selected_channel_name = st.selectbox(
554
- "Channel",
555
- st.session_state["channels_list"] + ["non media"],
556
- format_func=channel_name_formating,
557
- )
558
- selected_channel = scenario.channels.get(selected_channel_name, None)
559
-
560
- st.plotly_chart(
561
- create_channel_spends_sales_plot(selected_channel),
562
- use_container_width=True,
563
- )
564
-
565
- st.markdown("<hr>", unsafe_allow_html=True)
566
-
567
- if st.checkbox("Save this session", key="save"):
568
- project_dct_path = os.path.join(
569
- st.session_state["session_path"], "project_dct.pkl"
570
- )
571
- with open(project_dct_path, "wb") as f:
572
- pickle.dump(st.session_state["project_dct"], f)
573
- update_db("7_Current_Media_Performance.py")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/8_Build_Response_Curves.py DELETED
@@ -1,596 +0,0 @@
1
- import streamlit as st
2
- import plotly.express as px
3
- import numpy as np
4
- import plotly.graph_objects as go
5
- from utilities import (
6
- channel_name_formating,
7
- load_authenticator,
8
- initialize_data,
9
- fetch_actual_data,
10
- )
11
- from sklearn.metrics import r2_score
12
- from collections import OrderedDict
13
- from classes import class_from_dict, class_to_dict
14
- import pickle
15
- import json
16
- import sqlite3
17
- from utilities import update_db
18
-
19
- for k, v in st.session_state.items():
20
- if k not in ["logout", "login", "config"] and not k.startswith(
21
- "FormSubmitter"
22
- ):
23
- st.session_state[k] = v
24
-
25
-
26
- def s_curve(x, K, b, a, x0):
27
- return K / (1 + b * np.exp(-a * (x - x0)))
28
-
29
-
30
- def save_scenario(scenario_name):
31
- """
32
- Save the current scenario with the mentioned name in the session state
33
-
34
- Parameters
35
- ----------
36
- scenario_name
37
- Name of the scenario to be saved
38
- """
39
- if "saved_scenarios" not in st.session_state:
40
- st.session_state = OrderedDict()
41
-
42
- # st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
43
- st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
44
- st.session_state["scenario"]
45
- )
46
- st.session_state["scenario_input"] = ""
47
- print(type(st.session_state["saved_scenarios"]))
48
- with open("../saved_scenarios.pkl", "wb") as f:
49
- pickle.dump(st.session_state["saved_scenarios"], f)
50
-
51
-
52
- def reset_curve_parameters(
53
- metrics=None, panel=None, selected_channel_name=None
54
- ):
55
- del st.session_state["K"]
56
- del st.session_state["b"]
57
- del st.session_state["a"]
58
- del st.session_state["x0"]
59
-
60
- if (
61
- metrics is not None
62
- and panel is not None
63
- and selected_channel_name is not None
64
- ):
65
- if f"{metrics}#@{panel}#@{selected_channel_name}" in list(
66
- st.session_state["update_rcs"].keys()
67
- ):
68
- del st.session_state["update_rcs"][
69
- f"{metrics}#@{panel}#@{selected_channel_name}"
70
- ]
71
-
72
-
73
- def update_response_curve(
74
- K_updated,
75
- b_updated,
76
- a_updated,
77
- x0_updated,
78
- metrics=None,
79
- panel=None,
80
- selected_channel_name=None,
81
- ):
82
- print(
83
- "[DEBUG] update_response_curves: ",
84
- st.session_state["project_dct"]["scenario_planner"].keys(),
85
- )
86
- st.session_state["project_dct"]["scenario_planner"][unique_key].channels[
87
- selected_channel_name
88
- ].response_curve_params = {
89
- "K": st.session_state["K"],
90
- "b": st.session_state["b"],
91
- "a": st.session_state["a"],
92
- "x0": st.session_state["x0"],
93
- }
94
-
95
- # if (
96
- # metrics is not None
97
- # and panel is not None
98
- # and selected_channel_name is not None
99
- # ):
100
- # st.session_state["update_rcs"][
101
- # f"{metrics}#@{panel}#@{selected_channel_name}"
102
- # ] = {
103
- # "K": K_updated,
104
- # "b": b_updated,
105
- # "a": a_updated,
106
- # "x0": x0_updated,
107
- # }
108
-
109
- # st.session_state["scenario"].channels[
110
- # selected_channel_name
111
- # ].response_curve_params = {
112
- # "K": K_updated,
113
- # "b": b_updated,
114
- # "a": a_updated,
115
- # "x0": x0_updated,
116
- # }
117
-
118
-
119
- # authenticator = st.session_state.get('authenticator')
120
- # if authenticator is None:
121
- # authenticator = load_authenticator()
122
-
123
- # name, authentication_status, username = authenticator.login('Login', 'main')
124
- # auth_status = st.session_state.get('authentication_status')
125
-
126
- # if auth_status == True:
127
- # is_state_initiaized = st.session_state.get('initialized',False)
128
- # if not is_state_initiaized:
129
- # print("Scenario page state reloaded")
130
-
131
- import pandas as pd
132
-
133
-
134
- @st.cache_resource(show_spinner=False)
135
- def panel_fetch(file_selected):
136
- raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
137
-
138
- if "Panel" in raw_data_mmm_df.columns:
139
- panel = list(set(raw_data_mmm_df["Panel"]))
140
- else:
141
- raw_data_mmm_df = None
142
- panel = None
143
-
144
- return panel
145
-
146
-
147
- import glob
148
- import os
149
-
150
-
151
- def get_excel_names(directory):
152
- # Create a list to hold the final parts of the filenames
153
- last_portions = []
154
-
155
- # Patterns to match Excel files (.xlsx and .xls) that contain @#
156
- patterns = [
157
- os.path.join(directory, "*@#*.xlsx"),
158
- os.path.join(directory, "*@#*.xls"),
159
- ]
160
-
161
- # Process each pattern
162
- for pattern in patterns:
163
- files = glob.glob(pattern)
164
-
165
- # Extracting the last portion after @# for each file
166
- for file in files:
167
- base_name = os.path.basename(file)
168
- last_portion = base_name.split("@#")[-1]
169
- last_portion = last_portion.replace(".xlsx", "").replace(
170
- ".xls", ""
171
- ) # Removing extensions
172
- last_portions.append(last_portion)
173
-
174
- return last_portions
175
-
176
-
177
- def name_formating(channel_name):
178
- # Replace underscores with spaces
179
- name_mod = channel_name.replace("_", " ")
180
-
181
- # Capitalize the first letter of each word
182
- name_mod = name_mod.title()
183
-
184
- return name_mod
185
-
186
-
187
- def fetch_panel_data():
188
- print("DEBUG etch_panel_data: running... ")
189
- file_selected = f"./metrics_level_data/Overview_data_test_panel@#{st.session_state['response_metrics_selectbox']}.xlsx"
190
- panel_selected = st.session_state["panel_selected_selectbox"]
191
- print(panel_selected)
192
- if panel_selected == "Aggregated":
193
- (
194
- st.session_state["actual_input_df"],
195
- st.session_state["actual_contribution_df"],
196
- ) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
197
- else:
198
- (
199
- st.session_state["actual_input_df"],
200
- st.session_state["actual_contribution_df"],
201
- ) = fetch_actual_data(panel=panel_selected, target_file=file_selected)
202
-
203
- unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
204
- print("unique_key")
205
- if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
206
- if panel_selected == "Aggregated":
207
- initialize_data(
208
- panel=panel_selected,
209
- target_file=file_selected,
210
- updated_rcs={},
211
- metrics=metrics_selected,
212
- )
213
- panel = None
214
- else:
215
- initialize_data(
216
- panel=panel_selected,
217
- target_file=file_selected,
218
- updated_rcs={},
219
- metrics=metrics_selected,
220
- )
221
- st.session_state["project_dct"]["scenario_planner"][unique_key] = (
222
- st.session_state["scenario"]
223
- )
224
- # print(
225
- # "DEBUG etch_panel_data: ",
226
- # st.session_state["project_dct"]["scenario_planner"][
227
- # unique_key
228
- # ].keys(),
229
- # )
230
-
231
- else:
232
- st.session_state["scenario"] = st.session_state["project_dct"][
233
- "scenario_planner"
234
- ][unique_key]
235
- st.session_state["rcs"] = {}
236
- st.session_state["powers"] = {}
237
-
238
- for channel_name, _channel in st.session_state["project_dct"][
239
- "scenario_planner"
240
- ][unique_key].channels.items():
241
- st.session_state["rcs"][
242
- channel_name
243
- ] = _channel.response_curve_params
244
- st.session_state["powers"][channel_name] = _channel.power
245
-
246
- if "K" in st.session_state:
247
- del st.session_state["K"]
248
-
249
- if "b" in st.session_state:
250
- del st.session_state["b"]
251
-
252
- if "a" in st.session_state:
253
- del st.session_state["a"]
254
-
255
- if "x0" in st.session_state:
256
- del st.session_state["x0"]
257
-
258
-
259
- if "project_dct" not in st.session_state:
260
- st.error("Please load a project from home")
261
- st.stop()
262
-
263
- database_file = r"DB\User.db"
264
-
265
- conn = sqlite3.connect(
266
- database_file, check_same_thread=False
267
- ) # connection with sql db
268
- c = conn.cursor()
269
-
270
- st.subheader("Build Response Curves")
271
-
272
-
273
- if "update_rcs" not in st.session_state:
274
- st.session_state["update_rcs"] = {}
275
-
276
- st.session_state["first_time"] = True
277
-
278
- col1, col2, col3 = st.columns([1, 1, 1])
279
-
280
- directory = "metrics_level_data"
281
- metrics_list = get_excel_names(directory)
282
-
283
-
284
- metrics_selected = col1.selectbox(
285
- "Response Metrics",
286
- metrics_list,
287
- on_change=fetch_panel_data,
288
- format_func=name_formating,
289
- key="response_metrics_selectbox",
290
- )
291
-
292
-
293
- file_selected = (
294
- f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
295
- )
296
-
297
- panel_list = panel_fetch(file_selected)
298
- final_panel_list = ["Aggregated"] + panel_list
299
-
300
- panel_selected = col3.selectbox(
301
- "Panel",
302
- final_panel_list,
303
- on_change=fetch_panel_data,
304
- key="panel_selected_selectbox",
305
- )
306
-
307
-
308
- is_state_initiaized = st.session_state.get("initialized_rcs", False)
309
- print(is_state_initiaized)
310
- if not is_state_initiaized:
311
- print("DEBUG.....", "Here")
312
- fetch_panel_data()
313
- # if panel_selected == "Aggregated":
314
- # initialize_data(panel=panel_selected, target_file=file_selected)
315
- # panel = None
316
- # else:
317
- # initialize_data(panel=panel_selected, target_file=file_selected)
318
-
319
- st.session_state["initialized_rcs"] = True
320
-
321
- # channels_list = st.session_state["channels_list"]
322
- unique_key = f"{st.session_state['response_metrics_selectbox']}-{st.session_state['panel_selected_selectbox']}"
323
- chanel_list_final = list(
324
- st.session_state["project_dct"]["scenario_planner"][
325
- unique_key
326
- ].channels.keys()
327
- ) + ["Others"]
328
-
329
-
330
- selected_channel_name = col2.selectbox(
331
- "Channel",
332
- chanel_list_final,
333
- format_func=channel_name_formating,
334
- on_change=reset_curve_parameters,
335
- key="selected_channel_name_selectbox",
336
- )
337
-
338
-
339
- rcs = st.session_state["rcs"]
340
-
341
- if "K" not in st.session_state:
342
- st.session_state["K"] = rcs[selected_channel_name]["K"]
343
-
344
- if "b" not in st.session_state:
345
- st.session_state["b"] = rcs[selected_channel_name]["b"]
346
-
347
-
348
- if "a" not in st.session_state:
349
- st.session_state["a"] = rcs[selected_channel_name]["a"]
350
-
351
- if "x0" not in st.session_state:
352
- st.session_state["x0"] = rcs[selected_channel_name]["x0"]
353
-
354
-
355
- x = st.session_state["actual_input_df"][selected_channel_name].values
356
- y = st.session_state["actual_contribution_df"][selected_channel_name].values
357
-
358
-
359
- power = np.ceil(np.log(x.max()) / np.log(10)) - 3
360
-
361
- print(f"DEBUG BUILD RCS: {selected_channel_name}")
362
- print(f"DEBUG BUILD RCS: K : {st.session_state['K']}")
363
- print(f"DEBUG BUILD RCS: b : {st.session_state['b']}")
364
- print(f"DEBUG BUILD RCS: a : {st.session_state['a']}")
365
- print(f"DEBUG BUILD RCS: x0: {st.session_state['x0']}")
366
-
367
- # fig = px.scatter(x, s_curve(x/10**power,
368
- # st.session_state['K'],
369
- # st.session_state['b'],
370
- # st.session_state['a'],
371
- # st.session_state['x0']))
372
-
373
- x_plot = np.linspace(0, 5 * max(x), 50)
374
-
375
- fig = px.scatter(x=x, y=y)
376
- fig.add_trace(
377
- go.Scatter(
378
- x=x_plot,
379
- y=s_curve(
380
- x_plot / 10**power,
381
- st.session_state["K"],
382
- st.session_state["b"],
383
- st.session_state["a"],
384
- st.session_state["x0"],
385
- ),
386
- line=dict(color="red"),
387
- name="Modified",
388
- ),
389
- )
390
-
391
- fig.add_trace(
392
- go.Scatter(
393
- x=x_plot,
394
- y=s_curve(
395
- x_plot / 10**power,
396
- rcs[selected_channel_name]["K"],
397
- rcs[selected_channel_name]["b"],
398
- rcs[selected_channel_name]["a"],
399
- rcs[selected_channel_name]["x0"],
400
- ),
401
- line=dict(color="rgba(0, 255, 0, 0.4)"),
402
- name="Actual",
403
- ),
404
- )
405
-
406
- fig.update_layout(title_text="Response Curve", showlegend=True)
407
- fig.update_annotations(font_size=10)
408
- fig.update_xaxes(title="Spends")
409
- fig.update_yaxes(title="Revenue")
410
-
411
- st.plotly_chart(fig, use_container_width=True)
412
-
413
- r2 = r2_score(
414
- y,
415
- s_curve(
416
- x / 10**power,
417
- st.session_state["K"],
418
- st.session_state["b"],
419
- st.session_state["a"],
420
- st.session_state["x0"],
421
- ),
422
- )
423
-
424
- r2_actual = r2_score(
425
- y,
426
- s_curve(
427
- x / 10**power,
428
- rcs[selected_channel_name]["K"],
429
- rcs[selected_channel_name]["b"],
430
- rcs[selected_channel_name]["a"],
431
- rcs[selected_channel_name]["x0"],
432
- ),
433
- )
434
-
435
- columns = st.columns((1, 1, 2))
436
- with columns[0]:
437
- st.metric("R2 Modified", round(r2, 2))
438
- with columns[1]:
439
- st.metric("R2 Actual", round(r2_actual, 2))
440
-
441
-
442
- st.markdown("#### Set Parameters", unsafe_allow_html=True)
443
- columns = st.columns(4)
444
-
445
- if "updated_parms" not in st.session_state:
446
- st.session_state["updated_parms"] = {
447
- "K_updated": 0,
448
- "b_updated": 0,
449
- "a_updated": 0,
450
- "x0_updated": 0,
451
- }
452
-
453
- with columns[0]:
454
- st.session_state["updated_parms"]["K_updated"] = st.number_input(
455
- "K", key="K", format="%0.5f"
456
- )
457
- with columns[1]:
458
- st.session_state["updated_parms"]["b_updated"] = st.number_input(
459
- "b", key="b", format="%0.5f"
460
- )
461
- with columns[2]:
462
- st.session_state["updated_parms"]["a_updated"] = st.number_input(
463
- "a", key="a", step=0.0001, format="%0.5f"
464
- )
465
- with columns[3]:
466
- st.session_state["updated_parms"]["x0_updated"] = st.number_input(
467
- "x0", key="x0", format="%0.5f"
468
- )
469
-
470
- # st.session_state["project_dct"]["scenario_planner"]["K_number_input"] = (
471
- # st.session_state["updated_parms"]["K_updated"]
472
- # )
473
- # st.session_state["project_dct"]["scenario_planner"]["b_number_input"] = (
474
- # st.session_state["updated_parms"]["b_updated"]
475
- # )
476
- # st.session_state["project_dct"]["scenario_planner"]["a_number_input"] = (
477
- # st.session_state["updated_parms"]["a_updated"]
478
- # )
479
- # st.session_state["project_dct"]["scenario_planner"]["x0_number_input"] = (
480
- # st.session_state["updated_parms"]["x0_updated"]
481
- # )
482
-
483
- update_col, reset_col = st.columns([1, 1])
484
- if update_col.button(
485
- "Update Parameters",
486
- on_click=update_response_curve,
487
- args=(
488
- st.session_state["updated_parms"]["K_updated"],
489
- st.session_state["updated_parms"]["b_updated"],
490
- st.session_state["updated_parms"]["a_updated"],
491
- st.session_state["updated_parms"]["x0_updated"],
492
- metrics_selected,
493
- panel_selected,
494
- selected_channel_name,
495
- ),
496
- use_container_width=True,
497
- ):
498
- st.session_state["rcs"][selected_channel_name]["K"] = st.session_state[
499
- "updated_parms"
500
- ]["K_updated"]
501
- st.session_state["rcs"][selected_channel_name]["b"] = st.session_state[
502
- "updated_parms"
503
- ]["b_updated"]
504
- st.session_state["rcs"][selected_channel_name]["a"] = st.session_state[
505
- "updated_parms"
506
- ]["a_updated"]
507
- st.session_state["rcs"][selected_channel_name]["x0"] = st.session_state[
508
- "updated_parms"
509
- ]["x0_updated"]
510
-
511
- reset_col.button(
512
- "Reset Parameters",
513
- on_click=reset_curve_parameters,
514
- args=(metrics_selected, panel_selected, selected_channel_name),
515
- use_container_width=True,
516
- )
517
-
518
- st.divider()
519
- save_col, down_col = st.columns([1, 1])
520
-
521
-
522
- with save_col:
523
- file_name = st.text_input(
524
- "rcs download file name",
525
- key="file_name_input",
526
- placeholder="File name",
527
- label_visibility="collapsed",
528
- )
529
- down_col.download_button(
530
- label="Download response curves",
531
- data=json.dumps(rcs),
532
- file_name=f"{file_name}.json",
533
- mime="application/json",
534
- disabled=len(file_name) == 0,
535
- use_container_width=True,
536
- )
537
-
538
-
539
- def s_curve_derivative(x, K, b, a, x0):
540
- # Derivative of the S-curve function
541
- return (
542
- a
543
- * b
544
- * K
545
- * np.exp(-a * (x - x0))
546
- / ((1 + b * np.exp(-a * (x - x0))) ** 2)
547
- )
548
-
549
-
550
- # Parameters of the S-curve
551
- K = st.session_state["K"]
552
- b = st.session_state["b"]
553
- a = st.session_state["a"]
554
- x0 = st.session_state["x0"]
555
-
556
- # # Optimized spend value obtained from the tool
557
- # optimized_spend = st.number_input(
558
- # "value of x"
559
- # ) # Replace this with your optimized spend value
560
-
561
- # # Calculate the slope at the optimized spend value
562
- # slope_at_optimized_spend = s_curve_derivative(optimized_spend, K, b, a, x0)
563
-
564
- # st.write("Slope ", slope_at_optimized_spend)
565
-
566
-
567
- # Initialize a list to hold our rows
568
- rows = []
569
-
570
- # Iterate over the dictionary
571
- for key, value in st.session_state["update_rcs"].items():
572
- # Split the key into its components
573
- metrics, panel, channel_name = key.split("#@")
574
- # Create a new row with the components and the values
575
- row = {
576
- "Metrics": name_formating(metrics),
577
- "Panel": name_formating(panel),
578
- "Channel Name": channel_name,
579
- "K": value["K"],
580
- "b": value["b"],
581
- "a": value["a"],
582
- "x0": value["x0"],
583
- }
584
- # Append the row to our list
585
- rows.append(row)
586
-
587
- # Convert the list of rows into a DataFrame
588
- updated_parms_df = pd.DataFrame(rows)
589
-
590
- if len(list(st.session_state["update_rcs"].keys())) > 0:
591
- st.markdown("#### Updated Parameters", unsafe_allow_html=True)
592
- st.dataframe(updated_parms_df, hide_index=True)
593
- else:
594
- st.info("No parameters are updated")
595
-
596
- update_db("8_Build_Response_Curves.py")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/9_Scenario_Planner.py DELETED
@@ -1,1712 +0,0 @@
1
- import streamlit as st
2
- from numerize.numerize import numerize
3
- import numpy as np
4
- from functools import partial
5
- from collections import OrderedDict
6
- from plotly.subplots import make_subplots
7
- import plotly.graph_objects as go
8
- from utilities import (
9
- format_numbers,
10
- load_local_css,
11
- set_header,
12
- initialize_data,
13
- load_authenticator,
14
- send_email,
15
- channel_name_formating,
16
- )
17
- from classes import class_from_dict, class_to_dict
18
- import pickle
19
- import streamlit_authenticator as stauth
20
- import yaml
21
- from yaml import SafeLoader
22
- import re
23
- import pandas as pd
24
- import plotly.express as px
25
- import logging
26
- from utilities import update_db
27
- import sqlite3
28
-
29
-
30
- st.set_page_config(layout="wide")
31
- load_local_css("styles.css")
32
- set_header()
33
-
34
- for k, v in st.session_state.items():
35
- if k not in ["logout", "login", "config"] and not k.startswith(
36
- "FormSubmitter"
37
- ):
38
- st.session_state[k] = v
39
- # ======================================================== #
40
- # ======================= Functions ====================== #
41
- # ======================================================== #
42
-
43
-
44
- def optimize(key, status_placeholder):
45
- """
46
- Optimize the spends for the sales
47
- """
48
-
49
- channel_list = [
50
- key
51
- for key, value in st.session_state["optimization_channels"].items()
52
- if value
53
- ]
54
-
55
- if len(channel_list) > 0:
56
- scenario = st.session_state["scenario"]
57
- if key.lower() == "media spends":
58
- with status_placeholder:
59
- with st.spinner("Optimizing"):
60
- result = st.session_state["scenario"].optimize(
61
- st.session_state["total_spends_change"], channel_list
62
- )
63
- # elif key.lower() == "revenue":
64
- else:
65
- with status_placeholder:
66
- with st.spinner("Optimizing"):
67
-
68
- result = st.session_state["scenario"].optimize_spends(
69
- st.session_state["total_sales_change"], channel_list
70
- )
71
- for channel_name, modified_spends in result:
72
-
73
- st.session_state[channel_name] = numerize(
74
- modified_spends
75
- * scenario.channels[channel_name].conversion_rate,
76
- 1,
77
- )
78
- prev_spends = (
79
- st.session_state["scenario"]
80
- .channels[channel_name]
81
- .actual_total_spends
82
- )
83
- st.session_state[f"{channel_name}_change"] = round(
84
- 100 * (modified_spends - prev_spends) / prev_spends, 2
85
- )
86
-
87
-
88
- def save_scenario(scenario_name):
89
- """
90
- Save the current scenario with the mentioned name in the session state
91
-
92
- Parameters
93
- ----------
94
- scenario_name
95
- Name of the scenario to be saved
96
- """
97
- if "saved_scenarios" not in st.session_state:
98
- st.session_state = OrderedDict()
99
-
100
- # st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
101
- st.session_state["saved_scenarios"][scenario_name] = class_to_dict(
102
- st.session_state["scenario"]
103
- )
104
- st.session_state["scenario_input"] = ""
105
- # print(type(st.session_state['saved_scenarios']))
106
- with open("../saved_scenarios.pkl", "wb") as f:
107
- pickle.dump(st.session_state["saved_scenarios"], f)
108
-
109
-
110
- def update_sales_abs_slider():
111
- actual_sales = st.session_state["scenario"].actual_total_sales
112
- if validate_input(st.session_state["total_sales_change_abs_slider"]):
113
- modified_sales = extract_number_for_string(
114
- st.session_state["total_sales_change_abs_slider"]
115
- )
116
- st.session_state["total_sales_change"] = round(
117
- ((modified_sales / actual_sales) - 1) * 100
118
- )
119
- st.session_state["total_sales_change_abs"] = numerize(
120
- modified_sales, 1
121
- )
122
-
123
- st.session_state["project_dct"]["scenario_planner"][
124
- "total_sales_change"
125
- ] = st.session_state.total_sales_change
126
-
127
-
128
- def update_sales_abs():
129
- actual_sales = st.session_state["scenario"].actual_total_sales
130
- if validate_input(st.session_state["total_sales_change_abs"]):
131
- modified_sales = extract_number_for_string(
132
- st.session_state["total_sales_change_abs"]
133
- )
134
- st.session_state["total_sales_change"] = round(
135
- ((modified_sales / actual_sales) - 1) * 100
136
- )
137
- st.session_state["total_sales_change_abs_slider"] = numerize(
138
- modified_sales, 1
139
- )
140
-
141
-
142
- def update_sales():
143
- # print("DEBUG: running update_sales")
144
- # st.session_state["project_dct"]["scenario_planner"][
145
- # "total_sales_change"
146
- # ] = st.session_state.total_sales_change
147
- # st.session_state["total_spends_change"] = st.session_state[
148
- # "total_sales_change"
149
- # ]
150
-
151
- st.session_state["total_sales_change_abs"] = numerize(
152
- (1 + st.session_state["total_sales_change"] / 100)
153
- * st.session_state["scenario"].actual_total_sales,
154
- 1,
155
- )
156
- st.session_state["total_sales_change_abs_slider"] = numerize(
157
- (1 + st.session_state["total_sales_change"] / 100)
158
- * st.session_state["scenario"].actual_total_sales,
159
- 1,
160
- )
161
- # update_spends()
162
-
163
-
164
- def update_all_spends_abs_slider():
165
- actual_spends = st.session_state["scenario"].actual_total_spends
166
- if validate_input(st.session_state["total_spends_change_abs_slider"]):
167
- modified_spends = extract_number_for_string(
168
- st.session_state["total_spends_change_abs_slider"]
169
- )
170
- st.session_state["total_spends_change"] = round(
171
- ((modified_spends / actual_spends) - 1) * 100
172
- )
173
- st.session_state["total_spends_change_abs"] = numerize(
174
- modified_spends, 1
175
- )
176
-
177
- st.session_state["project_dct"]["scenario_planner"][
178
- "total_spends_change"
179
- ] = st.session_state.total_spends_change
180
-
181
- update_all_spends()
182
-
183
-
184
- # def update_all_spends_abs_slider():
185
- # actual_spends = _scenario.actual_total_spends
186
- # if validate_input(st.session_state["total_spends_change_abs_slider"]):
187
- # print("#" * 100)
188
- # print(st.session_state["total_spends_change_abs_slider"])
189
- # print("#" * 100)
190
-
191
- # modified_spends = extract_number_for_string(
192
- # st.session_state["total_spends_change_abs_slider"]
193
- # )
194
- # st.session_state["total_spends_change"] = (
195
- # (modified_spends / actual_spends) - 1
196
- # ) * 100
197
- # st.session_state["total_spends_change_abs"] = st.session_state[
198
- # "total_spends_change_abs_slider"
199
- # ]
200
-
201
- # update_all_spends()
202
-
203
-
204
- def update_all_spends_abs():
205
- print("DEBUG: ", "inside update_all_spends_abs")
206
- # print(st.session_state["total_spends_change_abs_slider_options"])
207
-
208
- actual_spends = st.session_state["scenario"].actual_total_spends
209
- if validate_input(st.session_state["total_spends_change_abs"]):
210
- modified_spends = extract_number_for_string(
211
- st.session_state["total_spends_change_abs"]
212
- )
213
- st.session_state["total_spends_change"] = (
214
- (modified_spends / actual_spends) - 1
215
- ) * 100
216
- st.session_state["total_spends_change_abs_slider"] = numerize(
217
- extract_number_for_string(
218
- st.session_state["total_spends_change_abs"]
219
- ),
220
- 1,
221
- )
222
-
223
- st.session_state["project_dct"]["scenario_planner"][
224
- "total_spends_change"
225
- ] = st.session_state.total_spends_change
226
-
227
- # print(
228
- # "DEBUG UPDATE_ALL_SPENDS_ABS: ",
229
- # st.session_state["total_spends_change"],
230
- # )
231
- update_all_spends()
232
-
233
-
234
- def update_spends():
235
- print("update_spends")
236
- st.session_state["total_spends_change_abs"] = numerize(
237
- (1 + st.session_state["total_spends_change"] / 100)
238
- * st.session_state["scenario"].actual_total_spends,
239
- 1,
240
- )
241
- st.session_state["total_spends_change_abs_slider"] = numerize(
242
- (1 + st.session_state["total_spends_change"] / 100)
243
- * st.session_state["scenario"].actual_total_spends,
244
- 1,
245
- )
246
-
247
- st.session_state["project_dct"]["scenario_planner"][
248
- "total_spends_change"
249
- ] = st.session_state.total_spends_change
250
-
251
- update_all_spends()
252
-
253
-
254
- def update_all_spends():
255
- """
256
- Updates spends for all the channels with the given overall spends change
257
- """
258
- percent_change = st.session_state["total_spends_change"]
259
- print("runs update_all")
260
- for channel_name in list(
261
- st.session_state["project_dct"]["scenario_planner"][
262
- unique_key
263
- ].channels.keys()
264
- ):
265
- st.session_state[f"{channel_name}_percent"] = percent_change
266
- channel = st.session_state["scenario"].channels[channel_name]
267
- current_spends = channel.actual_total_spends
268
- modified_spends = (1 + percent_change / 100) * current_spends
269
- st.session_state["scenario"].update(channel_name, modified_spends)
270
- st.session_state[channel_name] = numerize(
271
- modified_spends * channel.conversion_rate, 1
272
- )
273
- st.session_state[f"{channel_name}_change"] = percent_change
274
-
275
-
276
- def extract_number_for_string(string_input):
277
- string_input = string_input.upper()
278
- if string_input.endswith("K"):
279
- return float(string_input[:-1]) * 10**3
280
- elif string_input.endswith("M"):
281
- return float(string_input[:-1]) * 10**6
282
- elif string_input.endswith("B"):
283
- return float(string_input[:-1]) * 10**9
284
-
285
-
286
- def validate_input(string_input):
287
- pattern = r"\d+\.?\d*[K|M|B]$"
288
- match = re.match(pattern, string_input)
289
- if match is None:
290
- return False
291
- return True
292
-
293
-
294
- def update_data_by_percent(channel_name):
295
- prev_spends = (
296
- st.session_state["scenario"].channels[channel_name].actual_total_spends
297
- * st.session_state["scenario"].channels[channel_name].conversion_rate
298
- )
299
- modified_spends = prev_spends * (
300
- 1 + st.session_state[f"{channel_name}_percent"] / 100
301
- )
302
-
303
- st.session_state[channel_name] = numerize(modified_spends, 1)
304
-
305
- st.session_state["scenario"].update(
306
- channel_name,
307
- modified_spends
308
- / st.session_state["scenario"].channels[channel_name].conversion_rate,
309
- )
310
-
311
-
312
- def update_data(channel_name):
313
- """
314
- Updates the spends for the given channel
315
- """
316
- print("tuns update_Data")
317
- if validate_input(st.session_state[channel_name]):
318
- modified_spends = extract_number_for_string(
319
- st.session_state[channel_name]
320
- )
321
-
322
- prev_spends = (
323
- st.session_state["scenario"]
324
- .channels[channel_name]
325
- .actual_total_spends
326
- * st.session_state["scenario"]
327
- .channels[channel_name]
328
- .conversion_rate
329
- )
330
- st.session_state[f"{channel_name}_percent"] = round(
331
- 100 * (modified_spends - prev_spends) / prev_spends, 2
332
- )
333
- st.session_state["scenario"].update(
334
- channel_name,
335
- modified_spends
336
- / st.session_state["scenario"]
337
- .channels[channel_name]
338
- .conversion_rate,
339
- )
340
- # st.session_state['scenario'].update(channel_name, modified_spends)
341
- # else:
342
- # try:
343
- # modified_spends = float(st.session_state[channel_name])
344
- # prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate
345
- # st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2)
346
- # st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate)
347
- # st.session_state[f'{channel_name}'] = numerize(modified_spends,1)
348
- # except ValueError:
349
- # st.write('Invalid input')
350
-
351
-
352
- def select_channel_for_optimization(channel_name):
353
- """
354
- Marks the given channel for optimization
355
- """
356
- st.session_state["optimization_channels"][channel_name] = st.session_state[
357
- f"{channel_name}_selected"
358
- ]
359
-
360
-
361
- def select_all_channels_for_optimization():
362
- """
363
- Marks all the channel for optimization
364
- """
365
- # print(
366
- # "DEBUG: select_all_channels_for_opt",
367
- # st.session_state["optimze_all_channels"],
368
- # )
369
-
370
- for channel_name in st.session_state["optimization_channels"].keys():
371
- st.session_state[f"{channel_name}_selected"] = st.session_state[
372
- "optimze_all_channels"
373
- ]
374
- st.session_state["optimization_channels"][channel_name] = (
375
- st.session_state["optimze_all_channels"]
376
- )
377
- from pprint import pprint
378
-
379
-
380
- def update_penalty():
381
- """
382
- Updates the penalty flag for sales calculation
383
- """
384
- st.session_state["scenario"].update_penalty(
385
- st.session_state["apply_penalty"]
386
- )
387
-
388
-
389
- def reset_optimization():
390
- print("DEBUG: ", "Running reset_optimization")
391
- for channel_name in list(
392
- st.session_state["project_dct"]["scenario_planner"][
393
- unique_key
394
- ].channels.keys()
395
- ):
396
- st.session_state[f"{channel_name}_selected"] = False
397
- # st.session_state[f"{channel_name}_change"] = 0
398
- st.session_state["optimze_all_channels"] = False
399
- st.session_state["initialized"] = False
400
- del st.session_state["total_sales_change_abs_slider"]
401
- del st.session_state["total_sales_change_abs"]
402
- del st.session_state["total_sales_change"]
403
-
404
-
405
- def reset_scenario():
406
- print("[DEBUG]: reset_scenario")
407
- # def reset_scenario(panel_selected, file_selected, updated_rcs):
408
- # #print(st.session_state['default_scenario_dict'])
409
- # st.session_state['scenario'] = class_from_dict(st.session_state['default_scenario_dict'])
410
- # for channel in st.session_state['scenario'].channels.values():
411
- # st.session_state[channel.name] = float(channel.actual_total_spends * channel.conversion_rate)
412
- for channel_name in list(
413
- st.session_state["project_dct"]["scenario_planner"][
414
- unique_key
415
- ].channels.keys()
416
- ):
417
- st.session_state[f"{channel_name}_selected"] = False
418
- # st.session_state[f"{channel_name}_change"] = 0
419
- st.session_state["optimze_all_channels"] = False
420
- st.session_state["initialized"] = False
421
-
422
- del st.session_state["optimization_channels"]
423
- panel_selected = st.session_state.get("panel_selected", 0)
424
- file_selected = st.session_state["file_selected"]
425
- update_rcs = st.session_state.get("update_rcs", None)
426
-
427
- # print(f"## [DEBUG] [SCENARIO PLANNER][RESET SCENARIO]: {}")
428
- del st.session_state["project_dct"]["scenario_planner"][
429
- f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
430
- ]
431
- del st.session_state["total_sales_change_abs_slider"]
432
- del st.session_state["total_sales_change_abs"]
433
- del st.session_state["total_sales_change"]
434
- # if panel_selected == "Aggregated":
435
- # initialize_data(
436
- # panel=panel_selected,
437
- # target_file=file_selected,
438
- # updated_rcs=updated_rcs,
439
- # metrics=metrics_selected,
440
- # )
441
- # panel = None
442
- # else:
443
- # initialize_data(
444
- # panel=panel_selected,
445
- # target_file=file_selected,
446
- # updated_rcs=updated_rcs,
447
- # metrics=metrics_selected,
448
- # )
449
- # st.session_state["total_spends_change"] = 0
450
- # update_all_spends()
451
-
452
-
453
- def format_number(num):
454
- if num >= 1_000_000:
455
- return f"{num / 1_000_000:.2f}M"
456
- elif num >= 1_000:
457
- return f"{num / 1_000:.0f}K"
458
- else:
459
- return f"{num:.2f}"
460
-
461
-
462
- def summary_plot(data, x, y, title, text_column):
463
- fig = px.bar(
464
- data,
465
- x=x,
466
- y=y,
467
- orientation="h",
468
- title=title,
469
- text=text_column,
470
- color="Channel_name",
471
- )
472
-
473
- # Convert text_column to numeric values
474
- data[text_column] = pd.to_numeric(data[text_column], errors="coerce")
475
-
476
- # Update the format of the displayed text based on magnitude
477
- fig.update_traces(
478
- texttemplate="%{text:.2s}",
479
- textposition="outside",
480
- hovertemplate="%{x:.2s}",
481
- )
482
-
483
- fig.update_layout(
484
- xaxis_title=x, yaxis_title="Channel Name", showlegend=False
485
- )
486
- return fig
487
-
488
-
489
- def s_curve(x, K, b, a, x0):
490
- return K / (1 + b * np.exp(-a * (x - x0)))
491
-
492
-
493
- def find_segment_value(x, roi, mroi):
494
- start_value = x[0]
495
- end_value = x[len(x) - 1]
496
-
497
- # Condition for green region: Both MROI and ROI > 1
498
- green_condition = (roi > 1) & (mroi > 1)
499
- left_indices = np.where(green_condition)[0]
500
- left_value = x[left_indices[0]] if left_indices.size > 0 else x[0]
501
-
502
- right_indices = np.where(green_condition)[0]
503
- right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0]
504
-
505
- return start_value, end_value, left_value, right_value
506
-
507
-
508
- def calculate_rgba(
509
- start_value, end_value, left_value, right_value, current_channel_spends
510
- ):
511
- # Initialize alpha to None for clarity
512
- alpha = None
513
-
514
- # Determine the color and calculate relative_position and alpha based on the point's position
515
- if start_value <= current_channel_spends <= left_value:
516
- color = "yellow"
517
- relative_position = (current_channel_spends - start_value) / (
518
- left_value - start_value
519
- )
520
- alpha = 0.8 - (
521
- 0.6 * relative_position
522
- ) # Alpha decreases from start to end
523
-
524
- elif left_value < current_channel_spends <= right_value:
525
- color = "green"
526
- relative_position = (current_channel_spends - left_value) / (
527
- right_value - left_value
528
- )
529
- alpha = 0.8 - (
530
- 0.6 * relative_position
531
- ) # Alpha decreases from start to end
532
-
533
- elif right_value < current_channel_spends <= end_value:
534
- color = "red"
535
- relative_position = (current_channel_spends - right_value) / (
536
- end_value - right_value
537
- )
538
- alpha = 0.2 + (
539
- 0.6 * relative_position
540
- ) # Alpha increases from start to end
541
-
542
- else:
543
- # Default case, if the spends are outside the defined ranges
544
- return "rgba(136, 136, 136, 0.5)" # Grey for values outside the range
545
-
546
- # Ensure alpha is within the intended range in case of any calculation overshoot
547
- alpha = max(0.2, min(alpha, 0.8))
548
-
549
- # Define color codes for RGBA
550
- color_codes = {
551
- "yellow": "255, 255, 0", # RGB for yellow
552
- "green": "0, 128, 0", # RGB for green
553
- "red": "255, 0, 0", # RGB for red
554
- }
555
-
556
- rgba = f"rgba({color_codes[color]}, {alpha})"
557
- return rgba
558
-
559
-
560
- def debug_temp(x_test, power, K, b, a, x0):
561
- print("*" * 100)
562
- # Calculate the count of bins
563
- count_lower_bin = sum(1 for x in x_test if x <= 2524)
564
- count_center_bin = sum(1 for x in x_test if x > 2524 and x <= 3377)
565
- count_ = sum(1 for x in x_test if x > 3377)
566
-
567
- print(
568
- f"""
569
- lower : {count_lower_bin}
570
- center : {count_center_bin}
571
- upper : {count_}
572
- """
573
- )
574
-
575
-
576
- # @st.cache
577
- def plot_response_curves():
578
- cols = 4
579
- rows = (
580
- len(channels_list) // cols
581
- if len(channels_list) % cols == 0
582
- else len(channels_list) // cols + 1
583
- )
584
- rcs = st.session_state["rcs"]
585
- shapes = []
586
- fig = make_subplots(rows=rows, cols=cols, subplot_titles=channels_list)
587
- for i in range(0, len(channels_list)):
588
- col = channels_list[i]
589
- x_actual = st.session_state["scenario"].channels[col].actual_spends
590
- # x_modified = st.session_state["scenario"].channels[col].modified_spends
591
-
592
- power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
593
-
594
- K = rcs[col]["K"]
595
- b = rcs[col]["b"]
596
- a = rcs[col]["a"]
597
- x0 = rcs[col]["x0"]
598
-
599
- x_plot = np.linspace(0, 5 * x_actual.sum(), 50)
600
-
601
- x, y, marginal_roi = [], [], []
602
- for x_p in x_plot:
603
- x.append(x_p * x_actual / x_actual.sum())
604
-
605
- for index in range(len(x_plot)):
606
- y.append(s_curve(x[index] / 10**power, K, b, a, x0))
607
-
608
- for index in range(len(x_plot)):
609
- marginal_roi.append(
610
- a
611
- * y[index]
612
- * (1 - y[index] / np.maximum(K, np.finfo(float).eps))
613
- )
614
-
615
- x = (
616
- np.sum(x, axis=1)
617
- * st.session_state["scenario"].channels[col].conversion_rate
618
- )
619
- y = np.sum(y, axis=1)
620
- marginal_roi = (
621
- np.average(marginal_roi, axis=1)
622
- / st.session_state["scenario"].channels[col].conversion_rate
623
- )
624
-
625
- roi = y / np.maximum(x, np.finfo(float).eps)
626
-
627
- fig.add_trace(
628
- go.Scatter(
629
- x=x,
630
- y=y,
631
- name=col,
632
- customdata=np.stack((roi, marginal_roi), axis=-1),
633
- hovertemplate="Spend:%{x:$.2s}<br>Sale:%{y:$.2s}<br>ROI:%{customdata[0]:.3f}<br>MROI:%{customdata[1]:.3f}",
634
- line=dict(color="blue"),
635
- ),
636
- row=1 + (i) // cols,
637
- col=i % cols + 1,
638
- )
639
-
640
- x_optimal = (
641
- st.session_state["scenario"].channels[col].modified_total_spends
642
- * st.session_state["scenario"].channels[col].conversion_rate
643
- )
644
- y_optimal = (
645
- st.session_state["scenario"].channels[col].modified_total_sales
646
- )
647
-
648
- # if col == "Paid_social_others":
649
- # debug_temp(x_optimal * x_actual / x_actual.sum(), power, K, b, a, x0)
650
-
651
- fig.add_trace(
652
- go.Scatter(
653
- x=[x_optimal],
654
- y=[y_optimal],
655
- name=col,
656
- legendgroup=col,
657
- showlegend=False,
658
- marker=dict(color=["black"]),
659
- ),
660
- row=1 + (i) // cols,
661
- col=i % cols + 1,
662
- )
663
-
664
- shapes.append(
665
- go.layout.Shape(
666
- type="line",
667
- x0=0,
668
- y0=y_optimal,
669
- x1=x_optimal,
670
- y1=y_optimal,
671
- line_width=1,
672
- line_dash="dash",
673
- line_color="black",
674
- xref=f"x{i+1}",
675
- yref=f"y{i+1}",
676
- )
677
- )
678
-
679
- shapes.append(
680
- go.layout.Shape(
681
- type="line",
682
- x0=x_optimal,
683
- y0=0,
684
- x1=x_optimal,
685
- y1=y_optimal,
686
- line_width=1,
687
- line_dash="dash",
688
- line_color="black",
689
- xref=f"x{i+1}",
690
- yref=f"y{i+1}",
691
- )
692
- )
693
-
694
- start_value, end_value, left_value, right_value = find_segment_value(
695
- x,
696
- roi,
697
- marginal_roi,
698
- )
699
-
700
- # Adding background colors
701
- y_max = y.max() * 1.3 # 30% extra space above the max
702
-
703
- # Yellow region
704
- shapes.append(
705
- go.layout.Shape(
706
- type="rect",
707
- x0=start_value,
708
- y0=0,
709
- x1=left_value,
710
- y1=y_max,
711
- line=dict(width=0),
712
- fillcolor="rgba(255, 255, 0, 0.3)",
713
- layer="below",
714
- xref=f"x{i+1}",
715
- yref=f"y{i+1}",
716
- )
717
- )
718
-
719
- # Green region
720
- shapes.append(
721
- go.layout.Shape(
722
- type="rect",
723
- x0=left_value,
724
- y0=0,
725
- x1=right_value,
726
- y1=y_max,
727
- line=dict(width=0),
728
- fillcolor="rgba(0, 255, 0, 0.3)",
729
- layer="below",
730
- xref=f"x{i+1}",
731
- yref=f"y{i+1}",
732
- )
733
- )
734
-
735
- # Red region
736
- shapes.append(
737
- go.layout.Shape(
738
- type="rect",
739
- x0=right_value,
740
- y0=0,
741
- x1=end_value,
742
- y1=y_max,
743
- line=dict(width=0),
744
- fillcolor="rgba(255, 0, 0, 0.3)",
745
- layer="below",
746
- xref=f"x{i+1}",
747
- yref=f"y{i+1}",
748
- )
749
- )
750
-
751
- fig.update_layout(
752
- # height=1000,
753
- # width=1000,
754
- title_text=f"Response Curves (X: Spends Vs Y: {target})",
755
- showlegend=False,
756
- shapes=shapes,
757
- )
758
- fig.update_annotations(font_size=10)
759
- # fig.update_xaxes(title="Spends")
760
- # fig.update_yaxes(title=target)
761
- fig.update_yaxes(
762
- gridcolor="rgba(136, 136, 136, 0.5)", gridwidth=0.5, griddash="dash"
763
- )
764
-
765
- return fig
766
-
767
-
768
- # ======================================================== #
769
- # ==================== HTML Components =================== #
770
- # ======================================================== #
771
-
772
-
773
- def generate_spending_header(heading):
774
- return st.markdown(
775
- f"""<h2 class="spends-header">{heading}</h2>""", unsafe_allow_html=True
776
- )
777
-
778
-
779
- def save_checkpoint():
780
- project_dct_path = os.path.join(
781
- st.session_state["project_path"], "project_dct.pkl"
782
- )
783
-
784
- try:
785
- pickle.dumps(st.session_state["project_dct"])
786
- with open(project_dct_path, "wb") as f:
787
- pickle.dump(st.session_state["project_dct"], f)
788
- except Exception:
789
- # with warning_placeholder:
790
- st.toast("Unknown Issue, please reload the page.")
791
-
792
-
793
- def reset_checkpoint():
794
- st.session_state["project_dct"]["scenario_planner"] = {}
795
- save_checkpoint()
796
-
797
-
798
- # ======================================================== #
799
- # =================== Session variables ================== #
800
- # ======================================================== #
801
-
802
- with open("config.yaml") as file:
803
- config = yaml.load(file, Loader=SafeLoader)
804
- st.session_state["config"] = config
805
-
806
- authenticator = stauth.Authenticate(
807
- config["credentials"],
808
- config["cookie"]["name"],
809
- config["cookie"]["key"],
810
- config["cookie"]["expiry_days"],
811
- config["preauthorized"],
812
- )
813
- st.session_state["authenticator"] = authenticator
814
- name, authentication_status, username = authenticator.login("Login", "main")
815
- auth_status = st.session_state.get("authentication_status")
816
-
817
- import os
818
- import glob
819
-
820
-
821
- def get_excel_names(directory):
822
- # Create a list to hold the final parts of the filenames
823
- last_portions = []
824
-
825
- # Patterns to match Excel files (.xlsx and .xls) that contain @#
826
- patterns = [
827
- os.path.join(directory, "*@#*.xlsx"),
828
- os.path.join(directory, "*@#*.xls"),
829
- ]
830
-
831
- # Process each pattern
832
- for pattern in patterns:
833
- files = glob.glob(pattern)
834
-
835
- # Extracting the last portion after @# for each file
836
- for file in files:
837
- base_name = os.path.basename(file)
838
- last_portion = base_name.split("@#")[-1]
839
- last_portion = last_portion.replace(".xlsx", "").replace(
840
- ".xls", ""
841
- ) # Removing extensions
842
- last_portions.append(last_portion)
843
-
844
- return last_portions
845
-
846
-
847
- def name_formating(channel_name):
848
- # Replace underscores with spaces
849
- name_mod = channel_name.replace("_", " ")
850
-
851
- # Capitalize the first letter of each word
852
- name_mod = name_mod.title()
853
-
854
- return name_mod
855
-
856
-
857
- @st.cache_resource(show_spinner=False)
858
- def panel_fetch(file_selected):
859
- raw_data_mmm_df = pd.read_excel(file_selected, sheet_name="RAW DATA MMM")
860
-
861
- if "Panel" in raw_data_mmm_df.columns:
862
- panel = list(set(raw_data_mmm_df["Panel"]))
863
- else:
864
- raw_data_mmm_df = None
865
- panel = None
866
-
867
- return panel
868
-
869
-
870
- if auth_status is True:
871
- authenticator.logout("Logout", "main")
872
-
873
- if "project_dct" not in st.session_state:
874
- st.error("Please load a project from home")
875
- st.stop()
876
-
877
- database_file = r"DB\User.db"
878
-
879
- conn = sqlite3.connect(
880
- database_file, check_same_thread=False
881
- ) # connection with sql db
882
- c = conn.cursor()
883
-
884
- with st.sidebar:
885
- st.button("Save checkpoint", on_click=save_checkpoint)
886
- st.button("Reset Checkpoint", on_click=reset_checkpoint)
887
-
888
- warning_placeholder = st.empty()
889
- st.header("Scenario Planner")
890
-
891
- # st.subheader("Simulation")
892
- col1, col2 = st.columns([1, 1])
893
-
894
- # Get metric and panel from last saved state
895
- if "last_saved_metric" not in st.session_state:
896
- st.session_state["last_saved_metric"] = st.session_state[
897
- "project_dct"
898
- ]["scenario_planner"].get("metric_selected", 0)
899
- # st.session_state["last_saved_metric"] = st.session_state[
900
- # "project_dct"
901
- # ]["scenario_planner"].get("metric_selected", 0)
902
-
903
- if "last_saved_panel" not in st.session_state:
904
- st.session_state["last_saved_panel"] = st.session_state["project_dct"][
905
- "scenario_planner"
906
- ].get("panel_selected", 0)
907
- # st.session_state["last_saved_panel"] = st.session_state["project_dct"][
908
- # "scenario_planner"
909
- # ].get("panel_selected", 0)
910
-
911
- # Response Metrics
912
- directory = "metrics_level_data"
913
- metrics_list = get_excel_names(directory)
914
- metrics_selected = col1.selectbox(
915
- "Response Metrics",
916
- metrics_list,
917
- format_func=name_formating,
918
- index=st.session_state["last_saved_metric"],
919
- on_change=reset_optimization,
920
- key="metric_selected",
921
- )
922
-
923
- # Target
924
- target = name_formating(metrics_selected)
925
-
926
- file_selected = f"./metrics_level_data/Overview_data_test_panel@#{metrics_selected}.xlsx"
927
- # print(f"[DEBUG]: {metrics_selected}")
928
- # print(f"[DEBUG]: {file_selected}")
929
- st.session_state["file_selected"] = file_selected
930
- # Panel List
931
- panel_list = panel_fetch(file_selected)
932
- panel_list_final = ["Aggregated"] + panel_list
933
-
934
- # Panel Selected
935
- panel_selected = col2.selectbox(
936
- "Panel",
937
- panel_list_final,
938
- on_change=reset_optimization,
939
- key="panel_selected",
940
- index=st.session_state["last_saved_panel"],
941
- )
942
-
943
- unique_key = f"{st.session_state['metric_selected']}-{st.session_state['panel_selected']}"
944
-
945
- if "update_rcs" in st.session_state:
946
- updated_rcs = st.session_state["update_rcs"]
947
- else:
948
- updated_rcs = None
949
-
950
- if unique_key not in st.session_state["project_dct"]["scenario_planner"]:
951
- if panel_selected == "Aggregated":
952
- initialize_data(
953
- panel=panel_selected,
954
- target_file=file_selected,
955
- updated_rcs=updated_rcs,
956
- metrics=metrics_selected,
957
- )
958
- panel = None
959
- else:
960
- initialize_data(
961
- panel=panel_selected,
962
- target_file=file_selected,
963
- updated_rcs=updated_rcs,
964
- metrics=metrics_selected,
965
- )
966
- st.session_state["project_dct"]["scenario_planner"][unique_key] = (
967
- st.session_state["scenario"]
968
- )
969
-
970
- else:
971
- st.session_state["scenario"] = st.session_state["project_dct"][
972
- "scenario_planner"
973
- ][unique_key]
974
- st.session_state["rcs"] = {}
975
- st.session_state["powers"] = {}
976
- if "optimization_channels" not in st.session_state:
977
- st.session_state["optimization_channels"] = {}
978
-
979
- for channel_name, _channel in st.session_state["project_dct"][
980
- "scenario_planner"
981
- ][unique_key].channels.items():
982
- st.session_state[channel_name] = numerize(
983
- _channel.modified_total_spends, 1
984
- )
985
- st.session_state["rcs"][
986
- channel_name
987
- ] = _channel.response_curve_params
988
- st.session_state["powers"][channel_name] = _channel.power
989
- if channel_name not in st.session_state["optimization_channels"]:
990
- st.session_state["optimization_channels"][channel_name] = False
991
-
992
- if "first_time" not in st.session_state:
993
- st.session_state["first_time"] = True
994
- st.session_state["first_run_scenario"] = True
995
-
996
- # Check if state is initiaized
997
- is_state_initiaized = st.session_state.get("initialized", False)
998
-
999
- # if not is_state_initiaized:
1000
- # print("running initialize...")
1001
- # # initialize_data()
1002
- # if panel_selected == "Aggregated":
1003
- # initialize_data(
1004
- # panel=panel_selected,
1005
- # target_file=file_selected,
1006
- # updated_rcs=updated_rcs,
1007
- # metrics=metrics_selected,
1008
- # )
1009
- # panel = None
1010
- # else:
1011
- # initialize_data(
1012
- # panel=panel_selected,
1013
- # target_file=file_selected,
1014
- # updated_rcs=updated_rcs,
1015
- # metrics=metrics_selected,
1016
- # )
1017
- # st.session_state["initialized"] = True
1018
- # st.session_state["first_time"] = False
1019
-
1020
- # Channels List
1021
- channels_list = list(
1022
- st.session_state["project_dct"]["scenario_planner"][
1023
- unique_key
1024
- ].channels.keys()
1025
- )
1026
-
1027
- # ======================================================== #
1028
- # ========================== UI ========================== #
1029
- # ======================================================== #
1030
-
1031
- main_header = st.columns((2, 2))
1032
- sub_header = st.columns((1, 1, 1, 1))
1033
- # _scenario = st.session_state["scenario"]
1034
-
1035
- st.session_state.total_spends_change = round(
1036
- (
1037
- st.session_state["scenario"].modified_total_spends
1038
- / st.session_state["scenario"].actual_total_spends
1039
- - 1
1040
- )
1041
- * 100
1042
- )
1043
-
1044
- if "total_sales_change" not in st.session_state:
1045
- st.session_state.total_sales_change = round(
1046
- (
1047
- st.session_state["scenario"].modified_total_sales
1048
- / st.session_state["scenario"].actual_total_sales
1049
- - 1
1050
- )
1051
- * 100
1052
- )
1053
-
1054
- st.session_state["total_spends_change_abs"] = numerize(
1055
- st.session_state["scenario"].modified_total_spends,
1056
- 1,
1057
- )
1058
- if "total_sales_change_abs" not in st.session_state:
1059
- st.session_state["total_sales_change_abs"] = numerize(
1060
- st.session_state["scenario"].modified_total_sales,
1061
- 1,
1062
- )
1063
-
1064
- # if "total_spends_change_abs_slider" not in st.session_state:
1065
- st.session_state.total_spends_change_abs_slider = numerize(
1066
- st.session_state["scenario"].modified_total_spends, 1
1067
- )
1068
-
1069
- if "total_sales_change_abs_slider" not in st.session_state:
1070
- st.session_state.total_sales_change_abs_slider = numerize(
1071
- st.session_state["scenario"].actual_total_sales, 1
1072
- )
1073
-
1074
- st.session_state["allow_sales_update"] = True
1075
-
1076
- st.session_state["allow_spends_update"] = True
1077
-
1078
- # if "panel_selected" not in st.session_state:
1079
- # st.session_state["panel_selected"] = 0
1080
-
1081
- with main_header[0]:
1082
- st.subheader("Actual")
1083
-
1084
- with main_header[-1]:
1085
- st.subheader("Simulated")
1086
-
1087
- with sub_header[0]:
1088
- st.metric(
1089
- label="Spends",
1090
- value=format_numbers(
1091
- st.session_state["scenario"].actual_total_spends
1092
- ),
1093
- )
1094
-
1095
- with sub_header[1]:
1096
- st.metric(
1097
- label=target,
1098
- value=format_numbers(
1099
- float(st.session_state["scenario"].actual_total_sales),
1100
- include_indicator=False,
1101
- ),
1102
- )
1103
-
1104
- with sub_header[2]:
1105
- st.metric(
1106
- label="Spends",
1107
- value=format_numbers(
1108
- st.session_state["scenario"].modified_total_spends
1109
- ),
1110
- delta=numerize(st.session_state["scenario"].delta_spends, 1),
1111
- )
1112
-
1113
- with sub_header[3]:
1114
- st.metric(
1115
- label=target,
1116
- value=format_numbers(
1117
- float(st.session_state["scenario"].modified_total_sales),
1118
- include_indicator=False,
1119
- ),
1120
- delta=numerize(st.session_state["scenario"].delta_sales, 1),
1121
- )
1122
-
1123
- with st.expander("Channel Spends Simulator", expanded=True):
1124
- _columns1 = st.columns((2, 2, 1, 1))
1125
- with _columns1[0]:
1126
- optimization_selection = st.selectbox(
1127
- "Optimize",
1128
- options=["Media Spends", target],
1129
- key="optimization_key_value",
1130
- )
1131
-
1132
- with _columns1[1]:
1133
- st.markdown("#")
1134
- # if st.checkbox(
1135
- # label="Optimize all Channels",
1136
- # key="optimze_all_channels",
1137
- # value=False,
1138
- # # on_change=select_all_channels_for_optimization,
1139
- # ):
1140
- # select_all_channels_for_optimization()
1141
-
1142
- st.checkbox(
1143
- label="Optimize all Channels",
1144
- key="optimze_all_channels",
1145
- on_change=select_all_channels_for_optimization,
1146
- )
1147
-
1148
- with _columns1[2]:
1149
- st.markdown("#")
1150
- # st.button(
1151
- # "Optimize",
1152
- # on_click=optimize,
1153
- # args=(st.session_state["optimization_key_value"]),
1154
- # use_container_width=True,
1155
- # )
1156
-
1157
- optimize_placeholder = st.empty()
1158
-
1159
- with _columns1[3]:
1160
- st.markdown("#")
1161
- st.button(
1162
- "Reset",
1163
- on_click=reset_scenario,
1164
- # args=(panel_selected, file_selected, updated_rcs),
1165
- use_container_width=True,
1166
- )
1167
-
1168
- _columns2 = st.columns((2, 2, 2))
1169
- if st.session_state["optimization_key_value"] == "Media Spends":
1170
-
1171
- # update_spends()
1172
-
1173
- with _columns2[0]:
1174
- spend_input = st.text_input(
1175
- "Absolute",
1176
- key="total_spends_change_abs",
1177
- # label_visibility="collapsed",
1178
- on_change=update_all_spends_abs,
1179
- )
1180
-
1181
- with _columns2[1]:
1182
- st.number_input(
1183
- "Percent Change",
1184
- key="total_spends_change",
1185
- min_value=-50,
1186
- max_value=50,
1187
- step=1,
1188
- on_change=update_spends,
1189
- )
1190
-
1191
- with _columns2[2]:
1192
- scenario = st.session_state["project_dct"]["scenario_planner"][
1193
- unique_key
1194
- ]
1195
- min_value = round(scenario.actual_total_spends * 0.5)
1196
- max_value = round(scenario.actual_total_spends * 1.5)
1197
- st.session_state["total_spends_change_abs_slider_options"] = [
1198
- numerize(value, 1)
1199
- for value in range(min_value, max_value + 1, int(1e4))
1200
- ]
1201
-
1202
- st.select_slider(
1203
- "Absolute Slider",
1204
- options=st.session_state[
1205
- "total_spends_change_abs_slider_options"
1206
- ],
1207
- key="total_spends_change_abs_slider",
1208
- on_change=update_all_spends_abs_slider,
1209
- )
1210
-
1211
- elif st.session_state["optimization_key_value"] == target:
1212
- # update_sales()
1213
-
1214
- with _columns2[0]:
1215
- sales_input = st.text_input(
1216
- "Absolute",
1217
- key="total_sales_change_abs",
1218
- on_change=update_sales_abs,
1219
- )
1220
-
1221
- with _columns2[1]:
1222
- st.number_input(
1223
- "Percent Change",
1224
- key="total_sales_change",
1225
- min_value=-50,
1226
- max_value=50,
1227
- step=1,
1228
- on_change=update_sales,
1229
- )
1230
-
1231
- with _columns2[2]:
1232
- min_value = round(
1233
- st.session_state["scenario"].actual_total_sales * 0.5
1234
- )
1235
- max_value = round(
1236
- st.session_state["scenario"].actual_total_sales * 1.5
1237
- )
1238
- st.session_state["total_sales_change_abs_slider_options"] = [
1239
- numerize(value, 1)
1240
- for value in range(min_value, max_value + 1, int(1e5))
1241
- ]
1242
-
1243
- st.select_slider(
1244
- "Absolute Slider",
1245
- options=st.session_state[
1246
- "total_sales_change_abs_slider_options"
1247
- ],
1248
- key="total_sales_change_abs_slider",
1249
- on_change=update_sales_abs_slider,
1250
- )
1251
-
1252
- if (
1253
- not st.session_state["allow_sales_update"]
1254
- and optimization_selection == target
1255
- ):
1256
- st.warning("Invalid Input")
1257
-
1258
- if (
1259
- not st.session_state["allow_spends_update"]
1260
- and optimization_selection == "Media Spends"
1261
- ):
1262
- st.warning("Invalid Input")
1263
-
1264
- status_placeholder = st.empty()
1265
-
1266
- # if optimize_placeholder.button("Optimize", use_container_width=True):
1267
- # optimize(st.session_state["optimization_key_value"], status_placeholder)
1268
- # st.rerun()
1269
-
1270
- optimize_placeholder.button(
1271
- "Optimize",
1272
- on_click=optimize,
1273
- args=(
1274
- st.session_state["optimization_key_value"],
1275
- status_placeholder,
1276
- ),
1277
- use_container_width=True,
1278
- )
1279
-
1280
- st.markdown(
1281
- """<hr class="spends-heading-seperator">""", unsafe_allow_html=True
1282
- )
1283
- _columns = st.columns((2.5, 2, 1.5, 1.5, 1))
1284
- with _columns[0]:
1285
- generate_spending_header("Channel")
1286
- with _columns[1]:
1287
- generate_spending_header("Spends Input")
1288
- with _columns[2]:
1289
- generate_spending_header("Spends")
1290
- with _columns[3]:
1291
- generate_spending_header(target)
1292
- with _columns[4]:
1293
- generate_spending_header("Optimize")
1294
-
1295
- st.markdown(
1296
- """<hr class="spends-heading-seperator">""", unsafe_allow_html=True
1297
- )
1298
-
1299
- if "acutual_predicted" not in st.session_state:
1300
- st.session_state["acutual_predicted"] = {
1301
- "Channel_name": [],
1302
- "Actual_spend": [],
1303
- "Optimized_spend": [],
1304
- "Delta": [],
1305
- }
1306
- for i, channel_name in enumerate(channels_list):
1307
- _channel_class = st.session_state["scenario"].channels[
1308
- channel_name
1309
- ]
1310
-
1311
- st.session_state[f"{channel_name}_percent"] = round(
1312
- (
1313
- _channel_class.modified_total_spends
1314
- / _channel_class.actual_total_spends
1315
- - 1
1316
- )
1317
- * 100
1318
- )
1319
-
1320
- _columns = st.columns((2.5, 1.5, 1.5, 1.5, 1))
1321
- with _columns[0]:
1322
- st.write(channel_name_formating(channel_name))
1323
- bin_placeholder = st.container()
1324
-
1325
- with _columns[1]:
1326
- channel_bounds = _channel_class.bounds
1327
- channel_spends = float(_channel_class.actual_total_spends)
1328
- min_value = float(
1329
- (1 + channel_bounds[0] / 100) * channel_spends
1330
- )
1331
- max_value = float(
1332
- (1 + channel_bounds[1] / 100) * channel_spends
1333
- )
1334
- # print("##########", st.session_state[channel_name])
1335
- spend_input = st.text_input(
1336
- channel_name,
1337
- key=channel_name,
1338
- label_visibility="collapsed",
1339
- on_change=partial(update_data, channel_name),
1340
- )
1341
- if not validate_input(spend_input):
1342
- st.error("Invalid input")
1343
-
1344
- channel_name_current = f"{channel_name}_change"
1345
-
1346
- st.number_input(
1347
- "Percent Change",
1348
- key=f"{channel_name}_percent",
1349
- step=1,
1350
- on_change=partial(update_data_by_percent, channel_name),
1351
- )
1352
-
1353
- with _columns[2]:
1354
- # spends
1355
- current_channel_spends = float(
1356
- _channel_class.modified_total_spends
1357
- * _channel_class.conversion_rate
1358
- )
1359
- actual_channel_spends = float(
1360
- _channel_class.actual_total_spends
1361
- * _channel_class.conversion_rate
1362
- )
1363
- spends_delta = float(
1364
- _channel_class.delta_spends
1365
- * _channel_class.conversion_rate
1366
- )
1367
- st.session_state["acutual_predicted"]["Channel_name"].append(
1368
- channel_name
1369
- )
1370
- st.session_state["acutual_predicted"]["Actual_spend"].append(
1371
- actual_channel_spends
1372
- )
1373
- st.session_state["acutual_predicted"][
1374
- "Optimized_spend"
1375
- ].append(current_channel_spends)
1376
- st.session_state["acutual_predicted"]["Delta"].append(
1377
- spends_delta
1378
- )
1379
- ## REMOVE
1380
- st.metric(
1381
- "Spends",
1382
- format_numbers(current_channel_spends),
1383
- delta=numerize(spends_delta, 1),
1384
- label_visibility="collapsed",
1385
- )
1386
-
1387
- with _columns[3]:
1388
- # sales
1389
- current_channel_sales = float(
1390
- _channel_class.modified_total_sales
1391
- )
1392
- actual_channel_sales = float(_channel_class.actual_total_sales)
1393
- sales_delta = float(_channel_class.delta_sales)
1394
- st.metric(
1395
- target,
1396
- format_numbers(
1397
- current_channel_sales, include_indicator=False
1398
- ),
1399
- delta=numerize(sales_delta, 1),
1400
- label_visibility="collapsed",
1401
- )
1402
-
1403
- with _columns[4]:
1404
-
1405
- # if st.checkbox(
1406
- # label="select for optimization",
1407
- # key=f"{channel_name}_selected",
1408
- # value=False,
1409
- # # on_change=partial(select_channel_for_optimization, channel_name),
1410
- # label_visibility="collapsed",
1411
- # ):
1412
- # select_channel_for_optimization(channel_name)
1413
-
1414
- st.checkbox(
1415
- label="select for optimization",
1416
- key=f"{channel_name}_selected",
1417
- value=False,
1418
- on_change=partial(
1419
- select_channel_for_optimization, channel_name
1420
- ),
1421
- label_visibility="collapsed",
1422
- )
1423
-
1424
- st.markdown(
1425
- """<hr class="spends-child-seperator">""",
1426
- unsafe_allow_html=True,
1427
- )
1428
-
1429
- # Bins
1430
- col = channels_list[i]
1431
- x_actual = st.session_state["scenario"].channels[col].actual_spends
1432
- x_modified = (
1433
- st.session_state["scenario"].channels[col].modified_spends
1434
- )
1435
-
1436
- x_total = x_modified.sum()
1437
- power = np.ceil(np.log(x_actual.max()) / np.log(10)) - 3
1438
-
1439
- updated_rcs_key = (
1440
- f"{metrics_selected}#@{panel_selected}#@{channel_name}"
1441
- )
1442
-
1443
- if updated_rcs and updated_rcs_key in list(updated_rcs.keys()):
1444
- K = updated_rcs[updated_rcs_key]["K"]
1445
- b = updated_rcs[updated_rcs_key]["b"]
1446
- a = updated_rcs[updated_rcs_key]["a"]
1447
- x0 = updated_rcs[updated_rcs_key]["x0"]
1448
- else:
1449
- K = st.session_state["rcs"][col]["K"]
1450
- b = st.session_state["rcs"][col]["b"]
1451
- a = st.session_state["rcs"][col]["a"]
1452
- x0 = st.session_state["rcs"][col]["x0"]
1453
-
1454
- x_plot = np.linspace(0, 5 * x_actual.sum(), 200)
1455
-
1456
- # Append current_channel_spends to the end of x_plot
1457
- x_plot = np.append(x_plot, current_channel_spends)
1458
-
1459
- x, y, marginal_roi = [], [], []
1460
- for x_p in x_plot:
1461
- x.append(x_p * x_actual / x_actual.sum())
1462
-
1463
- for index in range(len(x_plot)):
1464
- y.append(s_curve(x[index] / 10**power, K, b, a, x0))
1465
-
1466
- for index in range(len(x_plot)):
1467
- marginal_roi.append(
1468
- a
1469
- * y[index]
1470
- * (1 - y[index] / np.maximum(K, np.finfo(float).eps))
1471
- )
1472
-
1473
- x = (
1474
- np.sum(x, axis=1)
1475
- * st.session_state["scenario"].channels[col].conversion_rate
1476
- )
1477
- y = np.sum(y, axis=1)
1478
- marginal_roi = (
1479
- np.average(marginal_roi, axis=1)
1480
- / st.session_state["scenario"].channels[col].conversion_rate
1481
- )
1482
-
1483
- roi = y / np.maximum(x, np.finfo(float).eps)
1484
-
1485
- roi_current, marginal_roi_current = roi[-1], marginal_roi[-1]
1486
- x, y, roi, marginal_roi = (
1487
- x[:-1],
1488
- y[:-1],
1489
- roi[:-1],
1490
- marginal_roi[:-1],
1491
- ) # Drop data for current spends
1492
-
1493
- start_value, end_value, left_value, right_value = (
1494
- find_segment_value(
1495
- x,
1496
- roi,
1497
- marginal_roi,
1498
- )
1499
- )
1500
-
1501
- rgba = calculate_rgba(
1502
- start_value,
1503
- end_value,
1504
- left_value,
1505
- right_value,
1506
- current_channel_spends,
1507
- )
1508
-
1509
- with bin_placeholder:
1510
- st.markdown(
1511
- f"""
1512
- <div style="
1513
- border-radius: 12px;
1514
- background-color: {rgba};
1515
- padding: 10px;
1516
- text-align: center;
1517
- color: #006EC0;
1518
- ">
1519
- <p style="margin: 0; font-size: 20px;">ROI: {round(roi_current,1)}</p>
1520
- <p style="margin: 0; font-size: 20px;">Marginal ROI: {round(marginal_roi_current,1)}</p>
1521
- </div>
1522
- """,
1523
- unsafe_allow_html=True,
1524
- )
1525
-
1526
- st.session_state["project_dct"]["scenario_planner"]["scenario"] = (
1527
- st.session_state["scenario"]
1528
- )
1529
-
1530
- with st.expander("See Response Curves", expanded=True):
1531
- fig = plot_response_curves()
1532
- st.plotly_chart(fig, use_container_width=True)
1533
-
1534
- def update_optimization_bounds(channel_name, bound_type):
1535
- index = 0 if bound_type == "lower" else 1
1536
- update_key = (
1537
- f"{channel_name}_b_lower"
1538
- if bound_type == "lower"
1539
- else f"{channel_name}_b_upper"
1540
- )
1541
- st.session_state["project_dct"]["scenario_planner"][
1542
- unique_key
1543
- ].channels[channel_name].bounds[index] = st.session_state[update_key]
1544
-
1545
- def update_optimization_bounds_all(bound_type):
1546
- index = 0 if bound_type == "lower" else 1
1547
- update_key = (
1548
- f"all_b_lower" if bound_type == "lower" else f"all_b_upper"
1549
- )
1550
-
1551
- for channel_name, _channel in st.session_state["project_dct"][
1552
- "scenario_planner"
1553
- ][unique_key].channels.items():
1554
- _channel.bounds[index] = st.session_state[update_key]
1555
-
1556
- with st.expander("Optimization setup"):
1557
- bounds_placeholder = st.container()
1558
- with bounds_placeholder:
1559
- st.subheader("Optimization Bounds")
1560
- with st.container():
1561
- bounds_columns = st.columns((1, 0.35, 0.35, 1))
1562
- with bounds_columns[0]:
1563
- st.write("##")
1564
- st.write("Update all channels")
1565
-
1566
- with bounds_columns[1]:
1567
- st.number_input(
1568
- "Lower",
1569
- min_value=-100,
1570
- max_value=500,
1571
- key=f"all_b_lower",
1572
- # label_visibility="hidden",
1573
- on_change=update_optimization_bounds_all,
1574
- args=("lower",),
1575
- step=5,
1576
- value=-10,
1577
- )
1578
-
1579
- with bounds_columns[2]:
1580
- st.number_input(
1581
- "Higher",
1582
- value=10,
1583
- min_value=-100,
1584
- max_value=500,
1585
- key=f"all_b_upper",
1586
- # label_visibility="hidden",
1587
- on_change=update_optimization_bounds_all,
1588
- args=("upper",),
1589
- step=5,
1590
- )
1591
- st.divider()
1592
-
1593
- st.write("#### Channel wise bounds")
1594
- # st.divider()
1595
- # bounds_columns = st.columns((1, 0.35, 0.35, 1))
1596
-
1597
- # with bounds_columns[0]:
1598
- # st.write("Channel")
1599
- # with bounds_columns[1]:
1600
- # st.write("Lower")
1601
- # with bounds_columns[2]:
1602
- # st.write("Upper")
1603
- # st.divider()
1604
-
1605
- for channel_name, _channel in st.session_state["project_dct"][
1606
- "scenario_planner"
1607
- ][unique_key].channels.items():
1608
- st.session_state[f"{channel_name}_b_lower"] = _channel.bounds[0]
1609
- st.session_state[f"{channel_name}_b_upper"] = _channel.bounds[1]
1610
- with bounds_placeholder:
1611
- with st.container():
1612
- bounds_columns = st.columns((1, 0.35, 0.35, 1))
1613
- with bounds_columns[0]:
1614
- st.write("##")
1615
- st.write(channel_name)
1616
- with bounds_columns[1]:
1617
- st.number_input(
1618
- "Lower",
1619
- min_value=-100,
1620
- max_value=500,
1621
- key=f"{channel_name}_b_lower",
1622
- label_visibility="hidden",
1623
- on_change=update_optimization_bounds,
1624
- args=(
1625
- channel_name,
1626
- "lower",
1627
- ),
1628
- )
1629
-
1630
- with bounds_columns[2]:
1631
- st.number_input(
1632
- "Higher",
1633
- min_value=-100,
1634
- max_value=500,
1635
- key=f"{channel_name}_b_upper",
1636
- label_visibility="hidden",
1637
- on_change=update_optimization_bounds,
1638
- args=(
1639
- channel_name,
1640
- "upper",
1641
- ),
1642
- )
1643
-
1644
- st.divider()
1645
- _columns = st.columns(2)
1646
- with _columns[0]:
1647
- st.subheader("Save Scenario")
1648
- scenario_name = st.text_input(
1649
- "Scenario name",
1650
- key="scenario_input",
1651
- placeholder="Scenario name",
1652
- label_visibility="collapsed",
1653
- )
1654
- st.button(
1655
- "Save",
1656
- on_click=lambda: save_scenario(scenario_name),
1657
- disabled=len(st.session_state["scenario_input"]) == 0,
1658
- )
1659
-
1660
- summary_df = pd.DataFrame(st.session_state["acutual_predicted"])
1661
- summary_df.drop_duplicates(
1662
- subset="Channel_name", keep="last", inplace=True
1663
- )
1664
-
1665
- summary_df_sorted = summary_df.sort_values(by="Delta", ascending=False)
1666
- summary_df_sorted["Delta_percent"] = np.round(
1667
- (
1668
- (
1669
- summary_df_sorted["Optimized_spend"]
1670
- / summary_df_sorted["Actual_spend"]
1671
- )
1672
- - 1
1673
- )
1674
- * 100,
1675
- 2,
1676
- )
1677
-
1678
- with open("summary_df.pkl", "wb") as f:
1679
- pickle.dump(summary_df_sorted, f)
1680
- # st.dataframe(summary_df_sorted)
1681
- # ___columns=st.columns(3)
1682
- # with ___columns[2]:
1683
- # fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent')
1684
- # st.plotly_chart(fig,use_container_width=True)
1685
- # with ___columns[0]:
1686
- # fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend')
1687
- # st.plotly_chart(fig,use_container_width=True)
1688
- # with ___columns[1]:
1689
- # fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend')
1690
- # st.plotly_chart(fig,use_container_width=True)
1691
-
1692
- elif auth_status == False:
1693
- st.error("Username/Password is incorrect")
1694
-
1695
- if auth_status != True:
1696
- try:
1697
- username_forgot_pw, email_forgot_password, random_password = (
1698
- authenticator.forgot_password("Forgot password")
1699
- )
1700
- if username_forgot_pw:
1701
- st.session_state["config"]["credentials"]["usernames"][
1702
- username_forgot_pw
1703
- ]["password"] = stauth.Hasher([random_password]).generate()[0]
1704
- send_email(email_forgot_password, random_password)
1705
- st.success("New password sent securely")
1706
- # Random password to be transferred to user securely
1707
- elif username_forgot_pw == False:
1708
- st.error("Username not found")
1709
- except Exception as e:
1710
- st.error(e)
1711
-
1712
- update_db("9_Scenario_Planner.py")