Commit
·
a2d40e3
1
Parent(s):
2995c02
Refactor predict_rows function to include raw predictions
Browse filesThis commit refactors the predict_rows function in main.py to include an optional parameter, return_raw_predictions, which when set to True, returns the raw predictions along with the mean scores. This change improves the flexibility and usefulness of the function.
main.py
CHANGED
@@ -181,7 +181,9 @@ def filter_by_frequency(counts_dict: dict, threshold_percent: float = 0.2):
|
|
181 |
return {k for k, v in counts_dict.items() if v >= threshold}
|
182 |
|
183 |
|
184 |
-
def predict_rows(
|
|
|
|
|
185 |
rows = (row.get(target_column) for row in rows)
|
186 |
rows = (row for row in rows if row is not None)
|
187 |
rows = list(yield_clean_rows(rows))
|
@@ -194,9 +196,20 @@ def predict_rows(rows, target_column, language_threshold_percent=0.2):
|
|
194 |
langues_counts, threshold_percent=language_threshold_percent
|
195 |
)
|
196 |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
|
197 |
-
|
198 |
"predictions": dict(valmap(get_mean_score, filtered_dict)),
|
|
|
|
|
199 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
|
202 |
@app.get("/", include_in_schema=False)
|
|
|
181 |
return {k for k, v in counts_dict.items() if v >= threshold}
|
182 |
|
183 |
|
184 |
+
def predict_rows(
|
185 |
+
rows, target_column, language_threshold_percent=0.2, return_raw_predictions=False
|
186 |
+
):
|
187 |
rows = (row.get(target_column) for row in rows)
|
188 |
rows = (row for row in rows if row is not None)
|
189 |
rows = list(yield_clean_rows(rows))
|
|
|
196 |
langues_counts, threshold_percent=language_threshold_percent
|
197 |
)
|
198 |
filtered_dict = {k: v for k, v in predictions_by_lang.items() if k in keys_to_keep}
|
199 |
+
default_data = {
|
200 |
"predictions": dict(valmap(get_mean_score, filtered_dict)),
|
201 |
+
"hub_id": "hub_id",
|
202 |
+
"config": "config",
|
203 |
}
|
204 |
+
if return_raw_predictions:
|
205 |
+
default_data["raw_predictions"] = predictions
|
206 |
+
return default_data
|
207 |
+
|
208 |
+
|
209 |
+
# @app.get("/", response_class=HTMLResponse)
|
210 |
+
# async def read_index():
|
211 |
+
# html_content = Path("index.html").read_text()
|
212 |
+
# return HTMLResponse(content=html_content)
|
213 |
|
214 |
|
215 |
@app.get("/", include_in_schema=False)
|