Spaces:
Runtime error
Runtime error
import json | |
import os | |
from typing import Tuple | |
import pandas as pd | |
from common.call_llm import chat | |
EXTRACT_ENDPOINT = os.environ.get("EXTRACT_ENDPOINT") | |
prompt_template = """### 角色能力 ### | |
你是一个信息提取助手,你可以按下面给出的提取字段及描述对文档内容进行信息提取,并按给定的格式返回。 | |
确保提取的信息完整且与文档内容一致,如果有字段提取不到对应信息请返回'无'。 | |
### 提取字段及描述 ### | |
{fields_prompt} | |
### 文档内容 ### | |
{context} | |
### 返回格式 ### | |
请严格按照下面描述的JSON格式进行输出,不需要解释,输出JSON格式如下: | |
{response_prompt} | |
确保输出的格式可以被Python的json.loads方法解析。 | |
""" | |
def extract_slots(page_content: str, extraction_df: "pd.DataFrame") -> Tuple[str, None]: | |
""" | |
Extract slots from page content | |
:param page_content: | |
:param extract_requirement: | |
:return: | |
""" | |
extract_requirement = "" | |
output_requirement = dict() | |
df = pd.DataFrame(columns=["字段名称", "字段抽取结果"]) | |
# remove nan | |
extraction_df = extraction_df[extraction_df['字段名称'].notna()] | |
for _, row in extraction_df.iterrows(): | |
if not row['字段名称'] or not row['字段描述']: | |
continue | |
extract_requirement += f"{row['字段名称']}: {row['字段描述']}\n" | |
output_requirement[row['字段名称']] = row['字段描述'] | |
if not output_requirement: | |
return df | |
output_requirement_description = json.dumps([output_requirement], ensure_ascii=False, indent=4) | |
prompt = prompt_template.format(context=page_content, fields_prompt=extract_requirement, response_prompt=output_requirement_description) | |
messages = [{"role": "user", "content": prompt}] | |
max_retry = 6 | |
retry = 0 | |
result = None | |
while not result and retry < max_retry: | |
try: | |
result = chat(messages=messages, endpoint=EXTRACT_ENDPOINT) | |
if result.startswith("```json"): | |
result = result.replace("```json", "").replace("```", "").strip() | |
result = json.loads(result) | |
if isinstance(result, list): | |
result = result[0] | |
except Exception as e: | |
print(f"error: {e} {result}") | |
result = None | |
retry += 1 | |
print(f"extract slots prompt: {prompt} result: {result}") | |
if not result: | |
return df | |
for field in output_requirement: | |
df.loc[len(df)] = {"字段名称": field, "字段抽取结果": result.get(field, "无")} | |
return df | |