AISimplyExplained commited on
Commit
14c8502
·
verified ·
1 Parent(s): 3e6c12d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +25 -2
main.py CHANGED
@@ -21,20 +21,43 @@ class Guardrail:
21
  def guard(self, prompt):
22
  return self.classifier(prompt)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class TextPrompt(BaseModel):
26
  prompt: str
27
 
28
 
 
 
 
 
 
 
 
29
  app = FastAPI()
30
  guardrail = Guardrail()
31
 
32
 
33
- @app.post("/classify/")
34
  def classify_text(text_prompt: TextPrompt):
35
  try:
36
  result = guardrail.guard(text_prompt.prompt)
37
- return result
 
 
 
38
  except Exception as e:
39
  raise HTTPException(status_code=500, detail=str(e))
40
 
 
21
  def guard(self, prompt):
22
  return self.classifier(prompt)
23
 
24
+ def determine_level(self, label, score):
25
+ if label == "SAFE":
26
+ return 0, "safe"
27
+ else:
28
+ if score > 0.9:
29
+ return 4, "high"
30
+ elif score > 0.75:
31
+ return 3, "medium"
32
+ elif score > 0.5:
33
+ return 2, "low"
34
+ else:
35
+ return 1, "very low"
36
+
37
 
38
  class TextPrompt(BaseModel):
39
  prompt: str
40
 
41
 
42
+ class ClassificationResult(BaseModel):
43
+ label: str
44
+ score: float
45
+ level: int
46
+ severity_label: str
47
+
48
+
49
  app = FastAPI()
50
  guardrail = Guardrail()
51
 
52
 
53
+ @app.post("/classify/", response_model=ClassificationResult)
54
  def classify_text(text_prompt: TextPrompt):
55
  try:
56
  result = guardrail.guard(text_prompt.prompt)
57
+ label = result[0]['label']
58
+ score = result[0]['score']
59
+ level, severity_label = guardrail.determine_level(label, score)
60
+ return {"label": label, "score": score, "level": level, "severity_label": severity_label}
61
  except Exception as e:
62
  raise HTTPException(status_code=500, detail=str(e))
63