Marroco93 commited on
Commit
266a737
·
1 Parent(s): fe81f5c

no message

Browse files
Files changed (1) hide show
  1. main.py +31 -12
main.py CHANGED
@@ -139,31 +139,50 @@ def segment_text(text: str, max_tokens=500): # Setting a conservative limit bel
139
  return segments
140
 
141
 
142
- classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  def classify_segments(segments):
146
- classifier = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
147
- classified_segments = []
148
-
149
  for segment in segments:
150
  try:
151
- if len(segment.split()) <= 512: # Double-check to avoid errors
152
- result = classifier(segment)
153
- classified_segments.append(result)
154
- else:
155
- classified_segments.append({"error": f"Segment too long: {len(segment.split())} tokens"})
156
  except Exception as e:
157
- classified_segments.append({"error": str(e)})
 
158
 
159
- return classified_segments
160
 
161
 
162
  @app.post("/process_document")
163
  async def process_document(request: TextRequest):
164
  try:
165
  processed_text = preprocess_text(request.text)
166
- segments = segment_text(processed_text)
167
  classified_segments = classify_segments(segments)
168
 
169
  return {
 
139
  return segments
140
 
141
 
142
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
143
+
144
+ def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a buffer
145
+ doc = nlp(text)
146
+ segments = []
147
+ current_segment = []
148
+ current_tokens = []
149
+
150
+ for sent in doc.sents:
151
+ words = sent.text.strip().split()
152
+ sentence_tokens = tokenizer.encode(' '.join(words), add_special_tokens=False)
153
+
154
+ if len(current_tokens) + len(sentence_tokens) > max_tokens:
155
+ if current_tokens:
156
+ segments.append(tokenizer.decode(current_tokens))
157
+ current_segment = words
158
+ current_tokens = sentence_tokens
159
+ else:
160
+ current_segment.extend(words)
161
+ current_tokens.extend(sentence_tokens)
162
+
163
+ if current_tokens:
164
+ segments.append(tokenizer.decode(current_tokens))
165
+
166
+ return segments
167
 
168
 
169
  def classify_segments(segments):
170
+ results = []
 
 
171
  for segment in segments:
172
  try:
173
+ result = classifier(segment)
174
+ results.append(result)
 
 
 
175
  except Exception as e:
176
+ results.append({"error": str(e), "segment": segment[:50]}) # Include a part of the segment to debug if needed
177
+ return results
178
 
 
179
 
180
 
181
  @app.post("/process_document")
182
  async def process_document(request: TextRequest):
183
  try:
184
  processed_text = preprocess_text(request.text)
185
+ segments = robust_segment_text(processed_text)
186
  classified_segments = classify_segments(segments)
187
 
188
  return {