camillebrl commited on
Commit
93741cc
·
verified ·
1 Parent(s): f5ac2a0

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +38 -17
tasks/text.py CHANGED
@@ -32,16 +32,26 @@ class TextClassifier:
32
  batch_size=16
33
  )
34
 
35
- def process_batch(self, batch: List[str]) -> List[int]:
36
  """
37
- Process a batch of texts and return their predictions
 
 
 
 
 
 
 
38
  """
39
  try:
 
40
  batch_preds = self.classifier(list(batch))
41
- return [self.label2id[pred[0]["label"]] for pred in batch_preds]
 
 
42
  except Exception as e:
43
- print(f"Error processing batch: {e}")
44
- return []
45
 
46
  @router.post(ROUTE, tags=["Text Task"],
47
  description=DESCRIPTION)
@@ -122,18 +132,29 @@ async def evaluate_text(request: TextEvaluationRequest):
122
 
123
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
124
  # Submit all batches for processing
125
- futures = [
126
- executor.submit(classifier.process_batch, batch)
127
- for batch in batches
128
- ]
129
-
130
- # Collect results in order
131
- for future in futures:
132
- try:
133
- batch_preds = future.result()
134
- predictions.extend(batch_preds)
135
- except Exception as e:
136
- print(f"Batch processing failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  #--------------------------------------------------------------------------------------------
139
  # YOUR MODEL INFERENCE STOPS HERE
 
32
  batch_size=16
33
  )
34
 
35
+ def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
36
  """
37
+ Process a batch of texts and return their predictions along with batch index
38
+
39
+ Args:
40
+ batch: List of texts to process
41
+ batch_idx: Index of the current batch
42
+
43
+ Returns:
44
+ Tuple containing list of predictions and batch index
45
  """
46
  try:
47
+ print(f"Processing batch {batch_idx} with {len(batch)} items")
48
  batch_preds = self.classifier(list(batch))
49
+ predictions = [self.label2id[pred[0]["label"]] for pred in batch_preds]
50
+ print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
51
+ return predictions, batch_idx
52
  except Exception as e:
53
+ print(f"Error in batch {batch_idx}: {str(e)}")
54
+ return [], batch_idx
55
 
56
  @router.post(ROUTE, tags=["Text Task"],
57
  description=DESCRIPTION)
 
132
 
133
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
134
  # Submit all batches for processing
135
+ future_to_batch = {
136
+ executor.submit(
137
+ classifier.process_batch,
138
+ batch,
139
+ idx
140
+ ): idx for idx, batch in enumerate(batches)
141
+ }
142
+
143
+ # Collect results in order
144
+ for future in future_to_batch:
145
+ batch_idx = future_to_batch[future]
146
+ try:
147
+ predictions, idx = future.result()
148
+ batch_results[idx] = predictions
149
+ print(f"Stored results for batch {idx}")
150
+ except Exception as e:
151
+ print(f"Failed to get results for batch {batch_idx}: {e}")
152
+ batch_results[batch_idx] = []
153
+
154
+ # Flatten predictions while maintaining order
155
+ all_predictions = [pred for batch_preds in batch_results for pred in batch_preds]
156
+ print(f"Total predictions collected: {len(all_predictions)}")
157
+
158
 
159
  #--------------------------------------------------------------------------------------------
160
  # YOUR MODEL INFERENCE STOPS HERE