File size: 9,598 Bytes
13c0d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import chromadb
from langchain import LLMChain, PromptTemplate
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.output_parsers import StrOutputParser
from langchain.embeddings import ZhipuAIEmbeddings
from langchain.vectorstores import Chroma
from diffusers import StableDiffusionPipeline
import requests
import gradio as gr
import os
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())    # 读取本地 .env 文件
zhipuai_api_key = os.environ['ZHIPUAI_API_KEY']

class HealthcareAgent:
    def __init__(self):
        self.vectordb = self.get_vectordb()
        self.llm = ChatOpenAI(
            model="glm-3-turbo",
            temperature=0.7,
            openai_api_key=zhipuai_api_key,
            openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
        )
        self.diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")

    def get_vectordb(self):
        embedding = ZhipuAIEmbeddings()
        persist_directory = '/Users/chenshuyi/Documents/agent/data_base/vector_db'
        vectordb = Chroma(
            persist_directory=persist_directory,
            embedding_function=embedding
        )
        return vectordb

    def generate_response(self, input_text):
        output = self.llm.invoke(input_text)
        output_parser = StrOutputParser()
        output = output_parser.invoke(output)
        return output

    def rag_search(self, symptoms):
        template = """使用以下上下文来回答关于症状的问题。如果你不知道答案,就说你不知道,不要试图编造答案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说"谢谢你的提问!"。
        上下文: {context}
        问题: 基于这些症状 "{symptoms}",可能是什么疾病?请列出这些疾病的其他常见症状。
        回答格式:
        可能的疾病: [疾病1, 疾病2, ...]
        其他常见症状: [症状1, 症状2, ...]
        回答:"""
        QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "symptoms"], template=template)
        retriever = self.vectordb.as_retriever()
        qa_chain = RetrievalQA.from_chain_type(
            self.llm,
            retriever=retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
        )
        result = qa_chain({"query": symptoms})
        return result["result"]
    
    def assess_severity(self, condition, symptoms):
        template = """使用以下上下文来评估疾病的严重程度。
        上下文: {context}
        疾病: {condition}
        症状: {symptoms}
        请根据给定的疾病和症状,评估病情的严重程度。将严重程度分为轻度、中度和重度三个等级。
        同时,请给出这个评估的理由,并提供一些建议。
        回答格式:
        严重程度: [轻度/中度/重度]
        理由: [您的解释]
        建议: [您的建议]
        回答:"""
        QA_CHAIN_PROMPT = PromptTemplate(
            input_variables=["context", "condition", "symptoms"], 
            template=template
        )
        retriever = self.vectordb.as_retriever()
        qa_chain = RetrievalQA.from_chain_type(
            self.llm,
            retriever=retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
        )
        result = qa_chain({"query": f"{condition} {symptoms}", "condition": condition, "symptoms": symptoms})
        return result["result"]

    def generate_skin_condition_image(self, condition):
        severities = ["轻度", "中度", "重度"]
        images = []
        for severity in severities:
            prompt = f"{severity}{condition}的皮肤症状"
            image = self.diffusion_model(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
            images.append(image)
        return images

    def recommend_medical_facility(self, user_location, condition, severity):
        # 首先使用LLM推荐医疗设施类型
        template = """
        基于以下信息推荐合适的医疗设施类型:
        
        疾病: {condition}
        严重程度: {severity}
        
        请从以下选项中选择最合适的医疗设施类型:
        1. 药房
        2. 社区医院
        3. 二甲医院
        4. 三甲医院
        
        只需回复数字1-4,不需要其他解释。
        
        推荐:
        """
        
        prompt = PromptTemplate(template=template, input_variables=["condition", "severity"])
        llm_chain = LLMChain(prompt=prompt, llm=self.llm)
        facility_type = llm_chain.run(condition=condition, severity=severity).strip()
        
        # 将LLM的推荐转换为实际的设施类型
        facility_types = {
            "1": "药房",
            "2": "社区医院",
            "3": "二甲医院",
            "4": "三甲医院"
        }
        recommended_type = facility_types.get(facility_type, "医院")  # 默认为"医院"
        
        # 调用高德地图API搜索附近的医疗设施
        amap_key = "您的高德地图API密钥"  # 请替换为您的实际API密钥
        url = f"https://restapi.amap.com/v3/place/text?key={amap_key}&keywords={recommended_type}&city={user_location}&offset=10&page=1&extensions=all"
        
        response = requests.get(url)
        if response.status_code == 200:
            data = response.json()
            if data["status"] == "1" and data["pois"]:
                facilities = data["pois"]
                # 返回前三个结果
                top_facilities = facilities[:3]
                result = f"根据您的情况,我们推荐您去{recommended_type}。以下是附近的几个选择:\n\n"
                for facility in top_facilities:
                    result += f"名称: {facility['name']}\n"
                    result += f"地址: {facility['address']}\n"
                    result += f"电话: {facility.get('tel', '未提供')}\n\n"
                return result
            else:
                return f"抱歉,我们无法在您的位置找到合适的{recommended_type}。请考虑寻求紧急医疗帮助或咨询当地卫生部门。"
        else:
            return "抱歉,我们暂时无法获取医疗设施信息。请稍后再试或直接联系当地医疗机构。"

    def interact(self, symptoms, user_location):
        condition = self.rag_search(symptoms)
        
        if "皮肤" in condition:
            images = self.generate_skin_condition_image(condition)
            return condition, images, True, None  # 添加None作为医疗设施推荐的占位符
        else:
            severity_assessment = self.assess_severity(condition, symptoms)
            severity, reason, advice = self.parse_severity_result(severity_assessment)
            facility_recommendation = self.recommend_medical_facility(user_location, condition, severity)
            return condition, (severity, reason, advice), False, facility_recommendation

    def parse_severity_result(self, result):
        # 这个函数需要根据实际的输出格式来实现
        # 这里只是一个示例
        lines = result.split('\n')
        severity = ""
        reason = ""
        advice = ""
        for line in lines:
            if line.startswith("严重程度:"):
                severity = line.split(':')[1].strip()
            elif line.startswith("理由:"):
                reason = line.split(':')[1].strip()
            elif line.startswith("建议:"):
                advice = line.split(':')[1].strip()
        return severity, reason, advice

def gradio_interface():
    agent = HealthcareAgent()

    def process_input(symptoms, user_location):
        condition, result, is_skin_condition, facility_recommendation = agent.interact(symptoms, user_location)
        if is_skin_condition:
            return gr.update(visible=True, value=condition), gr.update(visible=True, value=result), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)
        else:
            severity, reason, advice = result
            return gr.update(visible=True, value=f"诊断: {condition}\n严重程度: {severity}\n理由: {reason}\n建议: {advice}"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)

    def on_select(evt: gr.SelectData):
        severities = ["轻度", "中度", "重度"]
        return f"您选择的严重程度为: {severities[evt.index]}"

    with gr.Blocks() as iface:
        gr.Markdown("# 医疗保健助手")
        symptoms_input = gr.Textbox(label="请描述您的症状")
        location_input = gr.Textbox(label="请输入您的位置")
        submit_btn = gr.Button("提交")
        
        with gr.Group() as output_group:
            text_output = gr.Textbox(label="诊断结果", visible=False)
            image_output = gr.Gallery(label="请选择最接近您症状的图片", visible=False, columns=3, height=300)
            severity_output = gr.Textbox(label="严重程度", visible=False)
            facility_output = gr.Textbox(label="推荐医疗设施", visible=False)

        submit_btn.click(process_input, inputs=[symptoms_input, location_input], outputs=[text_output, image_output, severity_output, facility_output])
        image_output.select(on_select, None, severity_output)

    return iface

if __name__ == "__main__":
    iface = gradio_interface()
    iface.launch()