Spaces:
Running
Running
File size: 3,304 Bytes
90f5392 892f484 7c08c7b 90f5392 7c08c7b 90f5392 c79e0ff 892f484 90f5392 79ce9cc 892f484 c79e0ff 892f484 7c08c7b 892f484 7c08c7b 892f484 7c08c7b 79ce9cc 7c08c7b 79ce9cc 7c08c7b c79e0ff 7c08c7b 892f484 7c08c7b 892f484 7c08c7b 79ce9cc c79e0ff 892f484 7c08c7b 79ce9cc c79e0ff aacc39b c79e0ff 892f484 c79e0ff 892f484 c79e0ff 7c08c7b c79e0ff 79ce9cc c79e0ff 79ce9cc c79e0ff 79ce9cc aacc39b c79e0ff 892f484 90f5392 0db0051 7c08c7b c79e0ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal
# 初始化Flask应用
app = Flask(__name__)
# 配置日志,日志级别设为INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
"def sort_list(x): return sorted(x)",
"""def count_above_threshold(elements, threshold=0):
return sum(1 for e in elements if e > threshold)""",
"""def find_min_max(elements):
return min(elements), max(elements)"""
]
# 全局服务状态
service_ready = False
# 优雅关闭处理
def handle_shutdown(signum, frame):
app.logger.info("收到终止信号,开始关闭...")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
# 初始化模型和编码
try:
# Hugging Face Spaces专用缓存路径
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder="/model-cache"
)
# 预计算编码(强制使用CPU)
code_emb = model.encode(CODE_SNIPPETS,
convert_to_tensor=True,
device="cpu")
service_ready = True
app.logger.info("服务初始化完成")
except Exception as e:
app.logger.error(f"初始化失败: {str(e)}")
raise
# Hugging Face健康检查端点
@app.route('/')
def hf_health_check():
"""必须响应根路径的健康检查"""
if service_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
# 支持GET和POST请求的搜索API端点
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
if not service_ready:
return jsonify({"error": "服务正在初始化"}), 503
try:
# 区分GET和POST请求,GET从URL参数中获取query,POST从JSON体中获取
if request.method == 'GET':
query = request.args.get('query', '').strip()
else:
data = request.get_json() or {}
query = data.get('query', '').strip()
if not query:
app.logger.info("收到空的查询请求")
return jsonify({"error": "查询不能为空"}), 400
# 记录接收到的查询
app.logger.info("收到查询请求: %s", query)
# 对查询进行编码,并搜索最匹配的代码片段
query_emb = model.encode(query,
convert_to_tensor=True,
device="cpu")
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
result = {
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
}
# 记录返回结果
app.logger.info("返回结果: %s", result)
return jsonify(result)
except Exception as e:
app.logger.error("请求处理失败: %s", str(e))
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
# Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
app.run(host='0.0.0.0', port=7860)
|