ClinicalTrialV2 / app.py
Satoc's picture
Fix Gradio TypeError and improve stability
c46f583
raw
history blame
11.3 kB
import gradio as gr
import pandas as pd
import time
import traceback
from OpenAITools.FetchTools import fetch_clinical_trials
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from OpenAITools.CrinicalTrialTools import SimpleClinicalTrialAgent, GraderAgent, LLMTranslator, generate_ex_question_English
# モデルとエージェントの初期化
groq = ChatGroq(model_name="llama3-70b-8192", temperature=0)
translator = LLMTranslator(groq)
CriteriaCheckAgent = SimpleClinicalTrialAgent(groq)
grader_agent = GraderAgent(groq)
# エラーハンドリング付きでエージェント評価を実行する関数
def evaluate_with_retry(agent, criteria, question, max_retries=3):
"""エラーハンドリング付きでエージェント評価を実行"""
for attempt in range(max_retries):
try:
return agent.evaluate_eligibility(criteria, question)
except Exception as e:
if "missing variables" in str(e):
# プロンプトテンプレートの変数エラーの場合
print(f"プロンプトテンプレートエラー (試行 {attempt + 1}/{max_retries}): {e}")
return "評価エラー: プロンプトテンプレートの設定に問題があります"
elif "no healthy upstream" in str(e) or "InternalServerError" in str(e):
# Groqサーバーエラーの場合
print(f"Groqサーバーエラー (試行 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(2) # 2秒待機してリトライ
continue
else:
return "評価エラー: サーバーに接続できませんでした"
else:
print(f"予期しないエラー (試行 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(1)
continue
else:
return f"評価エラー: {str(e)}"
return "評価エラー: 最大リトライ回数に達しました"
def evaluate_grade_with_retry(agent, judgment, max_retries=3):
"""エラーハンドリング付きでグレード評価を実行"""
for attempt in range(max_retries):
try:
return agent.evaluate_eligibility(judgment)
except Exception as e:
if "no healthy upstream" in str(e) or "InternalServerError" in str(e):
print(f"Groqサーバーエラー (グレード評価 - 試行 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(2)
continue
else:
return "unclear"
else:
print(f"予期しないエラー (グレード評価 - 試行 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(1)
continue
else:
return "unclear"
return "unclear"
# データフレームを生成する関数
def generate_dataframe(age, sex, tumor_type, GeneMutation, Meseable, Biopsiable):
try:
# 入力検証
if not all([age, sex, tumor_type]):
return pd.DataFrame(), pd.DataFrame()
# 日本語の腫瘍タイプを英語に翻訳
try:
TumorName = translator.translate(tumor_type)
except Exception as e:
print(f"翻訳エラー: {e}")
TumorName = tumor_type # 翻訳に失敗した場合は元の値を使用
# 質問文を生成
try:
ex_question = generate_ex_question_English(age, sex, TumorName, GeneMutation, Meseable, Biopsiable)
except Exception as e:
print(f"質問生成エラー: {e}")
return pd.DataFrame(), pd.DataFrame()
# 臨床試験データの取得
try:
df = fetch_clinical_trials(TumorName)
if df.empty:
print("臨床試験データが見つかりませんでした")
return pd.DataFrame(), pd.DataFrame()
except Exception as e:
print(f"臨床試験データ取得エラー: {e}")
return pd.DataFrame(), pd.DataFrame()
df['AgentJudgment'] = None
df['AgentGrade'] = None
# 臨床試験の適格性の評価
NCTIDs = list(df['NCTID'])
progress = gr.Progress(track_tqdm=True)
for i, nct_id in enumerate(NCTIDs):
try:
target_criteria = df.loc[df['NCTID'] == nct_id, 'Eligibility Criteria'].values[0]
# エラーハンドリング付きで評価実行
agent_judgment = evaluate_with_retry(CriteriaCheckAgent, target_criteria, ex_question)
agent_grade = evaluate_grade_with_retry(grader_agent, agent_judgment)
# データフレームの更新
df.loc[df['NCTID'] == nct_id, 'AgentJudgment'] = agent_judgment
df.loc[df['NCTID'] == nct_id, 'AgentGrade'] = agent_grade
except Exception as e:
print(f"NCTID {nct_id} の評価中にエラー: {e}")
df.loc[df['NCTID'] == nct_id, 'AgentJudgment'] = f"エラー: {str(e)}"
df.loc[df['NCTID'] == nct_id, 'AgentGrade'] = "unclear"
progress((i + 1) / len(NCTIDs))
# 列を指定した順に並び替え
columns_order = ['NCTID', 'AgentGrade', 'Title', 'AgentJudgment', 'Japanes Locations',
'Primary Completion Date', 'Cancer', 'Summary', 'Eligibility Criteria']
# 存在する列のみを選択
available_columns = [col for col in columns_order if col in df.columns]
df = df[available_columns]
return df, df # フィルタ用と表示用にデータフレームを返す
except Exception as e:
print(f"データフレーム生成中に予期しないエラー: {e}")
traceback.print_exc()
return pd.DataFrame(), pd.DataFrame()
# CSVとして保存しダウンロードする関数
def download_filtered_csv(df):
try:
if df is None or len(df) == 0:
return None
file_path = "filtered_data.csv"
df.to_csv(file_path, index=False)
return file_path
except Exception as e:
print(f"CSV保存エラー: {e}")
return None
# 全体結果をCSVとして保存しダウンロードする関数
def download_full_csv(df):
try:
if df is None or len(df) == 0:
return None
file_path = "full_data.csv"
df.to_csv(file_path, index=False)
return file_path
except Exception as e:
print(f"CSV保存エラー: {e}")
return None
# Gradioインターフェースの作成
with gr.Blocks(title="臨床試験適格性評価", theme=gr.themes.Soft()) as demo:
gr.Markdown("## 臨床試験適格性評価インターフェース")
gr.Markdown("⚠️ **注意**: サーバーエラーが発生する場合があります。エラーが続く場合は少し時間をおいてから再試行してください。")
# 各種入力フィールド
with gr.Row():
with gr.Column():
age_input = gr.Textbox(label="Age", placeholder="例: 65", value="")
sex_input = gr.Dropdown(choices=["男性", "女性"], label="Sex", value=None)
tumor_type_input = gr.Textbox(label="Tumor Type", placeholder="例: gastric cancer", value="")
with gr.Column():
gene_mutation_input = gr.Textbox(label="Gene Mutation", placeholder="例: HER2", value="")
measurable_input = gr.Dropdown(choices=["有り", "無し", "不明"], label="Measurable Tumor", value=None)
biopsiable_input = gr.Dropdown(choices=["有り", "無し", "不明"], label="Biopsiable Tumor", value=None)
# データフレーム表示エリア
dataframe_output = gr.DataFrame(
headers=["NCTID", "AgentGrade", "Title", "AgentJudgment", "Status"],
datatype=["str", "str", "str", "str", "str"],
value=None
)
# 内部状態用の非表示コンポーネント
original_df_state = gr.State(value=None)
filtered_df_state = gr.State(value=None)
# ボタン類
with gr.Row():
generate_button = gr.Button("Generate Clinical Trials Data", variant="primary")
with gr.Row():
yes_button = gr.Button("Show Eligible Trials", variant="secondary")
no_button = gr.Button("Show Ineligible Trials", variant="secondary")
unclear_button = gr.Button("Show Unclear Trials", variant="secondary")
with gr.Row():
download_filtered_button = gr.Button("Download Filtered Data")
download_full_button = gr.Button("Download Full Data")
# ダウンロードファイル
download_filtered_output = gr.File(label="Download Filtered Data", visible=False)
download_full_output = gr.File(label="Download Full Data", visible=False)
# イベントハンドリング
def update_dataframe_and_state(age, sex, tumor_type, gene_mutation, measurable, biopsiable):
"""データフレーム生成と状態更新"""
df, _ = generate_dataframe(age, sex, tumor_type, gene_mutation, measurable, biopsiable)
return df, df, df
def filter_and_update(original_df, grade):
"""フィルタリングと表示更新"""
if original_df is None or len(original_df) == 0:
return original_df, original_df
try:
df_filtered = original_df[original_df['AgentGrade'] == grade]
return df_filtered, df_filtered
except Exception as e:
print(f"フィルタリングエラー: {e}")
return original_df, original_df
# ボタン動作の設定
generate_button.click(
fn=update_dataframe_and_state,
inputs=[age_input, sex_input, tumor_type_input, gene_mutation_input, measurable_input, biopsiable_input],
outputs=[dataframe_output, original_df_state, filtered_df_state]
)
yes_button.click(
fn=lambda df: filter_and_update(df, "yes"),
inputs=[original_df_state],
outputs=[dataframe_output, filtered_df_state]
)
no_button.click(
fn=lambda df: filter_and_update(df, "no"),
inputs=[original_df_state],
outputs=[dataframe_output, filtered_df_state]
)
unclear_button.click(
fn=lambda df: filter_and_update(df, "unclear"),
inputs=[original_df_state],
outputs=[dataframe_output, filtered_df_state]
)
download_filtered_button.click(
fn=download_filtered_csv,
inputs=[filtered_df_state],
outputs=[download_filtered_output]
)
download_full_button.click(
fn=download_full_csv,
inputs=[original_df_state],
outputs=[download_full_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
show_error=True
)