Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import time | |
import traceback | |
import os | |
from OpenAITools.FetchTools import fetch_clinical_trials | |
from langchain_openai import ChatOpenAI | |
from langchain_groq import ChatGroq | |
from OpenAITools.CrinicalTrialTools import SimpleClinicalTrialAgent, GraderAgent, LLMTranslator, generate_ex_question_English | |
# 環境変数チェック | |
def check_environment(): | |
"""環境変数をチェックし、不足している場合は警告""" | |
missing_vars = [] | |
if not os.getenv("GROQ_API_KEY"): | |
missing_vars.append("GROQ_API_KEY") | |
if not os.getenv("OPENAI_API_KEY"): | |
missing_vars.append("OPENAI_API_KEY") | |
if missing_vars: | |
print(f"⚠️ 環境変数が設定されていません: {', '.join(missing_vars)}") | |
print("一部の機能が制限される可能性があります。") | |
return len(missing_vars) == 0 | |
# 環境変数チェック実行 | |
env_ok = check_environment() | |
# モデルとエージェントの安全な初期化 | |
def safe_init_agents(): | |
"""エージェントを安全に初期化""" | |
try: | |
groq = ChatGroq(model_name="llama3-70b-8192", temperature=0) | |
translator = LLMTranslator(groq) | |
criteria_agent = SimpleClinicalTrialAgent(groq) | |
grader_agent = GraderAgent(groq) | |
return translator, criteria_agent, grader_agent | |
except Exception as e: | |
print(f"エージェント初期化エラー: {e}") | |
return None, None, None | |
# エージェント初期化 | |
translator, CriteriaCheckAgent, grader_agent = safe_init_agents() | |
# エラーハンドリング付きでエージェント評価を実行する関数 | |
def evaluate_with_retry(agent, criteria, question, max_retries=3): | |
"""エラーハンドリング付きでエージェント評価を実行""" | |
if agent is None: | |
return "評価エラー: エージェントが初期化されていません。API keyを確認してください。" | |
for attempt in range(max_retries): | |
try: | |
return agent.evaluate_eligibility(criteria, question) | |
except Exception as e: | |
if "missing variables" in str(e): | |
# プロンプトテンプレートの変数エラーの場合 | |
print(f"プロンプトテンプレートエラー (試行 {attempt + 1}/{max_retries}): {e}") | |
return "評価エラー: プロンプトテンプレートの設定に問題があります" | |
elif "no healthy upstream" in str(e) or "InternalServerError" in str(e): | |
# Groqサーバーエラーの場合 | |
print(f"Groqサーバーエラー (試行 {attempt + 1}/{max_retries}): {e}") | |
if attempt < max_retries - 1: | |
time.sleep(2) # 2秒待機してリトライ | |
continue | |
else: | |
return "評価エラー: サーバーに接続できませんでした" | |
elif "API key" in str(e) or "authentication" in str(e).lower(): | |
return "評価エラー: API keyが無効または設定されていません" | |
else: | |
print(f"予期しないエラー (試行 {attempt + 1}/{max_retries}): {e}") | |
if attempt < max_retries - 1: | |
time.sleep(1) | |
continue | |
else: | |
return f"評価エラー: {str(e)}" | |
return "評価エラー: 最大リトライ回数に達しました" | |
def evaluate_grade_with_retry(agent, judgment, max_retries=3): | |
"""エラーハンドリング付きでグレード評価を実行""" | |
if agent is None: | |
return "unclear" | |
for attempt in range(max_retries): | |
try: | |
return agent.evaluate_eligibility(judgment) | |
except Exception as e: | |
if "no healthy upstream" in str(e) or "InternalServerError" in str(e): | |
print(f"Groqサーバーエラー (グレード評価 - 試行 {attempt + 1}/{max_retries}): {e}") | |
if attempt < max_retries - 1: | |
time.sleep(2) | |
continue | |
else: | |
return "unclear" | |
elif "API key" in str(e) or "authentication" in str(e).lower(): | |
return "unclear" | |
else: | |
print(f"予期しないエラー (グレード評価 - 試行 {attempt + 1}/{max_retries}): {e}") | |
if attempt < max_retries - 1: | |
time.sleep(1) | |
continue | |
else: | |
return "unclear" | |
return "unclear" | |
# データフレームを生成する関数 | |
def generate_dataframe(age, sex, tumor_type, GeneMutation, Meseable, Biopsiable): | |
try: | |
# 入力検証 | |
if not all([age, sex, tumor_type]): | |
return pd.DataFrame(), pd.DataFrame() | |
# 日本語の腫瘍タイプを英語に翻訳 | |
try: | |
if translator is not None: | |
TumorName = translator.translate(tumor_type) | |
else: | |
print("翻訳エージェントが利用できません。元の値を使用します。") | |
TumorName = tumor_type | |
except Exception as e: | |
print(f"翻訳エラー: {e}") | |
TumorName = tumor_type # 翻訳に失敗した場合は元の値を使用 | |
# 質問文を生成 | |
try: | |
ex_question = generate_ex_question_English(age, sex, TumorName, GeneMutation, Meseable, Biopsiable) | |
except Exception as e: | |
print(f"質問生成エラー: {e}") | |
return pd.DataFrame(), pd.DataFrame() | |
# 臨床試験データの取得 | |
try: | |
df = fetch_clinical_trials(TumorName) | |
if df.empty: | |
print("臨床試験データが見つかりませんでした") | |
return pd.DataFrame(), pd.DataFrame() | |
except Exception as e: | |
print(f"臨床試験データ取得エラー: {e}") | |
return pd.DataFrame(), pd.DataFrame() | |
df['AgentJudgment'] = None | |
df['AgentGrade'] = None | |
# 臨床試験の適格性の評価 | |
NCTIDs = list(df['NCTID']) | |
progress = gr.Progress(track_tqdm=True) | |
for i, nct_id in enumerate(NCTIDs): | |
try: | |
target_criteria = df.loc[df['NCTID'] == nct_id, 'Eligibility Criteria'].values[0] | |
# エラーハンドリング付きで評価実行 | |
agent_judgment = evaluate_with_retry(CriteriaCheckAgent, target_criteria, ex_question) | |
agent_grade = evaluate_grade_with_retry(grader_agent, agent_judgment) | |
# データフレームの更新 | |
df.loc[df['NCTID'] == nct_id, 'AgentJudgment'] = agent_judgment | |
df.loc[df['NCTID'] == nct_id, 'AgentGrade'] = agent_grade | |
except Exception as e: | |
print(f"NCTID {nct_id} の評価中にエラー: {e}") | |
df.loc[df['NCTID'] == nct_id, 'AgentJudgment'] = f"エラー: {str(e)}" | |
df.loc[df['NCTID'] == nct_id, 'AgentGrade'] = "unclear" | |
progress((i + 1) / len(NCTIDs)) | |
# 列を指定した順に並び替え | |
columns_order = ['NCTID', 'AgentGrade', 'Title', 'AgentJudgment', 'Japanes Locations', | |
'Primary Completion Date', 'Cancer', 'Summary', 'Eligibility Criteria'] | |
# 存在する列のみを選択 | |
available_columns = [col for col in columns_order if col in df.columns] | |
df = df[available_columns] | |
return df, df # フィルタ用と表示用にデータフレームを返す | |
except Exception as e: | |
print(f"データフレーム生成中に予期しないエラー: {e}") | |
traceback.print_exc() | |
return pd.DataFrame(), pd.DataFrame() | |
# CSVとして保存しダウンロードする関数 | |
def download_filtered_csv(df): | |
try: | |
if df is None or len(df) == 0: | |
return None | |
file_path = "filtered_data.csv" | |
df.to_csv(file_path, index=False) | |
return file_path | |
except Exception as e: | |
print(f"CSV保存エラー: {e}") | |
return None | |
# 全体結果をCSVとして保存しダウンロードする関数 | |
def download_full_csv(df): | |
try: | |
if df is None or len(df) == 0: | |
return None | |
file_path = "full_data.csv" | |
df.to_csv(file_path, index=False) | |
return file_path | |
except Exception as e: | |
print(f"CSV保存エラー: {e}") | |
return None | |
# Gradioインターフェースの作成 | |
with gr.Blocks(title="臨床試験適格性評価", theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## 臨床試験適格性評価インターフェース") | |
# 環境変数状態の表示 | |
if env_ok: | |
gr.Markdown("✅ **ステータス**: 全ての環境変数が設定されています") | |
else: | |
gr.Markdown("⚠️ **注意**: 一部の環境変数が設定されていません。機能が制限される可能性があります。") | |
gr.Markdown("💡 **使用方法**: 患者情報を入力して「Generate Clinical Trials Data」をクリックしてください。") | |
# 各種入力フィールド | |
with gr.Row(): | |
with gr.Column(): | |
age_input = gr.Textbox(label="Age", placeholder="例: 65", value="") | |
sex_input = gr.Dropdown(choices=["男性", "女性"], label="Sex", value=None) | |
tumor_type_input = gr.Textbox(label="Tumor Type", placeholder="例: gastric cancer", value="") | |
with gr.Column(): | |
gene_mutation_input = gr.Textbox(label="Gene Mutation", placeholder="例: HER2", value="") | |
measurable_input = gr.Dropdown(choices=["有り", "無し", "不明"], label="Measurable Tumor", value=None) | |
biopsiable_input = gr.Dropdown(choices=["有り", "無し", "不明"], label="Biopsiable Tumor", value=None) | |
# データフレーム表示エリア | |
dataframe_output = gr.DataFrame( | |
headers=["NCTID", "AgentGrade", "Title", "AgentJudgment", "Status"], | |
datatype=["str", "str", "str", "str", "str"], | |
value=None | |
) | |
# 内部状態用の非表示コンポーネント | |
original_df_state = gr.State(value=None) | |
filtered_df_state = gr.State(value=None) | |
# ボタン類 | |
with gr.Row(): | |
generate_button = gr.Button("Generate Clinical Trials Data", variant="primary") | |
with gr.Row(): | |
yes_button = gr.Button("Show Eligible Trials", variant="secondary") | |
no_button = gr.Button("Show Ineligible Trials", variant="secondary") | |
unclear_button = gr.Button("Show Unclear Trials", variant="secondary") | |
with gr.Row(): | |
download_filtered_button = gr.Button("Download Filtered Data") | |
download_full_button = gr.Button("Download Full Data") | |
# ダウンロードファイル | |
download_filtered_output = gr.File(label="Download Filtered Data", visible=False) | |
download_full_output = gr.File(label="Download Full Data", visible=False) | |
# イベントハンドリング | |
def update_dataframe_and_state(age, sex, tumor_type, gene_mutation, measurable, biopsiable): | |
"""データフレーム生成と状態更新""" | |
df, _ = generate_dataframe(age, sex, tumor_type, gene_mutation, measurable, biopsiable) | |
return df, df, df | |
def filter_and_update(original_df, grade): | |
"""フィルタリングと表示更新""" | |
if original_df is None or len(original_df) == 0: | |
return original_df, original_df | |
try: | |
df_filtered = original_df[original_df['AgentGrade'] == grade] | |
return df_filtered, df_filtered | |
except Exception as e: | |
print(f"フィルタリングエラー: {e}") | |
return original_df, original_df | |
# ボタン動作の設定 | |
generate_button.click( | |
fn=update_dataframe_and_state, | |
inputs=[age_input, sex_input, tumor_type_input, gene_mutation_input, measurable_input, biopsiable_input], | |
outputs=[dataframe_output, original_df_state, filtered_df_state] | |
) | |
yes_button.click( | |
fn=lambda df: filter_and_update(df, "yes"), | |
inputs=[original_df_state], | |
outputs=[dataframe_output, filtered_df_state] | |
) | |
no_button.click( | |
fn=lambda df: filter_and_update(df, "no"), | |
inputs=[original_df_state], | |
outputs=[dataframe_output, filtered_df_state] | |
) | |
unclear_button.click( | |
fn=lambda df: filter_and_update(df, "unclear"), | |
inputs=[original_df_state], | |
outputs=[dataframe_output, filtered_df_state] | |
) | |
download_filtered_button.click( | |
fn=download_filtered_csv, | |
inputs=[filtered_df_state], | |
outputs=[download_filtered_output] | |
) | |
download_full_button.click( | |
fn=download_full_csv, | |
inputs=[original_df_state], | |
outputs=[download_full_output] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=False, | |
show_error=True | |
) |