Commit
·
ef19caa
1
Parent(s):
9915c6f
Refactor app.py: Import modules, update function parameters, and improve logging
Browse files
app.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from httpx import Client
|
3 |
-
import random
|
4 |
import os
|
|
|
|
|
|
|
|
|
5 |
import fasttext
|
6 |
-
|
7 |
-
from typing import Union
|
8 |
-
from typing import Iterator
|
9 |
from dotenv import load_dotenv
|
10 |
-
from
|
11 |
-
from
|
12 |
-
from httpx import Timeout
|
13 |
from huggingface_hub.utils import logging
|
|
|
14 |
|
15 |
logger = logging.get_logger(__name__)
|
16 |
load_dotenv()
|
@@ -24,6 +23,7 @@ headers = {
|
|
24 |
}
|
25 |
timeout = Timeout(60, read=120)
|
26 |
client = Client(headers=headers, timeout=timeout)
|
|
|
27 |
# non exhaustive list of columns that might contain text which can be used for language detection
|
28 |
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
|
29 |
TARGET_COLUMN_NAMES = {
|
@@ -73,10 +73,10 @@ def get_dataset_info(hub_id: str, config: str | None = None):
|
|
73 |
|
74 |
|
75 |
def get_random_rows(
|
76 |
-
hub_id,
|
77 |
-
total_length,
|
78 |
-
number_of_rows,
|
79 |
-
max_request_calls,
|
80 |
config="default",
|
81 |
split="train",
|
82 |
):
|
@@ -88,8 +88,9 @@ def get_random_rows(
|
|
88 |
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
|
89 |
offset = random.randint(0, total_length - rows_per_call)
|
90 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
|
|
|
|
91 |
response = client.get(url)
|
92 |
-
|
93 |
if response.status_code == 200:
|
94 |
data = response.json()
|
95 |
batch_rows = data.get("rows")
|
@@ -107,10 +108,6 @@ def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
|
107 |
return fasttext.load_model(model_path)
|
108 |
|
109 |
|
110 |
-
# def predict_language_for_rows(rows: list[dict], target_column_names: list[str] | str):
|
111 |
-
# pass
|
112 |
-
|
113 |
-
|
114 |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
|
115 |
for row in rows:
|
116 |
if isinstance(row, str):
|
@@ -186,7 +183,8 @@ def predict_language(
|
|
186 |
config: str | None = None,
|
187 |
split: str | None = None,
|
188 |
max_request_calls: int = 10,
|
189 |
-
|
|
|
190 |
is_valid = datasets_server_valid_rows(hub_id)
|
191 |
if not is_valid:
|
192 |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
|
@@ -202,7 +200,7 @@ def predict_language(
|
|
202 |
logger.info(f"Column names: {column_names}")
|
203 |
if not set(column_names).intersection(TARGET_COLUMN_NAMES):
|
204 |
raise gr.Error(
|
205 |
-
f"Dataset {hub_id}
|
206 |
)
|
207 |
for column in TARGET_COLUMN_NAMES:
|
208 |
if column in column_names:
|
@@ -210,7 +208,12 @@ def predict_language(
|
|
210 |
logger.info(f"Using column {target_column} for language detection")
|
211 |
break
|
212 |
random_rows = get_random_rows(
|
213 |
-
hub_id,
|
|
|
|
|
|
|
|
|
|
|
214 |
)
|
215 |
logger.info(f"Predicting language for {len(random_rows)} rows")
|
216 |
predictions = predict_rows(random_rows, target_column)
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
import random
|
3 |
+
from statistics import mean
|
4 |
+
from typing import Iterator, Union
|
5 |
+
|
6 |
import fasttext
|
7 |
+
import gradio as gr
|
|
|
|
|
8 |
from dotenv import load_dotenv
|
9 |
+
from httpx import Client, Timeout
|
10 |
+
from huggingface_hub import hf_hub_download
|
|
|
11 |
from huggingface_hub.utils import logging
|
12 |
+
from toolz import concat, groupby, valmap
|
13 |
|
14 |
logger = logging.get_logger(__name__)
|
15 |
load_dotenv()
|
|
|
23 |
}
|
24 |
timeout = Timeout(60, read=120)
|
25 |
client = Client(headers=headers, timeout=timeout)
|
26 |
+
# async_client = AsyncClient(headers=headers, timeout=timeout)
|
27 |
# non exhaustive list of columns that might contain text which can be used for language detection
|
28 |
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
|
29 |
TARGET_COLUMN_NAMES = {
|
|
|
73 |
|
74 |
|
75 |
def get_random_rows(
|
76 |
+
hub_id: str,
|
77 |
+
total_length: int,
|
78 |
+
number_of_rows: int,
|
79 |
+
max_request_calls: int,
|
80 |
config="default",
|
81 |
split="train",
|
82 |
):
|
|
|
88 |
for _ in range(min(max_request_calls, number_of_rows // rows_per_call)):
|
89 |
offset = random.randint(0, total_length - rows_per_call)
|
90 |
url = f"https://datasets-server.huggingface.co/rows?dataset={hub_id}&config={config}&split={split}&offset={offset}&length={rows_per_call}"
|
91 |
+
logger.info(f"Fetching {url}")
|
92 |
+
print(url)
|
93 |
response = client.get(url)
|
|
|
94 |
if response.status_code == 200:
|
95 |
data = response.json()
|
96 |
batch_rows = data.get("rows")
|
|
|
108 |
return fasttext.load_model(model_path)
|
109 |
|
110 |
|
|
|
|
|
|
|
|
|
111 |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
|
112 |
for row in rows:
|
113 |
if isinstance(row, str):
|
|
|
183 |
config: str | None = None,
|
184 |
split: str | None = None,
|
185 |
max_request_calls: int = 10,
|
186 |
+
number_of_rows: int = 1000,
|
187 |
+
) -> dict[str, float | str]:
|
188 |
is_valid = datasets_server_valid_rows(hub_id)
|
189 |
if not is_valid:
|
190 |
gr.Error(f"Dataset {hub_id} is not accessible via the datasets server.")
|
|
|
200 |
logger.info(f"Column names: {column_names}")
|
201 |
if not set(column_names).intersection(TARGET_COLUMN_NAMES):
|
202 |
raise gr.Error(
|
203 |
+
f"Dataset {hub_id} {column_names} is not in any of the target columns {TARGET_COLUMN_NAMES}"
|
204 |
)
|
205 |
for column in TARGET_COLUMN_NAMES:
|
206 |
if column in column_names:
|
|
|
208 |
logger.info(f"Using column {target_column} for language detection")
|
209 |
break
|
210 |
random_rows = get_random_rows(
|
211 |
+
hub_id,
|
212 |
+
total_rows_for_split,
|
213 |
+
number_of_rows,
|
214 |
+
max_request_calls,
|
215 |
+
config,
|
216 |
+
split,
|
217 |
)
|
218 |
logger.info(f"Predicting language for {len(random_rows)} rows")
|
219 |
predictions = predict_rows(random_rows, target_column)
|