Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
import spacy | |
from textblob import TextBlob | |
from gradio_client import Client | |
import re | |
# Initialize models | |
nlp = spacy.load("en_core_web_sm") | |
spell_checker = pipeline("text2text-generation", model="oliverguhr/spelling-correction-english-base") | |
def preprocess_capitalization(text: str) -> str: | |
"""Preprocess input text to handle capitalization rules.""" | |
words = text.split(" ") | |
processed_words = [] | |
for word in words: | |
# Check if the word is an acronym (all uppercase letters) | |
if re.match(r"^[A-Z]+$", word): | |
processed_words.append(word) # Leave acronyms unchanged | |
# Check if the word has mixed capitalization (e.g., "HEllo") | |
elif re.search(r"[A-Z]", word) and re.search(r"[a-z]", word): | |
processed_words.append(word[0].upper() + word[1:].lower()) # Correct capitalization | |
else: | |
processed_words.append(word) # Leave other words unchanged | |
return " ".join(processed_words) | |
def preprocess_text(text: str): | |
"""Process text and return corrections with position information.""" | |
result = { | |
"spell_suggestions": [], | |
"entities": [], | |
"tags": [] | |
} | |
# Apply capitalization preprocessing | |
capitalized_text = preprocess_capitalization(text) | |
if capitalized_text != text: | |
result["spell_suggestions"].append({ | |
"original": text, | |
"corrected": capitalized_text | |
}) | |
text = capitalized_text # Update text for further processing | |
# Find and record positions of corrections | |
doc = nlp(text) | |
# TextBlob spell check with position tracking | |
blob = TextBlob(text) | |
corrected = str(blob.correct()) | |
if corrected != text: | |
result["spell_suggestions"].append({ | |
"original": text, | |
"corrected": corrected | |
}) | |
# Transformer spell check | |
spell_checked = spell_checker(text, max_length=512)[0]['generated_text'] | |
if spell_checked != text and spell_checked != corrected: | |
result["spell_suggestions"].append({ | |
"original": text, | |
"corrected": spell_checked | |
}) | |
# Add entities and tags | |
result["entities"] = [{"text": ent.text, "label": ent.label_} for ent in doc.ents] | |
result["tags"] = [token.text for token in doc if token.text.startswith(('#', '@'))] | |
return text, result | |
def preprocess_and_forward(text: str): | |
"""Process text and forward to translation service.""" | |
original_text, preprocessing_result = preprocess_text(text) | |
# Forward original text to translation service | |
client = Client("Frenchizer/space_17") | |
try: | |
translation = client.predict(original_text) | |
return translation, preprocessing_result | |
except Exception as e: | |
return f"Error: {str(e)}", preprocessing_result | |
# Gradio interface | |
with gr.Blocks() as demo: | |
input_text = gr.Textbox(label="Input Text") | |
output_text = gr.Textbox(label="Output Text") | |
preprocess_button = gr.Button("Process") | |
preprocess_button.click(fn=preprocess_and_forward, inputs=[input_text], outputs=[output_text]) | |
if __name__ == "__main__": | |
demo.launch() |