File size: 3,435 Bytes
9985fd7
 
 
 
 
 
 
 
 
 
1a1c17c
9985fd7
1a1c17c
9985fd7
 
1a1c17c
9985fd7
1a1c17c
9985fd7
 
 
 
 
1a1c17c
 
 
 
9985fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
from dataclasses import dataclass
from streamlit_text_label import Selection

@dataclass
class Relation:
    source: Selection
    target: Selection
    label: str

def load_data(file_path: str):
    """Load data from dev.json"""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_data(file_path: str, data):
    """Save data to dev.json"""
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=4)

def get_label_color(label):
    """Return color based on label type"""
    color_map = {
        'OBS-DP': '#FF6B6B',
        'ANAT-DP': '#4ECDC4',
        'OBS-U': '#FFD93D',
        'OBS-DA': '#95A5A6',
    }
    return color_map.get(label, '#666666')

def word_to_char_position(text, word_index):
    """Convert word position to character position"""
    words = text.split()
    if word_index >= len(words):
        return len(text)
    char_start = 0
    for i in range(word_index):
        char_start += len(words[i]) + 1
    return char_start

def word_to_char_span(text, start_ix, end_ix):
    """Convert word start and end positions to character span"""
    char_start = word_to_char_position(text, start_ix)
    if start_ix == end_ix:
        char_end = char_start + len(text.split()[start_ix])
    else:
        char_end = word_to_char_position(text, end_ix) + len(text.split()[end_ix])
    return char_start, char_end

def entities2Selection(text, entities_data):
    """Convert entities data to Selection objects list"""
    selections = []
    for entity_id, entity in entities_data.items():
        char_start, char_end = word_to_char_span(
            text,
            entity['start_ix'],
            entity['end_ix']
        )
        selection = Selection(
            start=char_start,
            end=char_end,
            text=entity['tokens'],
            labels=[entity['label']],
        )
        selections.append(selection)
    return selections

def selection2entities(selections):
    """Convert Selection objects list to entities data"""
    entities = {}
    for i, selection in enumerate(selections, 1):
        entities[str(i)] = {
            "tokens": selection.text,
            "label": selection.labels[0],
            "start_ix": selection.start,
            "end_ix": selection.end,
            "relations": []
        }
    return entities

def find_relations_with_entities(entities, entities_data):
    """Find relations between current entities based on original entities_data"""
    text_to_entity = {e.text: e for e in entities}
    tokens_to_id = {entity['tokens']: entity_id for entity_id, entity in entities_data.items()}
    id_to_tokens = {entity_id: entity['tokens'] for entity_id, entity in entities_data.items()}

    relations = []
    for source_text, source_entity in text_to_entity.items():
        for entity_id, entity in entities_data.items():
            if entity['tokens'] == source_text:
                for relation in entity.get('relations', []):
                    target_id = relation[1]
                    target_text = id_to_tokens.get(target_id)
                    if target_text and target_text in text_to_entity:
                        relations.append(Relation(
                            source=source_entity,
                            target=text_to_entity[target_text],
                            label=relation[0]
                        ))
    return relations