medKGC / app_logic.py
hanbinChen's picture
update
1a1c17c
import json
from dataclasses import dataclass
from streamlit_text_label import Selection
@dataclass
class Relation:
source: Selection
target: Selection
label: str
def load_data(file_path: str):
"""Load data from dev.json"""
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
def save_data(file_path: str, data):
"""Save data to dev.json"""
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4)
def get_label_color(label):
"""Return color based on label type"""
color_map = {
'OBS-DP': '#FF6B6B',
'ANAT-DP': '#4ECDC4',
'OBS-U': '#FFD93D',
'OBS-DA': '#95A5A6',
}
return color_map.get(label, '#666666')
def word_to_char_position(text, word_index):
"""Convert word position to character position"""
words = text.split()
if word_index >= len(words):
return len(text)
char_start = 0
for i in range(word_index):
char_start += len(words[i]) + 1
return char_start
def word_to_char_span(text, start_ix, end_ix):
"""Convert word start and end positions to character span"""
char_start = word_to_char_position(text, start_ix)
if start_ix == end_ix:
char_end = char_start + len(text.split()[start_ix])
else:
char_end = word_to_char_position(text, end_ix) + len(text.split()[end_ix])
return char_start, char_end
def entities2Selection(text, entities_data):
"""Convert entities data to Selection objects list"""
selections = []
for entity_id, entity in entities_data.items():
char_start, char_end = word_to_char_span(
text,
entity['start_ix'],
entity['end_ix']
)
selection = Selection(
start=char_start,
end=char_end,
text=entity['tokens'],
labels=[entity['label']],
)
selections.append(selection)
return selections
def selection2entities(selections):
"""Convert Selection objects list to entities data"""
entities = {}
for i, selection in enumerate(selections, 1):
entities[str(i)] = {
"tokens": selection.text,
"label": selection.labels[0],
"start_ix": selection.start,
"end_ix": selection.end,
"relations": []
}
return entities
def find_relations_with_entities(entities, entities_data):
"""Find relations between current entities based on original entities_data"""
text_to_entity = {e.text: e for e in entities}
tokens_to_id = {entity['tokens']: entity_id for entity_id, entity in entities_data.items()}
id_to_tokens = {entity_id: entity['tokens'] for entity_id, entity in entities_data.items()}
relations = []
for source_text, source_entity in text_to_entity.items():
for entity_id, entity in entities_data.items():
if entity['tokens'] == source_text:
for relation in entity.get('relations', []):
target_id = relation[1]
target_text = id_to_tokens.get(target_id)
if target_text and target_text in text_to_entity:
relations.append(Relation(
source=source_entity,
target=text_to_entity[target_text],
label=relation[0]
))
return relations