Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +35 -27
tasks/text.py
CHANGED
@@ -21,6 +21,28 @@ router = APIRouter()
|
|
21 |
DESCRIPTION = "Random Baseline"
|
22 |
ROUTE = "/text"
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
@router.post(ROUTE, tags=["Text Task"],
|
25 |
description=DESCRIPTION)
|
26 |
async def evaluate_text(request: TextEvaluationRequest):
|
@@ -83,12 +105,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
83 |
# print(predictions)
|
84 |
# print("final predictions : ", predictions)
|
85 |
# Initialize the model once
|
86 |
-
classifier =
|
87 |
-
"text-classification",
|
88 |
-
"camillebrl/ModernBERT-envclaims-overfit",
|
89 |
-
device="cpu", # Explicitly set device
|
90 |
-
batch_size=16 # Set batch size for pipeline
|
91 |
-
)
|
92 |
|
93 |
# Prepare batches
|
94 |
batch_size = 32
|
@@ -105,27 +122,18 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
105 |
|
106 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
107 |
# Submit all batches for processing
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
batch_idx = future_to_batch[future]
|
121 |
-
try:
|
122 |
-
batch_predictions[batch_idx] = future.result()
|
123 |
-
except Exception as e:
|
124 |
-
print(f"Batch {batch_idx} generated an exception: {e}")
|
125 |
-
batch_predictions[batch_idx] = []
|
126 |
-
|
127 |
-
# Flatten predictions
|
128 |
-
predictions = [pred for batch in batch_predictions for pred in batch]
|
129 |
|
130 |
#--------------------------------------------------------------------------------------------
|
131 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
|
21 |
DESCRIPTION = "Random Baseline"
|
22 |
ROUTE = "/text"
|
23 |
|
24 |
+
class TextClassifier:
|
25 |
+
def __init__(self):
|
26 |
+
self.config = AutoConfig.from_pretrained("camillebrl/ModernBERT-envclaims-overfit")
|
27 |
+
self.label2id = self.config.label2id
|
28 |
+
self.classifier = pipeline(
|
29 |
+
"text-classification",
|
30 |
+
"camillebrl/ModernBERT-envclaims-overfit",
|
31 |
+
device="cpu",
|
32 |
+
batch_size=16
|
33 |
+
)
|
34 |
+
|
35 |
+
def process_batch(self, batch: List[str]) -> List[int]:
|
36 |
+
"""
|
37 |
+
Process a batch of texts and return their predictions
|
38 |
+
"""
|
39 |
+
try:
|
40 |
+
batch_preds = self.classifier(list(batch))
|
41 |
+
return [self.label2id[pred[0]["label"]] for pred in batch_preds]
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error processing batch: {e}")
|
44 |
+
return []
|
45 |
+
|
46 |
@router.post(ROUTE, tags=["Text Task"],
|
47 |
description=DESCRIPTION)
|
48 |
async def evaluate_text(request: TextEvaluationRequest):
|
|
|
105 |
# print(predictions)
|
106 |
# print("final predictions : ", predictions)
|
107 |
# Initialize the model once
|
108 |
+
classifier = TextClassifier()
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
# Prepare batches
|
111 |
batch_size = 32
|
|
|
122 |
|
123 |
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
124 |
# Submit all batches for processing
|
125 |
+
futures = [
|
126 |
+
executor.submit(classifier.process_batch, batch)
|
127 |
+
for batch in batches
|
128 |
+
]
|
129 |
+
|
130 |
+
# Collect results in order
|
131 |
+
for future in futures:
|
132 |
+
try:
|
133 |
+
batch_preds = future.result()
|
134 |
+
predictions.extend(batch_preds)
|
135 |
+
except Exception as e:
|
136 |
+
print(f"Batch processing failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
#--------------------------------------------------------------------------------------------
|
139 |
# YOUR MODEL INFERENCE STOPS HERE
|