Spaces:
Runtime error
Runtime error
# 导入tushare | |
import tushare as ts | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import os | |
import json | |
from matplotlib.ticker import MaxNLocator | |
import matplotlib.font_manager as fm | |
from lab_gpt4_call import send_chat_request,send_chat_request_Azure,send_official_call | |
#import ast | |
import re | |
from tool import * | |
import tiktoken | |
import concurrent.futures | |
import datetime | |
from PIL import Image | |
from io import BytesIO | |
import queue | |
import datetime | |
from threading import Thread | |
# plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] | |
# plt.rcParams['axes.unicode_minus'] = False | |
import openai | |
# To override the Thread method | |
class MyThread(Thread): | |
def __init__(self, target, args): | |
super(MyThread, self).__init__() | |
self.func = target | |
self.args = args | |
def run(self): | |
self.result = self.func(*self.args) | |
def get_result(self): | |
return self.result | |
def parse_and_exe(call_dict, result_buffer, parallel_step: str='1'): | |
""" | |
Parse the input and call the corresponding function to obtain the result. | |
:param call_dict: dict, including arg, func, and output | |
:param result_buffer: dict, storing the corresponding intermediate results | |
:param parallel_step: int, parallel step | |
:return: Returns func(arg) and stores the corresponding result in result_buffer. | |
""" | |
arg_list = call_dict['arg' + parallel_step] | |
replace_arg_list = [result_buffer[item][0] if isinstance(item, str) and ('result' in item or 'input' in item) else item for item in arg_list] # 参数 | |
func_name = call_dict['function' + parallel_step] # | |
output = call_dict['output' + parallel_step] # | |
desc = call_dict['description' + parallel_step] # | |
if func_name == 'loop_rank': | |
replace_arg_list[1] = eval(replace_arg_list[1]) | |
result = eval(func_name)(*replace_arg_list) | |
result_buffer[output] = (result, desc) # 'result1': (df1, desc) | |
return result_buffer | |
def load_tool_and_prompt(tool_lib, tool_prompt ): | |
''' | |
Read two JSON files. | |
:param tool_lib: Tool description | |
:param tool_prompt: Tool prompt | |
:return: Flattened prompt | |
''' | |
# | |
with open(tool_lib, 'r') as f: | |
tool_lib = json.load(f) | |
with open(tool_prompt, 'r') as f: | |
# | |
tool_prompt = json.load(f) | |
for key, value in tool_lib.items(): | |
tool_prompt["Function Library:"] = tool_prompt["Function Library:"] + key + " " + value+ '\n\n' | |
prompt_flat = '' | |
for key, value in tool_prompt.items(): | |
prompt_flat = prompt_flat + key +' '+ value + '\n\n' | |
return prompt_flat | |
# callback function | |
intermediate_results = queue.Queue() # Create a queue to store intermediate results. | |
def add_to_queue(intermediate_result): | |
intermediate_results.put(f"After planing, the intermediate result is {intermediate_result}") | |
def check_RPM(run_time_list, new_time, max_RPM=1): | |
# Check if there are already 3 timestamps in the run_time_list, with a maximum of 3 accesses per minute. | |
# False means no rest is needed, True means rest is needed. | |
if len(run_time_list) < 3: | |
run_time_list.append(new_time) | |
return 0 | |
else: | |
if (new_time - run_time_list[0]).seconds < max_RPM: | |
# Calculate the required rest time. | |
sleep_time = 60 - (new_time - run_time_list[0]).seconds | |
print('sleep_time:', sleep_time) | |
run_time_list.pop(0) | |
run_time_list.append(new_time) | |
return sleep_time | |
else: | |
run_time_list.pop(0) | |
run_time_list.append(new_time) | |
return 0 | |
def run(instruction, add_to_queue=None, send_chat_request_Azure = send_official_call, openai_key = '', api_base='', engine=''): | |
output_text = '' | |
################################# Step-1:Task select ########################################### | |
current_time = datetime.datetime.now() | |
formatted_time = current_time.strftime("%Y-%m-%d") | |
# If the time has not exceeded 3 PM, use yesterday's data. | |
if current_time.hour < 15: | |
formatted_time = (current_time - datetime.timedelta(days=1)).strftime("%Y-%m-%d") | |
print('===============================Intent Detecting===========================================') | |
with open('./prompt_lib/prompt_intent_detection.json', 'r') as f: | |
prompt_task_dict = json.load(f) | |
prompt_intent_detection = '' | |
for key, value in prompt_task_dict.items(): | |
prompt_intent_detection = prompt_intent_detection + key + ": " + value+ '\n\n' | |
prompt_intent_detection = prompt_intent_detection + '\n\n' + 'Instruction:' + '今天的日期是'+ formatted_time +', '+ instruction + ' ###New Instruction: ' | |
# Record the running time. | |
# current_time = datetime.datetime.now() | |
# sleep_time = check_RPM(run_time, current_time) | |
# if sleep_time > 0: | |
# time.sleep(sleep_time) | |
try: | |
response = send_chat_request_Azure(prompt_intent_detection, openai_key=openai_key, api_base=api_base, engine=engine) | |
# 返回错误 | |
except Exception as e: | |
return e | |
new_instruction = response | |
print('new_instruction:', new_instruction) | |
output_text = output_text + '\n======Intent Detecting Stage=====\n\n' | |
output_text = output_text + new_instruction +'\n\n' | |
if add_to_queue is not None: | |
add_to_queue(output_text) | |
event_happen = True | |
print('===============================Task Planing===========================================') | |
output_text= output_text + '=====Task Planing Stage=====\n\n' | |
with open('./prompt_lib/prompt_task.json', 'r') as f: | |
prompt_task_dict = json.load(f) | |
prompt_task = '' | |
for key, value in prompt_task_dict.items(): | |
prompt_task = prompt_task + key + ": " + value+ '\n\n' | |
prompt_task = prompt_task + '\n\n' + 'Instruction:' + new_instruction + ' ###Plan:' | |
# current_time = datetime.datetime.now() | |
# sleep_time = check_RPM(run_time, current_time) | |
# if sleep_time > 0: | |
# time.sleep(sleep_time) | |
try: | |
response = send_chat_request_Azure(prompt_task, openai_key=openai_key,api_base=api_base,engine=engine) | |
except Exception as e: | |
return e | |
task_select = response | |
pattern = r"(task\d+=)(\{[^}]*\})" | |
matches = re.findall(pattern, task_select) | |
task_plan = {} | |
for task in matches: | |
task_step, task_select = task | |
task_select = task_select.replace("'", "\"") # Replace single quotes with double quotes. | |
task_select = json.loads(task_select) | |
task_name = list(task_select.keys())[0] | |
task_instruction = list(task_select.values())[0] | |
task_plan[task_name] = task_instruction | |
# task_plan | |
for key, value in task_plan.items(): | |
print(key, ':', value) | |
output_text = output_text + key + ': ' + str(value) + '\n' | |
output_text = output_text +'\n' | |
if add_to_queue is not None: | |
add_to_queue(output_text) | |
################################# Step-2:Tool select and use ########################################### | |
print('===============================Tool select and using Stage===========================================') | |
output_text = output_text + '======Tool select and using Stage======\n\n' | |
# Read the task_select JSON file name. | |
task_name = list(task_plan.keys())[0].split('_task')[0] | |
task_instruction = list(task_plan.values())[0] | |
tool_lib = './tool_lib/' + 'tool_' + task_name + '.json' | |
tool_prompt = './prompt_lib/' + 'prompt_' + task_name + '.json' | |
prompt_flat = load_tool_and_prompt(tool_lib, tool_prompt) | |
prompt_flat = prompt_flat + '\n\n' +'Instruction :'+ task_instruction+ ' ###Function Call' | |
#response = "step1={\n \"arg1\": [\"贵州茅台\"],\n \"function1\": \"get_stock_code\",\n \"output1\": \"result1\"\n},step2={\n \"arg1\": [\"result1\",\"20180123\",\"20190313\",\"daily\"],\n \"function1\": \"get_stock_prices_data\",\n \"output1\": \"result2\"\n},step3={\n \"arg1\": [\"result2\",\"close\"],\n \"function1\": \"calculate_stock_index\",\n \"output1\": \"result3\"\n}, ###Output:{\n \"贵州茅台在2018年1月23日到2019年3月13的每日收盘价格的时序表格\": \"result3\",\n}" | |
# current_time = datetime.datetime.now() | |
# sleep_time = check_RPM(run_time, current_time) | |
# if sleep_time > 0: | |
# time.sleep(sleep_time) | |
try: | |
response = send_chat_request_Azure(prompt_flat, openai_key=openai_key,api_base=api_base, engine=engine) | |
except Exception as e: | |
return e | |
#response = "Function Call:step1={\n \"arg1\": [\"五粮液\"],\n \"function1\": \"get_stock_code\",\n \"output1\": \"result1\",\n \"arg2\": [\"泸州老窖\"],\n \"function2\": \"get_stock_code\",\n \"output2\": \"result2\"\n},step2={\n \"arg1\": [\"result1\",\"20190101\",\"20220630\",\"daily\"],\n \"function1\": \"get_stock_prices_data\",\n \"output1\": \"result3\",\n \"arg2\": [\"result2\",\"20190101\",\"20220630\",\"daily\"],\n \"function2\": \"get_stock_prices_data\",\n \"output2\": \"result4\"\n},step3={\n \"arg1\": [\"result3\",\"Cumulative_Earnings_Rate\"],\n \"function1\": \"calculate_stock_index\",\n \"output1\": \"result5\",\n \"arg2\": [\"result4\",\"Cumulative_Earnings_Rate\"],\n \"function2\": \"calculate_stock_index\",\n \"output2\": \"result6\"\n}, ###Output:{\n \"五粮液在2019年1月1日到2022年06月30的每日收盘价格时序表格\": \"result5\",\n \"泸州老窖在2019年1月1日到2022年06月30的每日收盘价格时序表格\": \"result6\"\n}" | |
call_steps, _ = response.split('###') | |
pattern = r"(step\d+=)(\{[^}]*\})" | |
matches = re.findall(pattern, call_steps) | |
result_buffer = {} # The stored format is as follows: {'result1': (000001.SH, 'Stock code of China Ping An'), 'result2': (df2, 'Stock data of China Ping An from January to June 2021')}. | |
output_buffer = [] # Store the variable names [result5, result6] that will be passed as the final output to the next task. | |
# print(task_output) | |
# | |
for match in matches: | |
step, content = match | |
content = content.replace("'", "\"") # Replace single quotes with double quotes. | |
print('==================') | |
print("\n\nstep:", step) | |
print('content:',content) | |
call_dict = json.loads(content) | |
print('It has parallel steps:', len(call_dict) / 4) | |
output_text = output_text + step + ': ' + str(call_dict) + '\n\n' | |
# Execute the following code in parallel using multiple processes. | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
# Submit tasks to thread pool | |
futures = {executor.submit(parse_and_exe, call_dict, result_buffer, str(parallel_step)) | |
for parallel_step in range(1, int(len(call_dict) / 4) + 1)} | |
# Collect results as they become available | |
for idx, future in enumerate(concurrent.futures.as_completed(futures)): | |
# Handle possible exceptions | |
try: | |
result = future.result() | |
# Print the current parallel step number. | |
print('parallel step:', idx+1) | |
# print(list(result[1].keys())[0]) | |
# print(list(result[1].values())[0]) | |
except Exception as exc: | |
print(f'Generated an exception: {exc}') | |
if step == matches[-1][0]: | |
# Current task's final step. Save the output of the final step. | |
for parallel_step in range(1, int(len(call_dict) / 4) + 1): | |
output_buffer.append(call_dict['output' + str(parallel_step)]) | |
output_text = output_text + '\n' | |
if add_to_queue is not None: | |
add_to_queue(output_text) | |
################################# Step-3:visualization ########################################### | |
print('===============================Visualization Stage===========================================') | |
output_text = output_text + '======Visualization Stage====\n\n' | |
task_name = list(task_plan.keys())[1].split('_task')[0] #visualization_task | |
#task_name = 'visualization' | |
task_instruction = list(task_plan.values())[1] #'' | |
tool_lib = './tool_lib/' + 'tool_' + task_name + '.json' | |
tool_prompt = './prompt_lib/' + 'prompt_' + task_name + '.json' | |
result_buffer_viz={} | |
Previous_result = {} | |
for output_name in output_buffer: | |
rename = 'input'+ str(output_buffer.index(output_name)+1) | |
Previous_result[rename] = result_buffer[output_name][1] | |
result_buffer_viz[rename] = result_buffer[output_name] | |
prompt_flat = load_tool_and_prompt(tool_lib, tool_prompt) | |
prompt_flat = prompt_flat + '\n\n' +'Instruction: '+ task_instruction + ', Previous_result: '+ str(Previous_result) + ' ###Function Call' | |
# current_time = datetime.datetime.now() | |
# sleep_time = check_RPM(run_time, current_time) | |
# if sleep_time > 0: | |
# time.sleep(sleep_time) | |
try: | |
response = send_chat_request_Azure(prompt_flat, openai_key=openai_key, api_base=api_base, engine=engine) | |
except Exception as e: | |
return e | |
call_steps, _ = response.split('###') | |
pattern = r"(step\d+=)(\{[^}]*\})" | |
matches = re.findall(pattern, call_steps) | |
for match in matches: | |
step, content = match | |
content = content.replace("'", "\"") # Replace single quotes with double quotes. | |
print('==================') | |
print("\n\nstep:", step) | |
print('content:',content) | |
call_dict = json.loads(content) | |
print('It has parallel steps:', len(call_dict) / 4) | |
result_buffer_viz = parse_and_exe(call_dict, result_buffer_viz, parallel_step = '' ) | |
output_text = output_text + step + ': ' + str(call_dict) + '\n\n' | |
if add_to_queue is not None: | |
add_to_queue(output_text) | |
finally_output = list(result_buffer_viz.values()) # plt.Axes | |
# | |
df = pd.DataFrame() | |
str_out = output_text + 'Finally result: ' | |
for ax in finally_output: | |
if isinstance(ax[0], plt.Axes): # If the output is plt.Axes, display it. | |
plt.grid() | |
#plt.show() | |
str_out = str_out + ax[1]+ ':' + 'plt.Axes' + '\n\n' | |
# | |
elif isinstance(ax[0], pd.DataFrame): | |
df = ax[0] | |
str_out = str_out + ax[1]+ ':' + 'pd.DataFrame' + '\n\n' | |
else: | |
str_out = str_out + str(ax[1])+ ':' + str(ax[0]) + '\n\n' | |
# | |
print('===============================Summary Stage===========================================') | |
output_prompt = "请用第一人称总结一下整个任务规划和解决过程,并且输出结果,用[Task]表示每个规划任务,用\{function\}表示每个任务里调用的函数." + \ | |
"示例1:###我用将您的问题拆分成两个任务,首先第一个任务[stock_task],我依次获取五粮液和贵州茅台从2013年5月20日到2023年5月20日的净资产回报率roe的时序数据. \n然后第二个任务[visualization_task],我用折线图绘制五粮液和贵州茅台从2013年5月20日到2023年5月20日的净资产回报率,并计算它们的平均值和中位数. \n\n在第一个任务中我分别使用了2个工具函数\{get_stock_code\},\{get_Financial_data_from_time_range\}获取到两只股票的roe数据,在第二个任务里我们使用折线图\{plot_stock_data\}工具函数来绘制他们的roe十年走势,最后并计算了两只股票十年ROE的中位数\{output_median_col\}和均值\{output_mean_col\}.\n\n最后贵州茅台的ROE的均值和中位数是\{\},{},五粮液的ROE的均值和中位数是\{\},\{\}###" + \ | |
"示例2:###我用将您的问题拆分成两个任务,首先第一个任务[stock_task],我依次获取20230101到20230520这段时间北向资金每日净流入和每日累计流入时序数据,第二个任务是[visualization_task],因此我在同一张图里同时绘制北向资金20230101到20230520的每日净流入柱状图和每日累计流入的折线图 \n\n为了完成第一个任务中我分别使用了2个工具函数\{get_north_south_money\},\{calculate_stock_index\}分别获取到北上资金的每日净流入量和每日的累计净流入量,第二个任务里我们使用折线图\{plot_stock_data\}绘制来两个指标的变化走势.\n\n最后我们给您提供了包含两个指标的折线图和数据表格." + \ | |
"示例3:###我用将您的问题拆分成两个任务,首先第一个任务[economic_task],我爬取了上市公司贵州茅台和其主营业务介绍信息. \n然后第二个任务[visualization_task],我用表格打印贵州茅台及其相关信息. \n\n在第一个任务中我分别使用了1个工具函数\{get_company_info\} 获取到贵州茅台的公司信息,在第二个任务里我们使用折线图\{print_save_table\}工具函数来输出表格.\n" | |
try: | |
output_result = send_chat_request_Azure(output_prompt + str_out + '###', openai_key=openai_key, api_base=api_base,engine=engine) | |
except Exception as e: | |
return e | |
print(output_result) | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
# | |
# | |
image = Image.open(buf) | |
return output_text, image, output_result, df | |
def gradio_interface(query, openai_key, openai_key_azure, api_base, engine): | |
# Create a new thread to run the function. | |
placeholder_dataframe = pd.DataFrame() | |
placeholder_image = np.zeros((100, 100, 3), dtype=np.uint8) # Create a placeholder image. | |
try: | |
if openai_key.startswith('sk') and openai_key_azure == '': | |
print('send_official_call') | |
thread = MyThread(target=run, args=(query, add_to_queue, send_official_call, openai_key)) | |
elif openai_key =='' and len(openai_key_azure)>0: | |
print('send_chat_request_Azure') | |
thread = MyThread(target=run, args=(query, add_to_queue, send_chat_request_Azure, openai_key_azure, api_base, engine)) | |
thread.start() | |
# | |
# Wait for the result of the calculate function and display the intermediate results simultaneously. | |
while thread.is_alive(): | |
while not intermediate_results.empty(): | |
yield intermediate_results.get(), placeholder_image, 'Running' , placeholder_dataframe # Use the yield keyword to return intermediate results in real-time | |
time.sleep(0.1) # Avoid excessive resource consumption. | |
finally_text, img, output, df = thread.get_result() | |
yield finally_text, img, output, df | |
except Exception as e: | |
yield str(e), placeholder_image, str(e), placeholder_dataframe | |
# Return the final result. | |
instruction = '画一下五粮液和泸州老窖从2019年年初到2022年年中的收益率走势' | |
if __name__ == '__main__': | |
# 初始化pro接口 | |
#openai_call = send_chat_request_Azure # | |
openai_call = send_official_call # | |
openai_key = os.getenv("OPENAI_KEY") | |
output, image, df , output_result = run(instruction, send_chat_request_Azure = openai_call, openai_key=openai_key, api_base='', engine='') | |
print(output_result) | |
plt.show() | |