Commit
·
1b8ac52
1
Parent(s):
5a41550
Refactor code and add type annotations
Browse files
main.py
CHANGED
|
@@ -1,20 +1,20 @@
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
from datetime import timedelta
|
| 4 |
-
|
| 5 |
from statistics import mean
|
| 6 |
from typing import Any, Iterator, Union
|
| 7 |
-
|
| 8 |
import fasttext
|
| 9 |
from cashews import cache
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
-
from fastapi import FastAPI
|
| 12 |
from httpx import AsyncClient, Client, Timeout
|
| 13 |
from huggingface_hub import hf_hub_download
|
| 14 |
-
from huggingface_hub.utils import logging
|
| 15 |
from iso639 import Lang
|
| 16 |
from starlette.responses import RedirectResponse
|
| 17 |
from toolz import concat, groupby, valmap
|
|
|
|
| 18 |
|
| 19 |
cache.setup("mem://")
|
| 20 |
|
|
@@ -130,6 +130,8 @@ async def get_random_rows(
|
|
| 130 |
|
| 131 |
|
| 132 |
def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
|
|
|
|
|
|
| 133 |
Path("code/models").mkdir(parents=True, exist_ok=True)
|
| 134 |
model_path = hf_hub_download(
|
| 135 |
repo_id,
|
|
@@ -237,14 +239,18 @@ def predict_rows(
|
|
| 237 |
def root():
|
| 238 |
return RedirectResponse(url="/docs")
|
| 239 |
|
|
|
|
|
|
|
| 240 |
|
| 241 |
@app.get("/predict_dataset_language/{hub_id:path}")
|
| 242 |
@cache(ttl=timedelta(minutes=10))
|
| 243 |
async def predict_language(
|
| 244 |
-
hub_id: str,
|
| 245 |
config: str | None = None,
|
| 246 |
split: str | None = None,
|
| 247 |
-
max_request_calls:
|
|
|
|
|
|
|
| 248 |
number_of_rows: int = 1000,
|
| 249 |
) -> dict[Any, Any] | None:
|
| 250 |
is_valid = datasets_server_valid_rows(hub_id)
|
|
|
|
| 1 |
import os
|
| 2 |
import random
|
| 3 |
from datetime import timedelta
|
| 4 |
+
|
| 5 |
from statistics import mean
|
| 6 |
from typing import Any, Iterator, Union
|
| 7 |
+
from typing import Annotated
|
| 8 |
import fasttext
|
| 9 |
from cashews import cache
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
+
from fastapi import FastAPI, Path
|
| 12 |
from httpx import AsyncClient, Client, Timeout
|
| 13 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 14 |
from iso639 import Lang
|
| 15 |
from starlette.responses import RedirectResponse
|
| 16 |
from toolz import concat, groupby, valmap
|
| 17 |
+
import logging
|
| 18 |
|
| 19 |
cache.setup("mem://")
|
| 20 |
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
def load_model(repo_id: str) -> fasttext.FastText._FastText:
|
| 133 |
+
from pathlib import Path
|
| 134 |
+
|
| 135 |
Path("code/models").mkdir(parents=True, exist_ok=True)
|
| 136 |
model_path = hf_hub_download(
|
| 137 |
repo_id,
|
|
|
|
| 239 |
def root():
|
| 240 |
return RedirectResponse(url="/docs")
|
| 241 |
|
| 242 |
+
# item_id: Annotated[int, Path(title="The ID of the item to get", ge=1)], q: str
|
| 243 |
+
|
| 244 |
|
| 245 |
@app.get("/predict_dataset_language/{hub_id:path}")
|
| 246 |
@cache(ttl=timedelta(minutes=10))
|
| 247 |
async def predict_language(
|
| 248 |
+
hub_id: Annotated[str, Path(title="The hub id of the dataset to predict")],
|
| 249 |
config: str | None = None,
|
| 250 |
split: str | None = None,
|
| 251 |
+
max_request_calls: Annotated[
|
| 252 |
+
int, Path(title="Max number of requests to datasets server", gt=0, le=20)
|
| 253 |
+
] = 10,
|
| 254 |
number_of_rows: int = 1000,
|
| 255 |
) -> dict[Any, Any] | None:
|
| 256 |
is_valid = datasets_server_valid_rows(hub_id)
|