codesearchBase / app.py
Forrest99's picture
Update app.py
c79e0ff verified
raw
history blame
3.3 kB
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal
# 初始化Flask应用
app = Flask(__name__)
# 配置日志,日志级别设为INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
"def sort_list(x): return sorted(x)",
"""def count_above_threshold(elements, threshold=0):
return sum(1 for e in elements if e > threshold)""",
"""def find_min_max(elements):
return min(elements), max(elements)"""
]
# 全局服务状态
service_ready = False
# 优雅关闭处理
def handle_shutdown(signum, frame):
app.logger.info("收到终止信号,开始关闭...")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
# 初始化模型和编码
try:
# Hugging Face Spaces专用缓存路径
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder="/model-cache"
)
# 预计算编码(强制使用CPU)
code_emb = model.encode(CODE_SNIPPETS,
convert_to_tensor=True,
device="cpu")
service_ready = True
app.logger.info("服务初始化完成")
except Exception as e:
app.logger.error(f"初始化失败: {str(e)}")
raise
# Hugging Face健康检查端点
@app.route('/')
def hf_health_check():
"""必须响应根路径的健康检查"""
if service_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
# 支持GET和POST请求的搜索API端点
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
if not service_ready:
return jsonify({"error": "服务正在初始化"}), 503
try:
# 区分GET和POST请求,GET从URL参数中获取query,POST从JSON体中获取
if request.method == 'GET':
query = request.args.get('query', '').strip()
else:
data = request.get_json() or {}
query = data.get('query', '').strip()
if not query:
app.logger.info("收到空的查询请求")
return jsonify({"error": "查询不能为空"}), 400
# 记录接收到的查询
app.logger.info("收到查询请求: %s", query)
# 对查询进行编码,并搜索最匹配的代码片段
query_emb = model.encode(query,
convert_to_tensor=True,
device="cpu")
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
result = {
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
}
# 记录返回结果
app.logger.info("返回结果: %s", result)
return jsonify(result)
except Exception as e:
app.logger.error("请求处理失败: %s", str(e))
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
# Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
app.run(host='0.0.0.0', port=7860)