camillebrl commited on
Commit
f5ac2a0
·
verified ·
1 Parent(s): a66df45

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +35 -27
tasks/text.py CHANGED
@@ -21,6 +21,28 @@ router = APIRouter()
21
  DESCRIPTION = "Random Baseline"
22
  ROUTE = "/text"
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  @router.post(ROUTE, tags=["Text Task"],
25
  description=DESCRIPTION)
26
  async def evaluate_text(request: TextEvaluationRequest):
@@ -83,12 +105,7 @@ async def evaluate_text(request: TextEvaluationRequest):
83
  # print(predictions)
84
  # print("final predictions : ", predictions)
85
  # Initialize the model once
86
- classifier = pipeline(
87
- "text-classification",
88
- "camillebrl/ModernBERT-envclaims-overfit",
89
- device="cpu", # Explicitly set device
90
- batch_size=16 # Set batch size for pipeline
91
- )
92
 
93
  # Prepare batches
94
  batch_size = 32
@@ -105,27 +122,18 @@ async def evaluate_text(request: TextEvaluationRequest):
105
 
106
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
107
  # Submit all batches for processing
108
- future_to_batch = {
109
- executor.submit(
110
- process_batch,
111
- batch,
112
- classifier,
113
- label2id
114
- ): i for i, batch in enumerate(batches)
115
- }
116
-
117
- # Collect results in order
118
- batch_predictions = [[] for _ in range(len(batches))]
119
- for future in future_to_batch:
120
- batch_idx = future_to_batch[future]
121
- try:
122
- batch_predictions[batch_idx] = future.result()
123
- except Exception as e:
124
- print(f"Batch {batch_idx} generated an exception: {e}")
125
- batch_predictions[batch_idx] = []
126
-
127
- # Flatten predictions
128
- predictions = [pred for batch in batch_predictions for pred in batch]
129
 
130
  #--------------------------------------------------------------------------------------------
131
  # YOUR MODEL INFERENCE STOPS HERE
 
21
  DESCRIPTION = "Random Baseline"
22
  ROUTE = "/text"
23
 
24
+ class TextClassifier:
25
+ def __init__(self):
26
+ self.config = AutoConfig.from_pretrained("camillebrl/ModernBERT-envclaims-overfit")
27
+ self.label2id = self.config.label2id
28
+ self.classifier = pipeline(
29
+ "text-classification",
30
+ "camillebrl/ModernBERT-envclaims-overfit",
31
+ device="cpu",
32
+ batch_size=16
33
+ )
34
+
35
+ def process_batch(self, batch: List[str]) -> List[int]:
36
+ """
37
+ Process a batch of texts and return their predictions
38
+ """
39
+ try:
40
+ batch_preds = self.classifier(list(batch))
41
+ return [self.label2id[pred[0]["label"]] for pred in batch_preds]
42
+ except Exception as e:
43
+ print(f"Error processing batch: {e}")
44
+ return []
45
+
46
  @router.post(ROUTE, tags=["Text Task"],
47
  description=DESCRIPTION)
48
  async def evaluate_text(request: TextEvaluationRequest):
 
105
  # print(predictions)
106
  # print("final predictions : ", predictions)
107
  # Initialize the model once
108
+ classifier = TextClassifier()
 
 
 
 
 
109
 
110
  # Prepare batches
111
  batch_size = 32
 
122
 
123
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
124
  # Submit all batches for processing
125
+ futures = [
126
+ executor.submit(classifier.process_batch, batch)
127
+ for batch in batches
128
+ ]
129
+
130
+ # Collect results in order
131
+ for future in futures:
132
+ try:
133
+ batch_preds = future.result()
134
+ predictions.extend(batch_preds)
135
+ except Exception as e:
136
+ print(f"Batch processing failed: {e}")
 
 
 
 
 
 
 
 
 
137
 
138
  #--------------------------------------------------------------------------------------------
139
  # YOUR MODEL INFERENCE STOPS HERE