Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,191 +1,83 @@
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import numpy as np
|
5 |
import matplotlib.pyplot as plt
|
6 |
import seaborn as sns
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
#
|
22 |
-
def
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
def
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
# Merge embeddings
|
51 |
-
county_embeddings.index = county_embeddings.index.astype('str')
|
52 |
-
X_train['place'] = X_train['place'].astype('str')
|
53 |
-
X_test['place'] = X_test['place'].astype('str')
|
54 |
-
|
55 |
-
X_train = X_train.merge(county_embeddings, left_on='place', right_index=True, how='left')
|
56 |
-
X_test = X_test.merge(county_embeddings, left_on='place', right_index=True, how='left')
|
57 |
-
|
58 |
-
# Remove non-numeric columns
|
59 |
-
numeric_cols_train = X_train.select_dtypes(include=['float64', 'int64']).columns
|
60 |
-
X_train_numeric = X_train[numeric_cols_train]
|
61 |
-
numeric_cols_test = X_test.select_dtypes(include=['float64', 'int64']).columns
|
62 |
-
X_test_numeric = X_test[numeric_cols_test]
|
63 |
-
|
64 |
-
# Impute missing values
|
65 |
-
X_train_imputed = imputer.transform(X_train_numeric)
|
66 |
-
X_test_imputed = imputer.transform(X_test_numeric)
|
67 |
-
|
68 |
-
# Apply PCA
|
69 |
-
X_train_pca = pca.transform(X_train_imputed)
|
70 |
-
X_test_pca = pca.transform(X_test_imputed)
|
71 |
-
|
72 |
-
# Convert labels to GPU arrays
|
73 |
-
y_train = y_train.to_cupy()
|
74 |
-
y_test = y_test.to_cupy()
|
75 |
-
|
76 |
-
return X_train_pca, X_test_pca, y_train, y_test, numeric_cols_train
|
77 |
-
|
78 |
-
def train_and_evaluate_models(X_train_pca, X_test_pca, y_train, y_test, numeric_cols_train, selected_models):
|
79 |
-
# Define models
|
80 |
-
all_models = {
|
81 |
-
"Random Forest": RandomForestRegressor(n_estimators=100, random_state=42),
|
82 |
-
"XGBoost": xgb.XGBRegressor(n_estimators=100, random_state=42, tree_method='gpu_hist', gpu_id=0),
|
83 |
-
"Ridge Regression": Ridge(alpha=1.0),
|
84 |
-
"CatBoost": CatBoostRegressor(iterations=100, random_seed=42, task_type="GPU", devices='0')
|
85 |
-
}
|
86 |
-
|
87 |
-
# Filter selected models
|
88 |
-
models = {name: model for name, model in all_models.items() if name in selected_models}
|
89 |
-
|
90 |
-
results = {}
|
91 |
-
feature_importances = {}
|
92 |
-
|
93 |
-
for name, model in models.items():
|
94 |
-
if name == "XGBoost":
|
95 |
-
model.fit(cp.asnumpy(X_train_pca), cp.asnumpy(y_train))
|
96 |
-
y_pred = model.predict(cp.asnumpy(X_test_pca))
|
97 |
-
y_pred = cp.asarray(y_pred)
|
98 |
-
elif name == "CatBoost":
|
99 |
-
model.fit(cp.asnumpy(X_train_pca), cp.asnumpy(y_train), verbose=False)
|
100 |
-
y_pred = model.predict(cp.asnumpy(X_test_pca))
|
101 |
-
y_pred = cp.asarray(y_pred)
|
102 |
-
else:
|
103 |
-
model.fit(X_train_pca, y_train)
|
104 |
-
y_pred = model.predict(X_test_pca)
|
105 |
-
|
106 |
-
# Compute metrics
|
107 |
-
rmse = cp.sqrt(mean_squared_error(y_test, y_pred)).get()
|
108 |
-
r2 = r2_score(y_test, y_pred).get()
|
109 |
-
results[name] = {'RMSE': rmse, 'R-squared': r2}
|
110 |
-
|
111 |
-
# Feature importances
|
112 |
-
if hasattr(model, 'feature_importances_'):
|
113 |
-
importances = model.feature_importances_
|
114 |
-
if isinstance(importances, cp.ndarray):
|
115 |
-
importances = cp.asnumpy(importances)
|
116 |
-
feature_importances[name] = importances
|
117 |
-
|
118 |
-
return results, feature_importances, numeric_cols_train
|
119 |
-
|
120 |
-
def plot_feature_importance(importances, feature_names, model_name):
|
121 |
-
feature_importance_df = pd.DataFrame({'Feature': feature_names[:len(importances)], 'Importance': importances})
|
122 |
-
feature_importance_df = feature_importance_df.sort_values('Importance', ascending=False).head(20)
|
123 |
-
|
124 |
-
plt.figure(figsize=(10, 8))
|
125 |
-
sns.barplot(x='Importance', y='Feature', data=feature_importance_df)
|
126 |
-
plt.title(f'{model_name} Feature Importance')
|
127 |
-
plt.tight_layout()
|
128 |
-
plt.close()
|
129 |
-
return plt.gcf()
|
130 |
-
|
131 |
-
def plot_metrics(results):
|
132 |
-
metrics_df = pd.DataFrame(results).T.reset_index().rename(columns={'index': 'Model'})
|
133 |
-
|
134 |
-
plt.figure(figsize=(8, 6))
|
135 |
-
sns.barplot(x='Model', y='RMSE', data=metrics_df)
|
136 |
-
plt.title('RMSE for Each Model')
|
137 |
-
plt.xticks(rotation=45)
|
138 |
-
plt.tight_layout()
|
139 |
-
plt.close()
|
140 |
-
rmse_plot = plt.gcf()
|
141 |
-
|
142 |
-
plt.figure(figsize=(8, 6))
|
143 |
-
sns.barplot(x='Model', y='R-squared', data=metrics_df)
|
144 |
-
plt.title('R-squared for Each Model')
|
145 |
-
plt.xticks(rotation=45)
|
146 |
-
plt.tight_layout()
|
147 |
-
plt.close()
|
148 |
-
r2_plot = plt.gcf()
|
149 |
-
|
150 |
-
return rmse_plot, r2_plot
|
151 |
-
|
152 |
-
def main(selected_models):
|
153 |
-
# Load data
|
154 |
-
county_embeddings, county_embeddings_pca, pca, imputer = load_embeddings()
|
155 |
-
unemployment_long = load_unemployment_data()
|
156 |
-
|
157 |
-
# Preprocess data
|
158 |
-
X_train_pca, X_test_pca, y_train, y_test, numeric_cols_train = preprocess_data(
|
159 |
-
county_embeddings, county_embeddings_pca, unemployment_long, pca, imputer
|
160 |
-
)
|
161 |
-
|
162 |
-
# Train and evaluate models
|
163 |
-
results, feature_importances, feature_names = train_and_evaluate_models(
|
164 |
-
X_train_pca, X_test_pca, y_train, y_test, numeric_cols_train, selected_models
|
165 |
-
)
|
166 |
-
|
167 |
-
# Plot metrics
|
168 |
-
rmse_plot, r2_plot = plot_metrics(results)
|
169 |
-
|
170 |
-
# Plot feature importance for models that have it
|
171 |
-
feature_importance_plots = {}
|
172 |
-
for model_name, importances in feature_importances.items():
|
173 |
-
fig = plot_feature_importance(importances, [f'PC{i+1}' for i in range(len(importances))], model_name)
|
174 |
-
feature_importance_plots[model_name] = fig
|
175 |
-
|
176 |
-
return results, rmse_plot, r2_plot, feature_importance_plots
|
177 |
-
|
178 |
-
def gradio_app():
|
179 |
with gr.Blocks() as demo:
|
180 |
-
gr.Markdown("
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import geopandas as gpd
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
import seaborn as sns
|
6 |
+
from timesfm import TimesFm, TimesFmHparams, TimesFmCheckpoint
|
7 |
+
from sklearn.ensemble import GradientBoostingRegressor
|
8 |
+
import numpy as np
|
9 |
|
10 |
+
# GPU-optimized TimesFM setup
|
11 |
+
timesfm_backend = "gpu"
|
12 |
+
timesfm_model_config = TimesFmHparams(
|
13 |
+
context_len=512,
|
14 |
+
horizon_len=128,
|
15 |
+
per_core_batch_size=128,
|
16 |
+
backend=timesfm_backend,
|
17 |
+
)
|
18 |
+
timesfm_model = TimesFm(
|
19 |
+
hparams=timesfm_model_config,
|
20 |
+
checkpoint=TimesFmCheckpoint(huggingface_repo_id="google/timesfm-1.0-200m-pytorch")
|
21 |
+
)
|
22 |
+
|
23 |
+
# Function to load embeddings and calculate HHI
|
24 |
+
def calculate_hhi(file, market_col, id_col, weight_col):
|
25 |
+
df = pd.read_csv(file.name)
|
26 |
+
df['denominator'] = df.groupby(market_col)[weight_col].transform('sum')
|
27 |
+
df['numerator'] = df.groupby([market_col, id_col])[weight_col].transform('sum')
|
28 |
+
df['market_share'] = 100 * (df['numerator'] / df['denominator'])
|
29 |
+
df['market_share_sq'] = df['market_share'] ** 2
|
30 |
+
hhi = df.groupby(market_col).apply(lambda x: x['market_share_sq'].sum())
|
31 |
+
return hhi.reset_index(name='hhi')
|
32 |
+
|
33 |
+
# Function to visualize HHI map
|
34 |
+
def plot_hhi_map(hhi_csv, shapefile):
|
35 |
+
hhi_df = pd.read_csv(hhi_csv.name)
|
36 |
+
gdf = gpd.read_file(shapefile.name)
|
37 |
+
gdf = gdf.merge(hhi_df, left_on='fips_code', right_on='market_col', how='left')
|
38 |
+
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
|
39 |
+
gdf.plot(column='hhi', cmap='RdBu', legend=True, ax=ax, missing_kwds={"color": "lightgrey"})
|
40 |
+
ax.set_title("HHI by County")
|
41 |
+
return fig
|
42 |
+
|
43 |
+
# Function to forecast using TimesFM
|
44 |
+
def forecast(file, history_steps, forecast_steps):
|
45 |
+
df = pd.read_csv(file.name).set_index('place')
|
46 |
+
history = df[history_steps]
|
47 |
+
forecast = timesfm_model.forecast(inputs=history.values)
|
48 |
+
return pd.DataFrame(forecast, index=history.index)
|
49 |
+
|
50 |
+
# Gradio app interface
|
51 |
+
def gradio_interface():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
with gr.Blocks() as demo:
|
53 |
+
gr.Markdown("### Healthcare Network Analysis and Forecasting")
|
54 |
+
|
55 |
+
with gr.Tab("Upload Embeddings"):
|
56 |
+
file_upload = gr.File(label="Upload Embeddings (CSV)")
|
57 |
+
hhi_results = gr.DataFrame(label="HHI Results")
|
58 |
+
calculate_button = gr.Button("Calculate HHI")
|
59 |
+
calculate_button.click(
|
60 |
+
calculate_hhi,
|
61 |
+
inputs=[file_upload, "market_col", "id_col", "weight_col"],
|
62 |
+
outputs=hhi_results
|
63 |
+
)
|
64 |
+
|
65 |
+
with gr.Tab("Visualize Map"):
|
66 |
+
hhi_csv = gr.File(label="Upload HHI CSV")
|
67 |
+
shapefile = gr.File(label="Upload Shapefile")
|
68 |
+
map_plot = gr.Plot(label="HHI Map")
|
69 |
+
plot_button = gr.Button("Generate Map")
|
70 |
+
plot_button.click(plot_hhi_map, inputs=[hhi_csv, shapefile], outputs=map_plot)
|
71 |
+
|
72 |
+
with gr.Tab("Forecasting"):
|
73 |
+
forecast_file = gr.File(label="Upload Historical Data (CSV)")
|
74 |
+
forecast_steps = gr.Slider(minimum=1, maximum=24, step=1, label="Forecast Steps")
|
75 |
+
forecast_results = gr.DataFrame(label="Forecasted Data")
|
76 |
+
forecast_button = gr.Button("Forecast")
|
77 |
+
forecast_button.click(forecast, inputs=[forecast_file, forecast_steps], outputs=forecast_results)
|
78 |
+
|
79 |
+
return demo
|
80 |
+
|
81 |
+
# Run app
|
82 |
+
if __name__ == "__main__":
|
83 |
+
gradio_interface().launch()
|