File size: 9,598 Bytes
13c0d56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import chromadb
from langchain import LLMChain, PromptTemplate
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.output_parsers import StrOutputParser
from langchain.embeddings import ZhipuAIEmbeddings
from langchain.vectorstores import Chroma
from diffusers import StableDiffusionPipeline
import requests
import gradio as gr
import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv()) # 读取本地 .env 文件
zhipuai_api_key = os.environ['ZHIPUAI_API_KEY']
class HealthcareAgent:
def __init__(self):
self.vectordb = self.get_vectordb()
self.llm = ChatOpenAI(
model="glm-3-turbo",
temperature=0.7,
openai_api_key=zhipuai_api_key,
openai_api_base="https://open.bigmodel.cn/api/paas/v4/"
)
self.diffusion_model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
def get_vectordb(self):
embedding = ZhipuAIEmbeddings()
persist_directory = '/Users/chenshuyi/Documents/agent/data_base/vector_db'
vectordb = Chroma(
persist_directory=persist_directory,
embedding_function=embedding
)
return vectordb
def generate_response(self, input_text):
output = self.llm.invoke(input_text)
output_parser = StrOutputParser()
output = output_parser.invoke(output)
return output
def rag_search(self, symptoms):
template = """使用以下上下文来回答关于症状的问题。如果你不知道答案,就说你不知道,不要试图编造答案。最多使用三句话。尽量使答案简明扼要。总是在回答的最后说"谢谢你的提问!"。
上下文: {context}
问题: 基于这些症状 "{symptoms}",可能是什么疾病?请列出这些疾病的其他常见症状。
回答格式:
可能的疾病: [疾病1, 疾病2, ...]
其他常见症状: [症状1, 症状2, ...]
回答:"""
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "symptoms"], template=template)
retriever = self.vectordb.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
self.llm,
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
result = qa_chain({"query": symptoms})
return result["result"]
def assess_severity(self, condition, symptoms):
template = """使用以下上下文来评估疾病的严重程度。
上下文: {context}
疾病: {condition}
症状: {symptoms}
请根据给定的疾病和症状,评估病情的严重程度。将严重程度分为轻度、中度和重度三个等级。
同时,请给出这个评估的理由,并提供一些建议。
回答格式:
严重程度: [轻度/中度/重度]
理由: [您的解释]
建议: [您的建议]
回答:"""
QA_CHAIN_PROMPT = PromptTemplate(
input_variables=["context", "condition", "symptoms"],
template=template
)
retriever = self.vectordb.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
self.llm,
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
result = qa_chain({"query": f"{condition} {symptoms}", "condition": condition, "symptoms": symptoms})
return result["result"]
def generate_skin_condition_image(self, condition):
severities = ["轻度", "中度", "重度"]
images = []
for severity in severities:
prompt = f"{severity}{condition}的皮肤症状"
image = self.diffusion_model(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
images.append(image)
return images
def recommend_medical_facility(self, user_location, condition, severity):
# 首先使用LLM推荐医疗设施类型
template = """
基于以下信息推荐合适的医疗设施类型:
疾病: {condition}
严重程度: {severity}
请从以下选项中选择最合适的医疗设施类型:
1. 药房
2. 社区医院
3. 二甲医院
4. 三甲医院
只需回复数字1-4,不需要其他解释。
推荐:
"""
prompt = PromptTemplate(template=template, input_variables=["condition", "severity"])
llm_chain = LLMChain(prompt=prompt, llm=self.llm)
facility_type = llm_chain.run(condition=condition, severity=severity).strip()
# 将LLM的推荐转换为实际的设施类型
facility_types = {
"1": "药房",
"2": "社区医院",
"3": "二甲医院",
"4": "三甲医院"
}
recommended_type = facility_types.get(facility_type, "医院") # 默认为"医院"
# 调用高德地图API搜索附近的医疗设施
amap_key = "您的高德地图API密钥" # 请替换为您的实际API密钥
url = f"https://restapi.amap.com/v3/place/text?key={amap_key}&keywords={recommended_type}&city={user_location}&offset=10&page=1&extensions=all"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
if data["status"] == "1" and data["pois"]:
facilities = data["pois"]
# 返回前三个结果
top_facilities = facilities[:3]
result = f"根据您的情况,我们推荐您去{recommended_type}。以下是附近的几个选择:\n\n"
for facility in top_facilities:
result += f"名称: {facility['name']}\n"
result += f"地址: {facility['address']}\n"
result += f"电话: {facility.get('tel', '未提供')}\n\n"
return result
else:
return f"抱歉,我们无法在您的位置找到合适的{recommended_type}。请考虑寻求紧急医疗帮助或咨询当地卫生部门。"
else:
return "抱歉,我们暂时无法获取医疗设施信息。请稍后再试或直接联系当地医疗机构。"
def interact(self, symptoms, user_location):
condition = self.rag_search(symptoms)
if "皮肤" in condition:
images = self.generate_skin_condition_image(condition)
return condition, images, True, None # 添加None作为医疗设施推荐的占位符
else:
severity_assessment = self.assess_severity(condition, symptoms)
severity, reason, advice = self.parse_severity_result(severity_assessment)
facility_recommendation = self.recommend_medical_facility(user_location, condition, severity)
return condition, (severity, reason, advice), False, facility_recommendation
def parse_severity_result(self, result):
# 这个函数需要根据实际的输出格式来实现
# 这里只是一个示例
lines = result.split('\n')
severity = ""
reason = ""
advice = ""
for line in lines:
if line.startswith("严重程度:"):
severity = line.split(':')[1].strip()
elif line.startswith("理由:"):
reason = line.split(':')[1].strip()
elif line.startswith("建议:"):
advice = line.split(':')[1].strip()
return severity, reason, advice
def gradio_interface():
agent = HealthcareAgent()
def process_input(symptoms, user_location):
condition, result, is_skin_condition, facility_recommendation = agent.interact(symptoms, user_location)
if is_skin_condition:
return gr.update(visible=True, value=condition), gr.update(visible=True, value=result), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)
else:
severity, reason, advice = result
return gr.update(visible=True, value=f"诊断: {condition}\n严重程度: {severity}\n理由: {reason}\n建议: {advice}"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True, value=facility_recommendation)
def on_select(evt: gr.SelectData):
severities = ["轻度", "中度", "重度"]
return f"您选择的严重程度为: {severities[evt.index]}"
with gr.Blocks() as iface:
gr.Markdown("# 医疗保健助手")
symptoms_input = gr.Textbox(label="请描述您的症状")
location_input = gr.Textbox(label="请输入您的位置")
submit_btn = gr.Button("提交")
with gr.Group() as output_group:
text_output = gr.Textbox(label="诊断结果", visible=False)
image_output = gr.Gallery(label="请选择最接近您症状的图片", visible=False, columns=3, height=300)
severity_output = gr.Textbox(label="严重程度", visible=False)
facility_output = gr.Textbox(label="推荐医疗设施", visible=False)
submit_btn.click(process_input, inputs=[symptoms_input, location_input], outputs=[text_output, image_output, severity_output, facility_output])
image_output.select(on_select, None, severity_output)
return iface
if __name__ == "__main__":
iface = gradio_interface()
iface.launch()
|