LeonceNsh commited on
Commit
d722ce7
·
verified ·
1 Parent(s): d7dbe6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -73
app.py CHANGED
@@ -1,83 +1,201 @@
1
- import gradio as gr
2
  import pandas as pd
3
  import geopandas as gpd
 
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
- from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint
7
- from sklearn.ensemble import GradientBoostingRegressor
8
  import numpy as np
9
 
10
- # GPU-optimized TimesFM setup
11
- timesfm_backend = "gpu"
12
- timesfm_model_config = TimesFmHparams(
13
- context_len=512,
14
- horizon_len=128,
15
- per_core_batch_size=128,
16
- backend=timesfm_backend,
17
- )
18
- timesfm_model = TimesFm(
19
- hparams=timesfm_model_config,
20
- checkpoint=TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch")
21
- )
22
-
23
- # Function to load embeddings and calculate HHI
24
- def calculate_hhi(file, market_col, id_col, weight_col):
25
- df = pd.read_csv(file.name)
26
- df['denominator'] = df.groupby(market_col)[weight_col].transform('sum')
27
- df['numerator'] = df.groupby([market_col, id_col])[weight_col].transform('sum')
28
- df['market_share'] = 100 * (df['numerator'] / df['denominator'])
29
- df['market_share_sq'] = df['market_share'] ** 2
30
- hhi = df.groupby(market_col).apply(lambda x: x['market_share_sq'].sum())
31
- return hhi.reset_index(name='hhi')
32
-
33
- # Function to visualize HHI map
34
- def plot_hhi_map(hhi_csv, shapefile):
35
- hhi_df = pd.read_csv(hhi_csv.name)
36
- gdf = gpd.read_file(shapefile.name)
37
- gdf = gdf.merge(hhi_df, left_on='fips_code', right_on='market_col', how='left')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  fig, ax = plt.subplots(1, 1, figsize=(12, 8))
39
- gdf.plot(column='hhi', cmap='RdBu', legend=True, ax=ax, missing_kwds={"color": "lightgrey"})
40
- ax.set_title("HHI by County")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return fig
42
 
43
- # Function to forecast using TimesFM
44
- def forecast(file, history_steps, forecast_steps):
45
- df = pd.read_csv(file.name).set_index('place')
46
- history = df[history_steps]
47
- forecast = timesfm_model.forecast(inputs=history.values)
48
- return pd.DataFrame(forecast, index=history.index)
49
-
50
- # Gradio app interface
51
- def gradio_interface():
52
- with gr.Blocks() as demo:
53
- gr.Markdown("### Healthcare Network Analysis and Forecasting")
54
-
55
- with gr.Tab("Upload Embeddings"):
56
- file_upload = gr.File(label="Upload Embeddings (CSV)")
57
- hhi_results = gr.DataFrame(label="HHI Results")
58
- calculate_button = gr.Button("Calculate HHI")
59
- calculate_button.click(
60
- calculate_hhi,
61
- inputs=[file_upload, "market_col", "id_col", "weight_col"],
62
- outputs=hhi_results
63
- )
64
-
65
- with gr.Tab("Visualize Map"):
66
- hhi_csv = gr.File(label="Upload HHI CSV")
67
- shapefile = gr.File(label="Upload Shapefile")
68
- map_plot = gr.Plot(label="HHI Map")
69
- plot_button = gr.Button("Generate Map")
70
- plot_button.click(plot_hhi_map, inputs=[hhi_csv, shapefile], outputs=map_plot)
71
-
72
- with gr.Tab("Forecasting"):
73
- forecast_file = gr.File(label="Upload Historical Data (CSV)")
74
- forecast_steps = gr.Slider(minimum=1, maximum=24, step=1, label="Forecast Steps")
75
- forecast_results = gr.DataFrame(label="Forecasted Data")
76
- forecast_button = gr.Button("Forecast")
77
- forecast_button.click(forecast, inputs=[forecast_file, forecast_steps], outputs=forecast_results)
78
-
79
- return demo
80
-
81
- # Run app
82
  if __name__ == "__main__":
83
- gradio_interface().launch()
 
 
1
  import pandas as pd
2
  import geopandas as gpd
3
+ import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
 
 
6
  import numpy as np
7
 
8
+ # ========================
9
+ # Data Loading
10
+ # ========================
11
+
12
+ # Load the health and demographic data
13
+ conus_data = pd.read_csv("conus27.csv")
14
+ # Load the county shapefile
15
+ county_geojson = gpd.read_file("county.geojson")
16
+ # Load the county embeddings
17
+ county_embeddings = pd.read_csv("county_embeddings.csv")
18
+ # Load the unemployment data
19
+ county_unemployment = pd.read_csv("county_unemployment.csv")
20
+ # Load the poverty data
21
+ zcta_poverty = pd.read_csv("zcta_poverty.csv")
22
+ # Load the ZCTA shapefile
23
+ zcta_geojson = gpd.read_file("zcta.geojson")
24
+
25
+ # Merge unemployment data with county_geojson
26
+ county_unemployment_melted = county_unemployment.melt(id_vars=['place'],
27
+ var_name='date',
28
+ value_name='unemployment_rate')
29
+ county_unemployment_melted['place'] = county_unemployment_melted['place'].astype(str)
30
+
31
+
32
+ county_geojson_unemployment = county_geojson.merge(county_unemployment_melted, left_on='place', right_on='place', how='left')
33
+
34
+ # Prepare poverty data
35
+ zcta_poverty_melted = zcta_poverty.melt(id_vars=['place'], var_name='year', value_name='poverty_rate')
36
+
37
+
38
+ zcta_poverty_melted['place'] = zcta_poverty_melted['place'].astype(str)
39
+
40
+
41
+ zcta_geojson['place'] = zcta_geojson['place'].astype(str)
42
+
43
+ zcta_geojson_poverty = zcta_geojson.merge(zcta_poverty_melted, left_on='place', right_on='place', how='left')
44
+
45
+
46
+ # List of health metrics available
47
+ health_metrics = [col for col in conus_data.columns if col.startswith('Percent_Person_')]
48
+ # Simplify metric names
49
+ simplified_metrics = [col.replace('Percent_Person_', '') for col in health_metrics]
50
+ metric_mapping = dict(zip(simplified_metrics, health_metrics))
51
+
52
+ # ========================
53
+ # Utility Functions
54
+ # ========================
55
+
56
+ def plot_health_metric(metric):
57
+ """
58
+ Plots the geographical distribution of a selected health metric.
59
+ """
60
+ metric_full_name = metric_mapping[metric]
61
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
62
+ gdf_health.plot(
63
+ column=metric_full_name,
64
+ cmap='OrRd',
65
+ markersize=50,
66
+ legend=True,
67
+ legend_kwds={'label': f"{metric} (%)"},
68
+ ax=ax,
69
+ alpha=0.7,
70
+ edgecolor='k'
71
+ )
72
+ ax.set_title(f'Geographical Distribution of {metric}', fontsize=15)
73
+ ax.axis('off')
74
+ plt.tight_layout()
75
+ return fig
76
+
77
+ def plot_correlation_matrix(selected_metrics):
78
+ """
79
+ Plots the correlation matrix for selected health metrics.
80
+ """
81
+ selected_columns = [metric_mapping[metric] for metric in selected_metrics]
82
+ corr = conus_data[selected_columns].corr()
83
+ fig, ax = plt.subplots(figsize=(10, 8))
84
+ sns.heatmap(corr, annot=True, cmap='coolwarm', square=True, ax=ax)
85
+ ax.set_title('Correlation Matrix of Selected Health Metrics', fontsize=15)
86
+ plt.tight_layout()
87
+ return fig
88
+
89
+ def plot_unemployment_map(date):
90
+ """
91
+ Plots the unemployment rate map for a selected date.
92
+ """
93
+ date = str(date)
94
+ data = county_geojson_unemployment[county_geojson_unemployment['date'] == date]
95
  fig, ax = plt.subplots(1, 1, figsize=(12, 8))
96
+ data.plot(
97
+ column='unemployment_rate',
98
+ cmap='Blues',
99
+ linewidth=0.8,
100
+ ax=ax,
101
+ edgecolor='0.8',
102
+ legend=True,
103
+ missing_kwds={"color": "lightgrey", "label": "Missing values"},
104
+ )
105
+ ax.set_title(f'Unemployment Rate by County ({date})', fontsize=15)
106
+ ax.axis('off')
107
+ plt.tight_layout()
108
+ return fig
109
+
110
+ def plot_poverty_map(year):
111
+ """
112
+ Plots the poverty rate map for a selected year.
113
+ """
114
+ year = str(year)
115
+ data = zcta_geojson_poverty[zcta_geojson_poverty['year'] == year]
116
+ fig, ax = plt.subplots(1, 1, figsize=(12, 8))
117
+ data.plot(
118
+ column='poverty_rate',
119
+ cmap='Reds',
120
+ linewidth=0.8,
121
+ ax=ax,
122
+ edgecolor='0.8',
123
+ legend=True,
124
+ missing_kwds={"color": "lightgrey", "label": "Missing values"},
125
+ )
126
+ ax.set_title(f'Poverty Rate by ZCTA ({year})', fontsize=15)
127
+ ax.axis('off')
128
+ plt.tight_layout()
129
+ return fig
130
+
131
+ def summarize_health_metrics(metric):
132
+ """
133
+ Generates summary statistics for a selected health metric.
134
+ """
135
+ metric_full_name = metric_mapping[metric]
136
+ summary = conus_data[metric_full_name].describe().to_frame().reset_index()
137
+ summary.columns = ['Statistic', 'Value']
138
+ return summary
139
+
140
+ # ========================
141
+ # Gradio Interface Functions
142
+ # ========================
143
+
144
+ def health_metric_interface(metric):
145
+ fig = plot_health_metric(metric)
146
+ summary = summarize_health_metrics(metric)
147
+ return fig, summary
148
+
149
+ def correlation_interface(metrics):
150
+ fig = plot_correlation_matrix(metrics)
151
+ return fig
152
+
153
+ def unemployment_interface(date):
154
+ fig = plot_unemployment_map(date)
155
+ return fig
156
+
157
+ def poverty_interface(year):
158
+ fig = plot_poverty_map(year)
159
  return fig
160
 
161
+ # ========================
162
+ # Gradio App Setup
163
+ # ========================
164
+
165
+ with gr.Blocks(title="US Population Health Dashboard") as demo:
166
+ gr.Markdown("# US Population Health Dashboard")
167
+ gr.Markdown("Explore health metrics, socioeconomic data, and their geospatial distributions across the United States.")
168
+
169
+ with gr.Tab("Health Metrics Map"):
170
+ gr.Markdown("### Geographical Distribution of Health Metrics")
171
+ health_metric = gr.Dropdown(label="Select a Health Metric", choices=simplified_metrics, value=simplified_metrics[0])
172
+ health_plot = gr.Plot()
173
+ health_summary = gr.Dataframe(headers=["Statistic", "Value"])
174
+ health_metric.change(health_metric_interface, inputs=health_metric, outputs=[health_plot, health_summary])
175
+
176
+ with gr.Tab("Health Metrics Correlation"):
177
+ gr.Markdown("### Correlation Matrix of Health Metrics")
178
+ correlation_metrics = gr.CheckboxGroup(label="Select Health Metrics", choices=simplified_metrics, value=simplified_metrics[:5])
179
+ correlation_plot = gr.Plot()
180
+ correlation_metrics.change(correlation_interface, inputs=correlation_metrics, outputs=correlation_plot)
181
+
182
+ with gr.Tab("Unemployment Rate Map"):
183
+ gr.Markdown("### Geographical Distribution of Unemployment Rates")
184
+ dates = county_unemployment_melted['date'].unique().tolist()
185
+ unemployment_date = gr.Slider(label="Select a Date", minimum=min(dates), maximum=max(dates), step=1, value=dates[0])
186
+ unemployment_plot = gr.Plot()
187
+ unemployment_date.change(unemployment_interface, inputs=unemployment_date, outputs=unemployment_plot)
188
+
189
+ with gr.Tab("Poverty Rate Map"):
190
+ gr.Markdown("### Geographical Distribution of Poverty Rates")
191
+ years = zcta_poverty_melted['year'].unique().astype(int).tolist()
192
+ poverty_year = gr.Slider(label="Select a Year", minimum=min(years), maximum=max(years), step=1, value=years[0])
193
+ poverty_plot = gr.Plot()
194
+ poverty_year.change(poverty_interface, inputs=poverty_year, outputs=poverty_plot)
195
+
196
+ # ========================
197
+ # Launch the App
198
+ # ========================
199
+
200
  if __name__ == "__main__":
201
+ demo.launch()