|
from openai import OpenAI |
|
from dotenv import load_dotenv |
|
import os |
|
import threading |
|
import time |
|
import gradio as gr |
|
from lang import LANGUAGE_CONFIG |
|
|
|
|
|
load_dotenv(override=True) |
|
required_env_vars = ["API_KEY", "API_URL", "API_MODEL"] |
|
secondary_api_exists = all( |
|
os.getenv(f"{var}_2") for var in ["API_KEY", "API_URL", "API_MODEL"] |
|
) |
|
missing_vars = [var for var in required_env_vars if not os.getenv(var)] |
|
if missing_vars: |
|
raise EnvironmentError( |
|
f"Missing required environment variables: {', '.join(missing_vars)}" |
|
) |
|
|
|
|
|
class AppConfig: |
|
DEFAULT_THROUGHPUT = 10 |
|
SYNC_THRESHOLD_DEFAULT = 0 |
|
API_TIMEOUT = 20 |
|
|
|
|
|
class DynamicState: |
|
"""动态UI状态""" |
|
|
|
def __init__(self): |
|
self.should_stream = False |
|
self.stream_completed = False |
|
self.in_cot = True |
|
self.current_language = "en" |
|
self.waiting_api = False |
|
self.label_passthrough = False |
|
|
|
def control_button_handler(self): |
|
original_state = self.should_stream |
|
self.should_stream = not self.should_stream |
|
|
|
|
|
if not original_state and self.should_stream: |
|
self.waiting_api = True |
|
self.stream_completed = False |
|
|
|
return self.ui_state_controller() |
|
|
|
def ui_state_controller(self): |
|
"""生成动态UI组件状态""" |
|
|
|
lang_data = LANGUAGE_CONFIG[self.current_language] |
|
control_value = ( |
|
lang_data["pause_btn"] if self.should_stream else lang_data["generate_btn"] |
|
) |
|
control_variant = "secondary" if self.should_stream else "primary" |
|
|
|
if self.waiting_api and self.should_stream: |
|
status_suffix = lang_data["waiting_api"] |
|
elif self.waiting_api and not self.should_stream: |
|
status_suffix = lang_data["api_retry"] |
|
else: |
|
status_suffix = ( |
|
lang_data["completed"] |
|
if self.stream_completed |
|
else lang_data["interrupted"] |
|
) |
|
editor_label = f"{lang_data['editor_label']} - {status_suffix}" |
|
output = ( |
|
gr.update(value=control_value, variant=control_variant), |
|
gr.update() if self.label_passthrough else gr.update(label=editor_label), |
|
gr.update(interactive=not self.should_stream), |
|
) |
|
self.label_passthrough = False |
|
return output |
|
|
|
def reset_workspace(self): |
|
"""重置工作区状态""" |
|
self.stream_completed = False |
|
self.should_stream = False |
|
self.in_cot = True |
|
self.waiting_api = False |
|
|
|
return self.ui_state_controller() + ( |
|
"", |
|
"", |
|
LANGUAGE_CONFIG["en"]["bot_default"], |
|
DEFAULT_PERSISTENT, |
|
) |
|
|
|
|
|
class CoordinationManager: |
|
"""管理人类与AI的协同节奏""" |
|
|
|
def __init__(self, paragraph_threshold, initial_content): |
|
self.paragraph_threshold = paragraph_threshold |
|
self.initial_paragraph_count = initial_content.count("\n\n") |
|
self.triggered = False |
|
|
|
def should_pause_for_human(self, current_content): |
|
if self.paragraph_threshold <= 0 or self.triggered: |
|
return False |
|
|
|
current_paragraphs = current_content.count("\n\n") |
|
if ( |
|
current_paragraphs - self.initial_paragraph_count |
|
>= self.paragraph_threshold |
|
): |
|
self.triggered = True |
|
return True |
|
return False |
|
|
|
|
|
class ConvoState: |
|
"""State of current ROUND of convo""" |
|
|
|
def __init__(self): |
|
self.throughput = AppConfig.DEFAULT_THROUGHPUT |
|
self.sync_threshold = AppConfig.SYNC_THRESHOLD_DEFAULT |
|
self.current_language = "en" |
|
self.convo = [] |
|
self.initialize_new_round() |
|
self.is_error = False |
|
self.result_editing_toggle = False |
|
|
|
def get_api_config(self, language): |
|
suffix = "_2" if language == "zh" and secondary_api_exists else "" |
|
return { |
|
"key": os.getenv(f"API_KEY{suffix}"), |
|
"url": os.getenv(f"API_URL{suffix}"), |
|
"model": os.getenv(f"API_MODEL{suffix}"), |
|
} |
|
|
|
def initialize_new_round(self): |
|
self.current = {} |
|
self.current["user"] = "" |
|
self.current["cot"] = "" |
|
self.current["result"] = "" |
|
self.current["raw"] = "" |
|
self.convo.append(self.current) |
|
|
|
def flatten_output(self): |
|
output = [] |
|
for round in self.convo: |
|
output.append({"role": "user", "content": round["user"]}) |
|
if len(round["cot"]) > 0: |
|
output.append( |
|
{ |
|
"role": "assistant", |
|
"content": round["cot"], |
|
"metadata": {"title": f"Chain of Thought"}, |
|
} |
|
) |
|
if len(round["result"]) > 0: |
|
output.append({"role": "assistant", "content": round["result"]}) |
|
return output |
|
|
|
def generate_ai_response(self, user_prompt, current_content, dynamic_state): |
|
lang_data = LANGUAGE_CONFIG[self.current_language] |
|
dynamic_state.stream_completed = False |
|
full_response = current_content |
|
self.current["raw"] = full_response |
|
api_config = self.get_api_config(self.current_language) |
|
api_client = OpenAI( |
|
api_key=api_config["key"], |
|
base_url=api_config["url"], |
|
timeout=AppConfig.API_TIMEOUT, |
|
) |
|
|
|
coordinator = CoordinationManager(self.sync_threshold, current_content) |
|
|
|
editor_output = current_content |
|
|
|
try: |
|
|
|
|
|
if dynamic_state.waiting_api: |
|
status = lang_data["waiting_api"] |
|
editor_label = f"{lang_data['editor_label']} - {status}" |
|
yield full_response, gr.update( |
|
label=editor_label |
|
), self.flatten_output() |
|
|
|
coordinator = CoordinationManager(self.sync_threshold, current_content) |
|
messages = [ |
|
{"role": "user", "content": user_prompt}, |
|
{ |
|
"role": "assistant", |
|
"content": f"<think>\n{current_content}", |
|
"prefix": True, |
|
}, |
|
] |
|
self.current["user"] = user_prompt |
|
response_stream = api_client.chat.completions.create( |
|
model=os.getenv("API_MODEL"), |
|
messages=messages, |
|
stream=True, |
|
timeout=AppConfig.API_TIMEOUT, |
|
top_p=0.95, |
|
temperature=0.6, |
|
) |
|
for chunk in response_stream: |
|
chunk_content = chunk.choices[0].delta.content |
|
if ( |
|
coordinator.should_pause_for_human(full_response) |
|
and dynamic_state.in_cot |
|
): |
|
dynamic_state.should_stream = False |
|
if not dynamic_state.should_stream: |
|
break |
|
|
|
if chunk_content: |
|
dynamic_state.waiting_api = False |
|
full_response += chunk_content.replace("<think>", "") |
|
self.current["raw"] = full_response |
|
|
|
think_complete = "</think>" in full_response |
|
dynamic_state.in_cot = not think_complete |
|
if think_complete: |
|
self.current["cot"], self.current["result"] = ( |
|
full_response.split("</think>") |
|
) |
|
else: |
|
self.current["cot"], self.current["result"] = ( |
|
full_response, |
|
"", |
|
) |
|
status = ( |
|
lang_data["loading_thinking"] |
|
if dynamic_state.in_cot |
|
else lang_data["loading_output"] |
|
) |
|
editor_label = f"{lang_data['editor_label']} - {status}" |
|
if self.result_editing_toggle: |
|
editor_output = full_response |
|
else: |
|
editor_output = self.current["cot"] + ( |
|
"</think>" if think_complete else "" |
|
) |
|
yield editor_output, gr.update( |
|
label=editor_label |
|
), self.flatten_output() |
|
|
|
interval = 1.0 / self.throughput |
|
start_time = time.time() |
|
while ( |
|
(time.time() - start_time) < interval |
|
and dynamic_state.should_stream |
|
and dynamic_state.in_cot |
|
): |
|
time.sleep(0.005) |
|
|
|
except Exception as e: |
|
if str(e) == "list index out of range": |
|
dynamic_state.stream_completed = True |
|
else: |
|
if str(e) == "The read operation timed out": |
|
error_msg = lang_data["api_interrupted"] |
|
else: |
|
error_msg = "❓ " + str(e) |
|
|
|
editor_label_error = f"{lang_data['editor_label']} - {error_msg}" |
|
self.is_error = True |
|
dynamic_state.label_passthrough = True |
|
|
|
finally: |
|
dynamic_state.should_stream = False |
|
if "response_stream" in locals(): |
|
response_stream.close() |
|
final_status = ( |
|
lang_data["completed"] |
|
if dynamic_state.stream_completed |
|
else lang_data["interrupted"] |
|
) |
|
editor_label = f"{lang_data['editor_label']} - {final_status}" |
|
if not self.is_error: |
|
yield editor_output, gr.update( |
|
label=editor_label |
|
), self.flatten_output() |
|
else: |
|
yield editor_output, gr.update( |
|
label=editor_label_error |
|
), self.flatten_output() + [ |
|
{ |
|
"role": "assistant", |
|
"content": error_msg, |
|
"metadata": {"title": f"❌Error"}, |
|
} |
|
] |
|
self.is_error = False |
|
|
|
|
|
def update_interface_language(selected_lang, convo_state, dynamic_state): |
|
"""更新界面语言配置""" |
|
convo_state.current_language = selected_lang |
|
dynamic_state.current_language = selected_lang |
|
lang_data = LANGUAGE_CONFIG[selected_lang] |
|
base_editor_label = lang_data["editor_label"] |
|
status_suffix = ( |
|
lang_data["completed"] |
|
if dynamic_state.stream_completed |
|
else lang_data["interrupted"] |
|
) |
|
editor_label = f"{base_editor_label} - {status_suffix}" |
|
api_config = convo_state.get_api_config(selected_lang) |
|
new_bot_content = [ |
|
{ |
|
"role": "assistant", |
|
"content": f"{selected_lang} - Running `{api_config['model']}` @ {api_config['url']}", |
|
"metadata": {"title": f"API INFO"}, |
|
} |
|
] |
|
return [ |
|
gr.update(value=f"{lang_data['title']}"), |
|
gr.update( |
|
label=lang_data["prompt_label"], placeholder=lang_data["prompt_placeholder"] |
|
), |
|
gr.update(label=editor_label, placeholder=lang_data["editor_placeholder"]), |
|
gr.update( |
|
label=lang_data["sync_threshold_label"], |
|
info=lang_data["sync_threshold_info"], |
|
), |
|
gr.update( |
|
label=lang_data["throughput_label"], info=lang_data["throughput_info"] |
|
), |
|
gr.update( |
|
value=( |
|
lang_data["pause_btn"] |
|
if dynamic_state.should_stream |
|
else lang_data["generate_btn"] |
|
), |
|
variant="secondary" if dynamic_state.should_stream else "primary", |
|
), |
|
gr.update(label=lang_data["language_label"]), |
|
gr.update( |
|
value=lang_data["clear_btn"], interactive=not dynamic_state.should_stream |
|
), |
|
gr.update(value=lang_data["introduction"]), |
|
gr.update( |
|
value=lang_data["bot_default"] + new_bot_content, |
|
label=lang_data["bot_label"], |
|
), |
|
gr.update(label=lang_data["result_editing_toggle"]), |
|
] |
|
|
|
|
|
theme = gr.themes.Base(font="system-ui", primary_hue="stone") |
|
|
|
with gr.Blocks(theme=theme, css_paths="styles.css") as demo: |
|
DEFAULT_PERSISTENT = {"prompt_input": "", "thought_editor": ""} |
|
convo_state = gr.State(ConvoState) |
|
dynamic_state = gr.State(DynamicState) |
|
persistent_state = gr.BrowserState(DEFAULT_PERSISTENT) |
|
|
|
bot_default = LANGUAGE_CONFIG["en"]["bot_default"] + [ |
|
{ |
|
"role": "assistant", |
|
"content": f"Running `{os.getenv('API_MODEL')}` @ {os.getenv('API_URL')}", |
|
"metadata": {"title": f"EN API INFO"}, |
|
} |
|
] |
|
if secondary_api_exists: |
|
bot_default.append( |
|
{ |
|
"role": "assistant", |
|
"content": f"Switch to zh ↗ to use SiliconFlow API: `{os.getenv('API_MODEL_2')}` @ {os.getenv('API_URL_2')}", |
|
"metadata": {"title": f"CN API INFO"}, |
|
} |
|
) |
|
|
|
with gr.Row(variant=""): |
|
title_md = gr.Markdown( |
|
f"{LANGUAGE_CONFIG['en']['title']} ", |
|
container=False, |
|
) |
|
lang_selector = gr.Dropdown( |
|
choices=["en", "zh"], |
|
value="en", |
|
elem_id="compact_lang_selector", |
|
scale=0, |
|
container=False, |
|
) |
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=1, min_width=400): |
|
prompt_input = gr.Textbox( |
|
label=LANGUAGE_CONFIG["en"]["prompt_label"], |
|
lines=2, |
|
placeholder=LANGUAGE_CONFIG["en"]["prompt_placeholder"], |
|
max_lines=2, |
|
) |
|
thought_editor = gr.Textbox( |
|
label=f"{LANGUAGE_CONFIG['en']['editor_label']} - {LANGUAGE_CONFIG['en']['editor_default']}", |
|
lines=16, |
|
max_lines=16, |
|
placeholder=LANGUAGE_CONFIG["en"]["editor_placeholder"], |
|
autofocus=True, |
|
elem_id="editor", |
|
) |
|
with gr.Row(): |
|
control_button = gr.Button( |
|
value=LANGUAGE_CONFIG["en"]["generate_btn"], variant="primary" |
|
) |
|
next_turn_btn = gr.Button( |
|
value=LANGUAGE_CONFIG["en"]["clear_btn"], interactive=True |
|
) |
|
|
|
with gr.Column(scale=1, min_width=500): |
|
chatbot = gr.Chatbot( |
|
type="messages", |
|
height=300, |
|
value=bot_default, |
|
group_consecutive_messages=False, |
|
show_copy_all_button=True, |
|
show_share_button=True, |
|
label=LANGUAGE_CONFIG["en"]["bot_label"], |
|
) |
|
with gr.Row(): |
|
sync_threshold_slider = gr.Slider( |
|
minimum=0, |
|
maximum=20, |
|
value=AppConfig.SYNC_THRESHOLD_DEFAULT, |
|
step=1, |
|
label=LANGUAGE_CONFIG["en"]["sync_threshold_label"], |
|
info=LANGUAGE_CONFIG["en"]["sync_threshold_info"], |
|
) |
|
throughput_control = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=AppConfig.DEFAULT_THROUGHPUT, |
|
step=1, |
|
label=LANGUAGE_CONFIG["en"]["throughput_label"], |
|
info=LANGUAGE_CONFIG["en"]["throughput_info"], |
|
) |
|
result_editing_toggle = gr.Checkbox( |
|
label=LANGUAGE_CONFIG["en"]["result_editing_toggle"], |
|
interactive=True, |
|
scale=0, |
|
container=False, |
|
) |
|
|
|
intro_md = gr.Markdown(LANGUAGE_CONFIG["en"]["introduction"], visible=False) |
|
|
|
@demo.load(inputs=[persistent_state], outputs=[prompt_input, thought_editor]) |
|
def recover_persistent_state(persistant_state): |
|
if persistant_state["prompt_input"] or persistant_state["thought_editor"]: |
|
return persistant_state["prompt_input"], persistant_state["thought_editor"] |
|
else: |
|
return gr.update(), gr.update() |
|
|
|
|
|
stateful_ui = (control_button, thought_editor, next_turn_btn) |
|
|
|
throughput_control.change( |
|
lambda val, s: setattr(s, "throughput", val), |
|
[throughput_control, convo_state], |
|
None, |
|
concurrency_limit=None, |
|
) |
|
|
|
sync_threshold_slider.change( |
|
lambda val, s: setattr(s, "sync_threshold", val), |
|
[sync_threshold_slider, convo_state], |
|
None, |
|
concurrency_limit=None, |
|
) |
|
|
|
def wrap_stream_generator(convo_state, dynamic_state, prompt, content): |
|
for response in convo_state.generate_ai_response( |
|
prompt, content, dynamic_state |
|
): |
|
yield response + ( |
|
{ |
|
"prompt_input": convo_state.current["user"], |
|
"thought_editor": convo_state.current["cot"], |
|
}, |
|
) |
|
|
|
gr.on( |
|
[control_button.click, prompt_input.submit, thought_editor.submit], |
|
lambda d: d.control_button_handler(), |
|
[dynamic_state], |
|
stateful_ui, |
|
show_progress=False, |
|
concurrency_limit=None, |
|
).then( |
|
wrap_stream_generator, |
|
[convo_state, dynamic_state, prompt_input, thought_editor], |
|
[thought_editor, thought_editor, chatbot, persistent_state], |
|
concurrency_limit=1000, |
|
).then( |
|
lambda d: d.ui_state_controller(), |
|
[dynamic_state], |
|
stateful_ui, |
|
show_progress=False, |
|
concurrency_limit=None, |
|
) |
|
|
|
next_turn_btn.click( |
|
lambda d: d.reset_workspace(), |
|
[dynamic_state], |
|
stateful_ui + (thought_editor, prompt_input, chatbot, persistent_state), |
|
concurrency_limit=None, |
|
show_progress=False, |
|
) |
|
|
|
def toggle_editor_result(convo_state, allow): |
|
setattr(convo_state, "result_editing_toggle", allow) |
|
if allow: |
|
return gr.update(value=convo_state.current["raw"]) |
|
else: |
|
return gr.update(value=convo_state.current["cot"]) |
|
|
|
result_editing_toggle.change( |
|
toggle_editor_result, |
|
inputs=[convo_state, result_editing_toggle], |
|
outputs=[thought_editor], |
|
) |
|
|
|
lang_selector.change( |
|
lambda lang, s, d: update_interface_language(lang, s, d), |
|
[lang_selector, convo_state, dynamic_state], |
|
[ |
|
title_md, |
|
prompt_input, |
|
thought_editor, |
|
sync_threshold_slider, |
|
throughput_control, |
|
control_button, |
|
lang_selector, |
|
next_turn_btn, |
|
intro_md, |
|
chatbot, |
|
result_editing_toggle, |
|
], |
|
concurrency_limit=None, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(default_concurrency_limit=1000) |
|
demo.launch() |
|
|