import gradio as gr from transformers import pipeline from transformers import AutoTokenizer models = { "RUPunct-small": "RUPunct/RUPunct_small", "RUPunct-big": "RUPunct/RUPunct_big", "RUPunct-medium": "RUPunct/RUPunct_medium" } pipelines = {} for model_name, model_path in models.items(): tokenizer = AutoTokenizer.from_pretrained(model_path, strip_accents=False, add_prefix_space=True) pipelines[model_name] = pipeline("ner", model=model_path, tokenizer=tokenizer, aggregation_strategy="first") def process_token(token, label): if label == "LOWER_O": return token if label == "LOWER_PERIOD": return token + "." if label == "LOWER_COMMA": return token + "," if label == "LOWER_QUESTION": return token + "?" if label == "LOWER_TIRE": return token + "—" if label == "LOWER_DVOETOCHIE": return token + ":" if label == "LOWER_VOSKL": return token + "!" if label == "LOWER_PERIODCOMMA": return token + ";" if label == "LOWER_DEFIS": return token + "-" if label == "LOWER_MNOGOTOCHIE": return token + "..." if label == "LOWER_QUESTIONVOSKL": return token + "?!" if label == "UPPER_O": return token.capitalize() if label == "UPPER_PERIOD": return token.capitalize() + "." if label == "UPPER_COMMA": return token.capitalize() + "," if label == "UPPER_QUESTION": return token.capitalize() + "?" if label == "UPPER_TIRE": return token.capitalize() + " —" if label == "UPPER_DVOETOCHIE": return token.capitalize() + ":" if label == "UPPER_VOSKL": return token.capitalize() + "!" if label == "UPPER_PERIODCOMMA": return token.capitalize() + ";" if label == "UPPER_DEFIS": return token.capitalize() + "-" if label == "UPPER_MNOGOTOCHIE": return token.capitalize() + "..." if label == "UPPER_QUESTIONVOSKL": return token.capitalize() + "?!" if label == "UPPER_TOTAL_O": return token.upper() if label == "UPPER_TOTAL_PERIOD": return token.upper() + "." if label == "UPPER_TOTAL_COMMA": return token.upper() + "," if label == "UPPER_TOTAL_QUESTION": return token.upper() + "?" if label == "UPPER_TOTAL_TIRE": return token.upper() + " —" if label == "UPPER_TOTAL_DVOETOCHIE": return token.upper() + ":" if label == "UPPER_TOTAL_VOSKL": return token.upper() + "!" if label == "UPPER_TOTAL_PERIODCOMMA": return token.upper() + ";" if label == "UPPER_TOTAL_DEFIS": return token.upper() + "-" if label == "UPPER_TOTAL_MNOGOTOCHIE": return token.upper() + "..." if label == "UPPER_TOTAL_QUESTIONVOSKL": return token.upper() + "?!" def punctuate(input_text, model_name): classifier = pipelines[model_name] preds = classifier(input_text) output = "" for item in preds: if item["word"] == ".": item["entity_group"] = "LOWER_O" output += " " + process_token(item['word'].strip(), item['entity_group']) return output.strip() iface = gr.Interface( fn=punctuate, inputs=[ gr.components.Textbox(lines=5, placeholder="Введите текст"), gr.components.Radio(list(models.keys()), label="Модель") ], outputs="text", title="RUPunct", description="Демо RUPunct - модели для автоматической расстановки знаков препинания в русском тексте.", ) iface.launch()