File size: 1,669 Bytes
85a7a06
2242346
 
255ee19
2242346
b1aff02
52674a0
2242346
 
 
b1aff02
 
 
 
 
 
 
 
 
85a7a06
 
 
2242346
403301f
2242346
 
 
 
 
270e841
 
 
 
 
 
 
 
 
 
 
2242346
270e841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2242346
bdf5cae
85a7a06
 
f6dd423
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from huggingface_hub import login
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from transformers import pipeline
from fastapi.middleware.cors import CORSMiddleware
import os

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


access_token = os.environ.get("ACCESS_TOKEN_1")
login(token=access_token, add_to_git_credential=True)

# Load the model and tokenizer from the Hugging Face Hub
model_name = "MHULO/yembaner"
nlp = pipeline("ner", model=model_name, tokenizer=model_name)

class TextRequest(BaseModel):
    text: str

# Define a Pydantic model for the response
class Entity(BaseModel):
    entity: str
    score: float
    index: int
    word: str
    start: int
    end: int


@app.post("/predict/", response_model=List[Entity])
def predict(request: TextRequest):
    try:
        # Use the model to perform NER on the input text
        ner_results = nlp(request.text)
        
        # Format the results to match the desired output structure
        entities = [
            {
                "entity": result["entity"],
                "score": result["score"],
                "index": result["index"],
                "word": result["word"],
                "start": result["start"],
                "end": result["end"]
            }
            for result in ner_results
        ]
        
        return entities
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/")
def root():
    return {"prediction url": "https://mhulo-yembaner.hf.space/predict/"}