Spaces:
Running
Running
import json | |
from dataclasses import dataclass | |
from streamlit_text_label import Selection | |
class Relation: | |
source: Selection | |
target: Selection | |
label: str | |
def load_data(file_path: str): | |
"""Load data from dev.json""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
def save_data(file_path: str, data): | |
"""Save data to dev.json""" | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(data, f, indent=4) | |
def get_label_color(label): | |
"""Return color based on label type""" | |
color_map = { | |
'OBS-DP': '#FF6B6B', | |
'ANAT-DP': '#4ECDC4', | |
'OBS-U': '#FFD93D', | |
'OBS-DA': '#95A5A6', | |
} | |
return color_map.get(label, '#666666') | |
def word_to_char_position(text, word_index): | |
"""Convert word position to character position""" | |
words = text.split() | |
if word_index >= len(words): | |
return len(text) | |
char_start = 0 | |
for i in range(word_index): | |
char_start += len(words[i]) + 1 | |
return char_start | |
def word_to_char_span(text, start_ix, end_ix): | |
"""Convert word start and end positions to character span""" | |
char_start = word_to_char_position(text, start_ix) | |
if start_ix == end_ix: | |
char_end = char_start + len(text.split()[start_ix]) | |
else: | |
char_end = word_to_char_position(text, end_ix) + len(text.split()[end_ix]) | |
return char_start, char_end | |
def entities2Selection(text, entities_data): | |
"""Convert entities data to Selection objects list""" | |
selections = [] | |
for entity_id, entity in entities_data.items(): | |
char_start, char_end = word_to_char_span( | |
text, | |
entity['start_ix'], | |
entity['end_ix'] | |
) | |
selection = Selection( | |
start=char_start, | |
end=char_end, | |
text=entity['tokens'], | |
labels=[entity['label']], | |
) | |
selections.append(selection) | |
return selections | |
def selection2entities(selections): | |
"""Convert Selection objects list to entities data""" | |
entities = {} | |
for i, selection in enumerate(selections, 1): | |
entities[str(i)] = { | |
"tokens": selection.text, | |
"label": selection.labels[0], | |
"start_ix": selection.start, | |
"end_ix": selection.end, | |
"relations": [] | |
} | |
return entities | |
def find_relations_with_entities(entities, entities_data): | |
"""Find relations between current entities based on original entities_data""" | |
text_to_entity = {e.text: e for e in entities} | |
tokens_to_id = {entity['tokens']: entity_id for entity_id, entity in entities_data.items()} | |
id_to_tokens = {entity_id: entity['tokens'] for entity_id, entity in entities_data.items()} | |
relations = [] | |
for source_text, source_entity in text_to_entity.items(): | |
for entity_id, entity in entities_data.items(): | |
if entity['tokens'] == source_text: | |
for relation in entity.get('relations', []): | |
target_id = relation[1] | |
target_text = id_to_tokens.get(target_id) | |
if target_text and target_text in text_to_entity: | |
relations.append(Relation( | |
source=source_entity, | |
target=text_to_entity[target_text], | |
label=relation[0] | |
)) | |
return relations |