navindusa commited on
Commit
f1c3cd3
·
1 Parent(s): 2d02134

Enhance API with bulk prediction and model metadata; improve error handling and text processing

Browse files
Files changed (1) hide show
  1. app.py +86 -12
app.py CHANGED
@@ -1,15 +1,36 @@
1
- from fastapi import FastAPI
2
- from transformers import pipeline
 
 
3
  import re
4
 
5
- pipe = pipeline("text-classification", model="JungleLee/bert-toxic-comment-classification")
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  app = FastAPI(
8
  title="Hopeline - AI Inference API",
9
- description="API for detecting toxic comments",
10
- version="0.1"
11
  )
12
 
 
 
 
 
 
 
 
13
  def preprocess_text(text: str) -> str:
14
  # Remove special characters and extra whitespace
15
  text = re.sub(r'[^\w\s]', '', text)
@@ -23,17 +44,70 @@ def preprocess_text(text: str) -> str:
23
  async def welcome():
24
  return "Welcome to Hopeline - AI Inference API"
25
 
 
 
 
 
26
  @app.post('/predict')
27
- async def predict_post(request_body: dict):
28
- text = request_body.get('text', '')
29
- if not text:
30
- return {"error": "No text provided"}
31
 
32
  # Preprocess text
33
- processed_text = preprocess_text(text)
 
 
 
 
 
 
34
 
35
  # Get prediction
36
- prediction = pipe(processed_text)
37
- return prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional
4
+ from transformers import pipeline, AutoTokenizer
5
  import re
6
 
7
+ # Load the model and tokenizer
8
+ model_name = "JungleLee/bert-toxic-comment-classification"
9
+ pipe = pipeline("text-classification", model=model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ # Model metadata
13
+ model_info = {
14
+ "name": "BERT for Toxic Comment Classification",
15
+ "description": "A fine-tuned BERT model that detects toxic content in text",
16
+ "labels": ["toxic", "non-toxic"],
17
+ "max_sequence_length": tokenizer.model_max_length,
18
+ "author": "JungleLee"
19
+ }
20
 
21
  app = FastAPI(
22
  title="Hopeline - AI Inference API",
23
+ description="API for detecting toxic comments using a BERT-based model",
24
+ version="0.2"
25
  )
26
 
27
+ class TextRequest(BaseModel):
28
+ text: str
29
+
30
+ class BulkTextRequest(BaseModel):
31
+ texts: List[str]
32
+ threshold: Optional[float] = 0.5
33
+
34
  def preprocess_text(text: str) -> str:
35
  # Remove special characters and extra whitespace
36
  text = re.sub(r'[^\w\s]', '', text)
 
44
  async def welcome():
45
  return "Welcome to Hopeline - AI Inference API"
46
 
47
+ @app.get("/model-info")
48
+ async def get_model_info():
49
+ return model_info
50
+
51
  @app.post('/predict')
52
+ async def predict_post(request: TextRequest):
53
+ if not request.text:
54
+ raise HTTPException(status_code=400, detail="No text provided")
 
55
 
56
  # Preprocess text
57
+ processed_text = preprocess_text(request.text)
58
+
59
+ # Check token length and truncate if needed
60
+ tokens = tokenizer.tokenize(processed_text)
61
+ if len(tokens) > tokenizer.model_max_length - 2: # -2 for special tokens
62
+ tokens = tokens[:tokenizer.model_max_length - 2]
63
+ processed_text = tokenizer.convert_tokens_to_string(tokens)
64
 
65
  # Get prediction
66
+ prediction = pipe(processed_text)[0]
67
+
68
+ return {
69
+ "text": request.text,
70
+ "label": prediction["label"],
71
+ "score": prediction["score"],
72
+ "is_toxic": prediction["label"] == "toxic"
73
+ }
74
+
75
+ @app.post('/predict-bulk')
76
+ async def predict_bulk(request: BulkTextRequest):
77
+ if not request.texts:
78
+ raise HTTPException(status_code=400, detail="No texts provided")
79
+
80
+ results = []
81
+
82
+ for text in request.texts:
83
+ # Preprocess text
84
+ processed_text = preprocess_text(text)
85
+
86
+ # Check token length and truncate if needed
87
+ tokens = tokenizer.tokenize(processed_text)
88
+ if len(tokens) > tokenizer.model_max_length - 2:
89
+ tokens = tokens[:tokenizer.model_max_length - 2]
90
+ processed_text = tokenizer.convert_tokens_to_string(tokens)
91
+
92
+ # Get prediction
93
+ prediction = pipe(processed_text)[0]
94
+
95
+ results.append({
96
+ "text": text,
97
+ "label": prediction["label"],
98
+ "score": prediction["score"],
99
+ "is_toxic": prediction["label"] == "toxic",
100
+ "exceeds_threshold": prediction["score"] > request.threshold if prediction["label"] == "toxic" else False
101
+ })
102
+
103
+ return {
104
+ "results": results,
105
+ "summary": {
106
+ "total": len(results),
107
+ "toxic_count": sum(1 for r in results if r["is_toxic"]),
108
+ "non_toxic_count": sum(1 for r in results if not r["is_toxic"]),
109
+ "threshold_exceeded_count": sum(1 for r in results if r["exceeds_threshold"])
110
+ }
111
+ }
112
 
113