InteractiveSurvey / src /demo /path_utils.py
technicolor's picture
update
80c0e03
raw
history blame
2.83 kB
import os
import tempfile
# 设置 Hugging Face 缓存目录
def setup_hf_cache():
"""设置 Hugging Face 缓存目录,在 Hugging Face Spaces 中使用临时目录"""
if os.environ.get('SPACE_ID') or os.environ.get('HF_SPACE_ID'):
# 在 Hugging Face Spaces 中使用临时目录作为缓存
cache_dir = tempfile.mkdtemp()
os.environ['HF_HOME'] = cache_dir
os.environ['HF_HUB_CACHE'] = os.path.join(cache_dir, 'hub')
print(f"Using Hugging Face cache directory: {cache_dir}")
return cache_dir
else:
# 本地环境使用默认缓存目录
return None
# 检查是否在 Hugging Face Spaces 环境中
def get_data_paths():
# 如果在 Hugging Face Spaces 中,使用临时目录
if os.environ.get('SPACE_ID') or os.environ.get('HF_SPACE_ID'):
# 使用临时目录
temp_dir = tempfile.mkdtemp()
return {
'DATA_PATH': os.path.join(temp_dir, 'pdf/'),
'TXT_PATH': os.path.join(temp_dir, 'txt/'),
'TSV_PATH': os.path.join(temp_dir, 'tsv/'),
'MD_PATH': os.path.join(temp_dir, 'md/'),
'INFO_PATH': os.path.join(temp_dir, 'info/'),
'IMG_PATH': os.path.join(temp_dir, 'img/'),
'RESULTS_PATH': os.path.join(temp_dir, 'results/')
}
else:
# 本地环境使用原来的路径
return {
'DATA_PATH': './src/static/data/pdf/',
'TXT_PATH': './src/static/data/txt/',
'TSV_PATH': './src/static/data/tsv/',
'MD_PATH': './src/static/data/md/',
'INFO_PATH': './src/static/data/info/',
'IMG_PATH': './src/static/img/',
'RESULTS_PATH': './src/static/data/results/'
}
# 全局路径管理函数
def get_path(path_type, survey_id=None, filename=None):
"""
获取动态路径
path_type: 'pdf', 'txt', 'tsv', 'md', 'info', 'img', 'results'
survey_id: 可选的调查ID
filename: 可选的文件名
"""
paths_config = get_data_paths()
if path_type == 'pdf':
base_path = paths_config['DATA_PATH']
elif path_type == 'txt':
base_path = paths_config['TXT_PATH']
elif path_type == 'tsv':
base_path = paths_config['TSV_PATH']
elif path_type == 'md':
base_path = paths_config['MD_PATH']
elif path_type == 'info':
base_path = paths_config['INFO_PATH']
elif path_type == 'img':
base_path = paths_config['IMG_PATH']
elif path_type == 'results':
base_path = paths_config['RESULTS_PATH']
else:
raise ValueError(f"Unknown path type: {path_type}")
if survey_id:
base_path = os.path.join(base_path, str(survey_id))
if filename:
return os.path.join(base_path, filename)
return base_path