Spaces:
Running
Running
File size: 2,333 Bytes
90f5392 892f484 79ce9cc 90f5392 892f484 90f5392 79ce9cc 892f484 79ce9cc 892f484 79ce9cc 892f484 79ce9cc 892f484 79ce9cc 892f484 79ce9cc aacc39b 0db0051 892f484 0db0051 892f484 0db0051 892f484 79ce9cc aacc39b 79ce9cc 892f484 90f5392 0db0051 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
import os
app = Flask(__name__)
# 配置日志
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
"""def sort_list(x): return sorted(x)""",
"""def count_above_threshold(elements, threshold=0):
return sum(1 for e in elements if e > threshold)""",
"""def find_min_max(elements):
return min(elements), max(elements)"""
]
# 初始化标记
model_ready = False
try:
# 初始化模型(使用预下载的缓存)
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder=os.getenv("HF_HOME")
)
# 预计算编码
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True)
model_ready = True
app.logger.info("模型加载完成,服务就绪")
except Exception as e:
app.logger.error(f"模型初始化失败: {str(e)}")
raise
@app.route('/health')
def health_check():
"""健康检查端点"""
if model_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
@app.route('/search', methods=['POST'])
def handle_search():
"""搜索请求处理"""
if not model_ready:
return jsonify({"error": "服务正在初始化"}), 503
try:
# 请求验证
if not request.is_json:
return jsonify({"error": "需要 application/json"}), 415
data = request.get_json()
query = data.get('query', '').strip()
if not query:
return jsonify({"error": "查询不能为空"}), 400
# 处理查询
query_emb = model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
return jsonify({
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
})
except Exception as e:
app.logger.error(f"请求处理失败: {str(e)}")
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8080) |