File size: 5,994 Bytes
a06ac21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
医学放射学报告命名实体识别(NER)评估脚本

该脚本用于从放射学报告中提取临床术语实体,包括解剖部位(ANAT)和观察结果(OBS)。
主要功能:
1. 使用LLM模型进行实体抽取
2. 支持few-shot学习,可以从训练数据中加载示例
3. 包含错误重试机制
4. 结果输出格式化处理

实体标签说明:
- OBS-DP: 明确存在的观察结果
- ANAT-DP: 明确存在的解剖部位
- OBS-U: 不确定的观察结果
- OBS-DA: 明确不存在的观察结果

输出格式:
{
    "1": {"tokens": "...", "label": "..."},
    "2": {"tokens": "...", "label": "..."},
    ...
}
"""

import json
import requests
import pprint
import time  # 添加time模块用于重试机制
from typing import Dict, List, Any  # 添加类型提示支持


# Prepare prompt with instructions
system_prompt = '''You are a radiologist performing clinical term extraction from the FINDINGS and IMPRESSION sections 
    in the radiology report. Here a clinical term can be either anatomy or observation that is related to a finding or an impression. The anatomy term refers to an anatomical body part such as a 'lung'. The observation terms refer to observations made when referring to the associated radiology image. Observations are associated with visual features, identifiable pathophysiologic processes, or diagnostic disease classifications. For example, an observation could be 'effusion' or description phrases like 'increased'. You also need to assign a label to indicate whether the clinical term is present, absent or uncertain. The labels are:
    - OBS-DP: Observation definitely present
    - ANAT-DP: Anatomy definitely present  
    - OBS-U: Observation uncertain
    - OBS-DA: Observation definitely absent

    Given a piece of radiology text input in the format:

    <INPUT>
    <text>
    </INPUT>

    reply with the following structure:

    <OUTPUT>
    ANSWER: tuples separated by newlines. Each tuple has the format: (<clinical term text>, <label: observation-present |observation-absent|observation-uncertain|anatomy-present>). If there are no extraction related to findings or impression, return ()
    </OUTPUT>
    '''

# Create a session object for reuse
session = requests.Session()


def get_question_prompt(text):
    return f'''<INPUT>
        {text}
        </INPUT> 

        What are the clinical terms and their labels in this text? Discard sections other than FINDINGS and IMPRESSION: eg. INDICATION, HISTORY, TECHNIQUE, COMPARISON sections. If there is no extraction from findings and impression, return (). Please only output the tuples without additional notes or explanations.

        <OUTPUT> ANSWER:
        '''


def extract_entities(text, num_shots=5):
    try:
        url = 'http://ollama.corinth.informatik.rwth-aachen.de/api/chat'

        messages = [{'role': 'system', 'content': system_prompt}]
        messages = create_messages_with_shots(num_shots, messages)
        messages.append({'role': 'user', 'content': get_question_prompt(text)})

        data = {
            "model": "llama3.1:8b",
            "messages": messages,
            "stream": False
        }

        response = session.post(url, json=data)
        response.raise_for_status()  # 检查HTTP错误
        entities = response.json()['message']['content']

        # Parse response into JSON format
        result = {}
        lines = entities.split('\n')
        for i, line in enumerate(lines):
            if '(' in line and ')' in line:
                try:
                    term, label = eval(line.strip())
                    result[str(i+1)] = {
                        "tokens": term,
                        "label": label
                    }
                except (ValueError, SyntaxError):
                    continue

        return result

    except Exception as e:
        raise e


def create_messages_with_shots(num_shots, messages):
    """
    从train.json创建few-shot示例消息列表

    Args:
        num_shots (int): 需要的few-shot示例数量
        messages (list): 现有的消息列表

    Returns:
        list: 包含few-shot示例的消息列表
    """
    # 加载训练数据
    with open('train.json', 'r') as f:
        train_data = json.load(f)

    # 获取前num_shots个示例
    shot_count = 0
    for key, value in train_data.items():
        if shot_count >= num_shots:
            break

        text = value['text']
        entities = value['entities']

        # 构建答案字符串
        answer_parts = []
        for entity_id, entity_info in entities.items():
            tokens = entity_info['tokens']
            label = entity_info['label']
            answer_parts.append(f"('{tokens}', '{label}')")

        answer = '\n'.join(answer_parts)
        answer += '\n</OUTPUT>'

        # 添加用户问题
        messages.append({
            'role': 'user',
            'content': get_question_prompt(text)
        })

        # 添加助手回答
        messages.append({
            'role': 'assistant',
            'content': answer
        })

        shot_count += 1

    return messages


def main():
    # Load test data
    with open('dev.json', 'r') as f:
        data = json.load(f)

    # Get first example
    for key, value in data.items():
        text = value['text']
        # 移除未使用的shots_path参数,添加num_shots参数
        entities = extract_entities(text, num_shots=50)
        break

    print(entities)


if __name__ == '__main__':
    main()


def process_sample(text: str, num_shots: int) -> list:
    """处理单个样本,包含重试机制"""
    max_retries = 3

    for attempt in range(max_retries):
        try:
            entities = extract_entities(text, num_shots)
            return entities
        except Exception as e:
            if attempt < max_retries - 1:
                print(f'第{attempt + 1}次尝试失败: {e}')
                continue
            else:
                raise e