koyu008 commited on
Commit
13afbbc
·
verified ·
1 Parent(s): 3386fb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -34
app.py CHANGED
@@ -6,7 +6,6 @@ from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTo
6
  from langdetect import detect
7
  from huggingface_hub import snapshot_download
8
  import os
9
- from typing import List
10
 
11
  # Device
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -85,43 +84,35 @@ app.add_middleware(
85
 
86
 
87
  class TextIn(BaseModel):
88
- texts: List[str]
89
 
90
 
91
  @app.post("/api/predict")
92
  def predict(data: TextIn):
93
- results = []
94
-
95
- for text in data.texts:
96
- try:
97
- lang = detect(text)
98
- except:
99
- lang = "unknown"
100
-
101
- if lang == "en":
102
- tokenizer = english_tokenizer
103
- model = english_model
104
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
105
- with torch.no_grad():
106
- outputs = model(**inputs)
107
- probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
108
- predictions = dict(zip(english_labels, probs))
109
- else:
110
- tokenizer = hinglish_tokenizer
111
- model = hinglish_model
112
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
113
- with torch.no_grad():
114
- outputs = model(**inputs)
115
- probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
116
- predictions = dict(zip(hinglish_labels, probs))
117
-
118
- results.append({
119
- "text": text,
120
- "language": lang if lang in ["en", "hi"] else "unknown",
121
- "predictions": predictions
122
- })
123
-
124
- return {"results": results}
125
 
126
  @app.get("/")
127
  def root():
 
6
  from langdetect import detect
7
  from huggingface_hub import snapshot_download
8
  import os
 
9
 
10
  # Device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
84
 
85
 
86
  class TextIn(BaseModel):
87
+ texts: str
88
 
89
 
90
  @app.post("/api/predict")
91
  def predict(data: TextIn):
92
+ text = data.text
93
+ try:
94
+ lang = detect(text)
95
+ except:
96
+ lang = "unknown"
97
+
98
+ if lang == "en":
99
+ tokenizer = english_tokenizer
100
+ model = english_model
101
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
102
+ with torch.no_grad():
103
+ outputs = model(**inputs)
104
+ probs = torch.sigmoid(outputs).squeeze().cpu().tolist()
105
+ return {"language": "English", "predictions": dict(zip(english_labels, probs))}
106
+
107
+ else:
108
+ tokenizer = hinglish_tokenizer
109
+ model = hinglish_model
110
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
111
+ with torch.no_grad():
112
+ outputs = model(**inputs)
113
+ probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist()
114
+ return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))}
115
+
 
 
 
 
 
 
 
 
116
 
117
  @app.get("/")
118
  def root():