InteractiveSurvey / startup.py
technicolor's picture
update
80c0e03
#!/usr/bin/env python3
"""
启动脚本 - 设置环境变量以解决 Hugging Face Spaces 中的权限问题
"""
import os
import tempfile
import sys
def setup_environment():
"""设置所有必要的环境变量"""
# 检测是否在 Hugging Face Spaces 中运行
is_hf_space = bool(os.environ.get('SPACE_ID') or os.environ.get('HF_SPACE_ID'))
if is_hf_space:
print("检测到 Hugging Face Spaces 环境,设置临时目录...")
# 创建临时目录
temp_base = tempfile.mkdtemp()
# 设置 matplotlib 配置目录
matplotlib_dir = os.path.join(temp_base, "matplotlib")
os.makedirs(matplotlib_dir, exist_ok=True)
os.environ["MPLCONFIGDIR"] = matplotlib_dir
# 设置 XDG 缓存目录
os.environ["XDG_CACHE_HOME"] = temp_base
# 设置 numba 缓存目录
numba_dir = os.path.join(temp_base, "numba_cache")
os.makedirs(numba_dir, exist_ok=True)
os.environ["NUMBA_CACHE_DIR"] = numba_dir
# 设置 Hugging Face 缓存目录
hf_dir = os.path.join(temp_base, "hf_cache")
os.makedirs(hf_dir, exist_ok=True)
os.environ["HF_HOME"] = hf_dir
os.environ["HF_HUB_CACHE"] = os.path.join(hf_dir, "hub")
print(f"环境变量已设置:")
print(f" MPLCONFIGDIR: {matplotlib_dir}")
print(f" XDG_CACHE_HOME: {temp_base}")
print(f" NUMBA_CACHE_DIR: {numba_dir}")
print(f" HF_HOME: {hf_dir}")
print(f" HF_HUB_CACHE: {os.environ['HF_HUB_CACHE']}")
else:
print("本地环境,使用默认缓存目录")
return is_hf_space
def check_imports():
"""检查关键导入是否正常工作"""
try:
print("检查导入...")
# 检查 langchain 相关导入
from langchain_community.embeddings import HuggingFaceEmbeddings
print("✅ langchain_community.embeddings 导入成功")
from langchain_text_splitters import RecursiveCharacterTextSplitter
print("✅ langchain_text_splitters 导入成功")
# 检查其他关键库
import torch
print("✅ PyTorch 导入成功")
import transformers
print("✅ Transformers 导入成功")
import matplotlib
print("✅ Matplotlib 导入成功")
import numba
print("✅ Numba 导入成功")
return True
except ImportError as e:
print(f"❌ 导入失败: {e}")
return False
if __name__ == "__main__":
print("=== 环境设置脚本 ===")
# 设置环境变量
is_hf_space = setup_environment()
# 检查导入
if check_imports():
print("\n✅ 环境设置完成,所有导入正常")
if is_hf_space:
print("💡 提示: 在 Hugging Face Spaces 中使用临时目录作为缓存")
else:
print("💡 提示: 在本地环境中使用默认缓存目录")
else:
print("\n❌ 环境设置失败,请检查依赖安装")
sys.exit(1)