Commit
·
41869c7
1
Parent(s):
f1bc1ad
Add fastapi.responses and starlette.responses imports
Browse files
main.py
CHANGED
@@ -3,7 +3,7 @@ import random
|
|
3 |
from pathlib import Path
|
4 |
from statistics import mean
|
5 |
from typing import Any, Iterator, Union
|
6 |
-
|
7 |
import fasttext
|
8 |
from dotenv import load_dotenv
|
9 |
from fastapi import FastAPI
|
@@ -11,6 +11,7 @@ from httpx import AsyncClient, Client, Timeout
|
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from huggingface_hub.utils import logging
|
13 |
from toolz import concat, groupby, valmap
|
|
|
14 |
|
15 |
app = FastAPI()
|
16 |
logger = logging.get_logger(__name__)
|
@@ -19,16 +20,17 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
19 |
|
20 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
21 |
|
|
|
|
|
22 |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
23 |
-
DEFAULT_FAST_TEXT_MODEL = "
|
24 |
headers = {
|
25 |
"authorization": f"Bearer ${HF_TOKEN}",
|
26 |
}
|
27 |
timeout = Timeout(60, read=120)
|
28 |
client = Client(headers=headers, timeout=timeout)
|
29 |
async_client = AsyncClient(headers=headers, timeout=timeout)
|
30 |
-
|
31 |
-
# we prefer to use columns in this order i.e. if there is a column named "text" we will use it first
|
32 |
TARGET_COLUMN_NAMES = {
|
33 |
"text",
|
34 |
"input",
|
@@ -116,10 +118,20 @@ async def get_random_rows(
|
|
116 |
|
117 |
|
118 |
def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
return fasttext.load_model(model_path)
|
121 |
|
122 |
|
|
|
|
|
|
|
123 |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
|
124 |
for row in rows:
|
125 |
if isinstance(row, str):
|
@@ -139,21 +151,6 @@ def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterat
|
|
139 |
continue
|
140 |
|
141 |
|
142 |
-
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
|
143 |
-
|
144 |
-
# model = load_model(DEFAULT_FAST_TEXT_MODEL)
|
145 |
-
Path("code/models").mkdir(parents=True, exist_ok=True)
|
146 |
-
model = fasttext.load_model(
|
147 |
-
hf_hub_download(
|
148 |
-
"facebook/fasttext-language-identification",
|
149 |
-
"model.bin",
|
150 |
-
cache_dir="code/models",
|
151 |
-
local_dir="code/models",
|
152 |
-
local_dir_use_symlinks=False,
|
153 |
-
)
|
154 |
-
)
|
155 |
-
|
156 |
-
|
157 |
def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
|
158 |
predictions = model.predict(inputs, k=k)
|
159 |
return [
|
@@ -196,6 +193,17 @@ def predict_rows(rows, target_column, language_threshold_percent=0.2):
|
|
196 |
}
|
197 |
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
@app.get("/predict_dataset_language/{hub_id}")
|
200 |
async def predict_language(
|
201 |
hub_id: str,
|
|
|
3 |
from pathlib import Path
|
4 |
from statistics import mean
|
5 |
from typing import Any, Iterator, Union
|
6 |
+
from fastapi.responses import HTMLResponse
|
7 |
import fasttext
|
8 |
from dotenv import load_dotenv
|
9 |
from fastapi import FastAPI
|
|
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
from huggingface_hub.utils import logging
|
13 |
from toolz import concat, groupby, valmap
|
14 |
+
from starlette.responses import RedirectResponse
|
15 |
|
16 |
app = FastAPI()
|
17 |
logger = logging.get_logger(__name__)
|
|
|
20 |
|
21 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
22 |
|
23 |
+
FASTTEXT_PREFIX_LENGTH = 9 # fasttext labels are formatted like "__label__eng_Latn"
|
24 |
+
|
25 |
BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
|
26 |
+
DEFAULT_FAST_TEXT_MODEL = "facebook/fasttext-language-identification"
|
27 |
headers = {
|
28 |
"authorization": f"Bearer ${HF_TOKEN}",
|
29 |
}
|
30 |
timeout = Timeout(60, read=120)
|
31 |
client = Client(headers=headers, timeout=timeout)
|
32 |
async_client = AsyncClient(headers=headers, timeout=timeout)
|
33 |
+
|
|
|
34 |
TARGET_COLUMN_NAMES = {
|
35 |
"text",
|
36 |
"input",
|
|
|
118 |
|
119 |
|
120 |
def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
121 |
+
Path("code/models").mkdir(parents=True, exist_ok=True)
|
122 |
+
model_path = hf_hub_download(
|
123 |
+
repo_id,
|
124 |
+
"model.bin",
|
125 |
+
cache_dir="code/models",
|
126 |
+
local_dir="code/models",
|
127 |
+
local_dir_use_symlinks=False,
|
128 |
+
)
|
129 |
return fasttext.load_model(model_path)
|
130 |
|
131 |
|
132 |
+
model = load_model(DEFAULT_FAST_TEXT_MODEL)
|
133 |
+
|
134 |
+
|
135 |
def yield_clean_rows(rows: Union[list[str], str], min_length: int = 3) -> Iterator[str]:
|
136 |
for row in rows:
|
137 |
if isinstance(row, str):
|
|
|
151 |
continue
|
152 |
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
def model_predict(inputs: str, k=1) -> list[dict[str, float]]:
|
155 |
predictions = model.predict(inputs, k=k)
|
156 |
return [
|
|
|
193 |
}
|
194 |
|
195 |
|
196 |
+
# @app.get("/", response_class=HTMLResponse)
|
197 |
+
# async def read_index():
|
198 |
+
# html_content = Path("index.html").read_text()
|
199 |
+
# return HTMLResponse(content=html_content)
|
200 |
+
|
201 |
+
|
202 |
+
@app.get("/", include_in_schema=False)
|
203 |
+
def root():
|
204 |
+
return RedirectResponse(url="/docs")
|
205 |
+
|
206 |
+
|
207 |
@app.get("/predict_dataset_language/{hub_id}")
|
208 |
async def predict_language(
|
209 |
hub_id: str,
|