tangminhanh commited on
Commit
f493449
·
verified ·
1 Parent(s): 5a33d2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -1,15 +1,25 @@
1
- from fastapi import FastAPI
 
 
2
 
3
  app = FastAPI()
4
- from transformers import pipeline
5
 
6
- pipe = pipeline("text-classification", model="kmcs-casulit/hr_cate")
 
 
 
 
 
 
 
 
7
 
8
  @app.get("/")
9
  def greet_json():
10
  return {"message": "Hello, World!"}
11
 
12
- @app.get("/")
13
- def classify(text: str):
14
- output = pipe(text)
 
15
  return {"output": output[0]['label']}
 
1
+ from fastapi import FastAPI, Query
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
4
 
5
  app = FastAPI()
 
6
 
7
+ # Load the model and tokenizer
8
+ model = AutoModelForSequenceClassification.from_pretrained(
9
+ "kmcs-casulit/hr_cate")
10
+ tokenizer = AutoTokenizer.from_pretrained("kmcs-casulit/hr_cate")
11
+ pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
12
+
13
+ # Define a request model
14
+ class TextRequest(BaseModel):
15
+ text: str
16
 
17
  @app.get("/")
18
  def greet_json():
19
  return {"message": "Hello, World!"}
20
 
21
+
22
+ @app.post("/classify/")
23
+ def classify(request: TextRequest):
24
+ output = pipe(request.text)
25
  return {"output": output[0]['label']}