Spaces:
Sleeping
Sleeping
no message
Browse files
main.py
CHANGED
@@ -139,31 +139,50 @@ def segment_text(text: str, max_tokens=500): # Setting a conservative limit bel
|
|
139 |
return segments
|
140 |
|
141 |
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
|
145 |
def classify_segments(segments):
|
146 |
-
|
147 |
-
classified_segments = []
|
148 |
-
|
149 |
for segment in segments:
|
150 |
try:
|
151 |
-
|
152 |
-
|
153 |
-
classified_segments.append(result)
|
154 |
-
else:
|
155 |
-
classified_segments.append({"error": f"Segment too long: {len(segment.split())} tokens"})
|
156 |
except Exception as e:
|
157 |
-
|
|
|
158 |
|
159 |
-
return classified_segments
|
160 |
|
161 |
|
162 |
@app.post("/process_document")
|
163 |
async def process_document(request: TextRequest):
|
164 |
try:
|
165 |
processed_text = preprocess_text(request.text)
|
166 |
-
segments =
|
167 |
classified_segments = classify_segments(segments)
|
168 |
|
169 |
return {
|
|
|
139 |
return segments
|
140 |
|
141 |
|
142 |
+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
|
143 |
+
|
144 |
+
def robust_segment_text(text: str, max_tokens=510): # Slightly less to ensure a buffer
|
145 |
+
doc = nlp(text)
|
146 |
+
segments = []
|
147 |
+
current_segment = []
|
148 |
+
current_tokens = []
|
149 |
+
|
150 |
+
for sent in doc.sents:
|
151 |
+
words = sent.text.strip().split()
|
152 |
+
sentence_tokens = tokenizer.encode(' '.join(words), add_special_tokens=False)
|
153 |
+
|
154 |
+
if len(current_tokens) + len(sentence_tokens) > max_tokens:
|
155 |
+
if current_tokens:
|
156 |
+
segments.append(tokenizer.decode(current_tokens))
|
157 |
+
current_segment = words
|
158 |
+
current_tokens = sentence_tokens
|
159 |
+
else:
|
160 |
+
current_segment.extend(words)
|
161 |
+
current_tokens.extend(sentence_tokens)
|
162 |
+
|
163 |
+
if current_tokens:
|
164 |
+
segments.append(tokenizer.decode(current_tokens))
|
165 |
+
|
166 |
+
return segments
|
167 |
|
168 |
|
169 |
def classify_segments(segments):
|
170 |
+
results = []
|
|
|
|
|
171 |
for segment in segments:
|
172 |
try:
|
173 |
+
result = classifier(segment)
|
174 |
+
results.append(result)
|
|
|
|
|
|
|
175 |
except Exception as e:
|
176 |
+
results.append({"error": str(e), "segment": segment[:50]}) # Include a part of the segment to debug if needed
|
177 |
+
return results
|
178 |
|
|
|
179 |
|
180 |
|
181 |
@app.post("/process_document")
|
182 |
async def process_document(request: TextRequest):
|
183 |
try:
|
184 |
processed_text = preprocess_text(request.text)
|
185 |
+
segments = robust_segment_text(processed_text)
|
186 |
classified_segments = classify_segments(segments)
|
187 |
|
188 |
return {
|