Commit
·
1b8ac52
1
Parent(s):
5a41550
Refactor code and add type annotations
Browse files
main.py
CHANGED
@@ -1,20 +1,20 @@
|
|
1 |
import os
|
2 |
import random
|
3 |
from datetime import timedelta
|
4 |
-
|
5 |
from statistics import mean
|
6 |
from typing import Any, Iterator, Union
|
7 |
-
|
8 |
import fasttext
|
9 |
from cashews import cache
|
10 |
from dotenv import load_dotenv
|
11 |
-
from fastapi import FastAPI
|
12 |
from httpx import AsyncClient, Client, Timeout
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
-
from huggingface_hub.utils import logging
|
15 |
from iso639 import Lang
|
16 |
from starlette.responses import RedirectResponse
|
17 |
from toolz import concat, groupby, valmap
|
|
|
18 |
|
19 |
cache.setup("mem://")
|
20 |
|
@@ -130,6 +130,8 @@ async def get_random_rows(
|
|
130 |
|
131 |
|
132 |
def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
|
|
|
|
133 |
Path("code/models").mkdir(parents=True, exist_ok=True)
|
134 |
model_path = hf_hub_download(
|
135 |
repo_id,
|
@@ -237,14 +239,18 @@ def predict_rows(
|
|
237 |
def root():
|
238 |
return RedirectResponse(url="/docs")
|
239 |
|
|
|
|
|
240 |
|
241 |
@app.get("/predict_dataset_language/{hub_id:path}")
|
242 |
@cache(ttl=timedelta(minutes=10))
|
243 |
async def predict_language(
|
244 |
-
hub_id: str,
|
245 |
config: str | None = None,
|
246 |
split: str | None = None,
|
247 |
-
max_request_calls:
|
|
|
|
|
248 |
number_of_rows: int = 1000,
|
249 |
) -> dict[Any, Any] | None:
|
250 |
is_valid = datasets_server_valid_rows(hub_id)
|
|
|
1 |
import os
|
2 |
import random
|
3 |
from datetime import timedelta
|
4 |
+
|
5 |
from statistics import mean
|
6 |
from typing import Any, Iterator, Union
|
7 |
+
from typing import Annotated
|
8 |
import fasttext
|
9 |
from cashews import cache
|
10 |
from dotenv import load_dotenv
|
11 |
+
from fastapi import FastAPI, Path
|
12 |
from httpx import AsyncClient, Client, Timeout
|
13 |
from huggingface_hub import hf_hub_download
|
|
|
14 |
from iso639 import Lang
|
15 |
from starlette.responses import RedirectResponse
|
16 |
from toolz import concat, groupby, valmap
|
17 |
+
import logging
|
18 |
|
19 |
cache.setup("mem://")
|
20 |
|
|
|
130 |
|
131 |
|
132 |
def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
133 |
+
from pathlib import Path
|
134 |
+
|
135 |
Path("code/models").mkdir(parents=True, exist_ok=True)
|
136 |
model_path = hf_hub_download(
|
137 |
repo_id,
|
|
|
239 |
def root():
|
240 |
return RedirectResponse(url="/docs")
|
241 |
|
242 |
+
# item_id: Annotated[int, Path(title="The ID of the item to get", ge=1)], q: str
|
243 |
+
|
244 |
|
245 |
@app.get("/predict_dataset_language/{hub_id:path}")
|
246 |
@cache(ttl=timedelta(minutes=10))
|
247 |
async def predict_language(
|
248 |
+
hub_id: Annotated[str, Path(title="The hub id of the dataset to predict")],
|
249 |
config: str | None = None,
|
250 |
split: str | None = None,
|
251 |
+
max_request_calls: Annotated[
|
252 |
+
int, Path(title="Max number of requests to datasets server", gt=0, le=20)
|
253 |
+
] = 10,
|
254 |
number_of_rows: int = 1000,
|
255 |
) -> dict[Any, Any] | None:
|
256 |
is_valid = datasets_server_valid_rows(hub_id)
|