File size: 3,304 Bytes
90f5392
 
892f484
7c08c7b
 
90f5392
7c08c7b
90f5392
 
c79e0ff
892f484
 
90f5392
79ce9cc
892f484
c79e0ff
892f484
 
 
 
 
 
7c08c7b
 
892f484
7c08c7b
 
 
 
 
 
 
 
 
892f484
7c08c7b
79ce9cc
 
7c08c7b
79ce9cc
 
7c08c7b
 
c79e0ff
 
7c08c7b
 
 
892f484
7c08c7b
892f484
 
7c08c7b
 
 
 
 
79ce9cc
 
 
 
c79e0ff
 
892f484
7c08c7b
79ce9cc
c79e0ff
aacc39b
c79e0ff
 
 
 
 
 
 
892f484
c79e0ff
892f484
c79e0ff
 
 
 
 
7c08c7b
c79e0ff
 
79ce9cc
 
c79e0ff
 
79ce9cc
 
c79e0ff
 
 
 
 
79ce9cc
aacc39b
c79e0ff
892f484
90f5392
0db0051
7c08c7b
c79e0ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal

# 初始化Flask应用
app = Flask(__name__)

# 配置日志,日志级别设为INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")

# 预定义代码片段
CODE_SNIPPETS = [
    "def sort_list(x): return sorted(x)",
    """def count_above_threshold(elements, threshold=0):
    return sum(1 for e in elements if e > threshold)""",
    """def find_min_max(elements):
    return min(elements), max(elements)"""
]

# 全局服务状态
service_ready = False

# 优雅关闭处理
def handle_shutdown(signum, frame):
    app.logger.info("收到终止信号,开始关闭...")
    sys.exit(0)

signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)

# 初始化模型和编码
try:
    # Hugging Face Spaces专用缓存路径
    model = SentenceTransformer(
        "flax-sentence-embeddings/st-codesearch-distilroberta-base",
        cache_folder="/model-cache"
    )
    
    # 预计算编码(强制使用CPU)
    code_emb = model.encode(CODE_SNIPPETS, 
                            convert_to_tensor=True,
                            device="cpu")
    
    service_ready = True
    app.logger.info("服务初始化完成")
except Exception as e:
    app.logger.error(f"初始化失败: {str(e)}")
    raise

# Hugging Face健康检查端点
@app.route('/')
def hf_health_check():
    """必须响应根路径的健康检查"""
    if service_ready:
        return jsonify({"status": "ready"}), 200
    else:
        return jsonify({"status": "initializing"}), 503

# 支持GET和POST请求的搜索API端点
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
    if not service_ready:
        return jsonify({"error": "服务正在初始化"}), 503

    try:
        # 区分GET和POST请求,GET从URL参数中获取query,POST从JSON体中获取
        if request.method == 'GET':
            query = request.args.get('query', '').strip()
        else:
            data = request.get_json() or {}
            query = data.get('query', '').strip()

        if not query:
            app.logger.info("收到空的查询请求")
            return jsonify({"error": "查询不能为空"}), 400
        
        # 记录接收到的查询
        app.logger.info("收到查询请求: %s", query)

        # 对查询进行编码,并搜索最匹配的代码片段
        query_emb = model.encode(query, 
                                 convert_to_tensor=True,
                                 device="cpu")
        hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
        best = hits[0]

        result = {
            "code": CODE_SNIPPETS[best['corpus_id']],
            "score": round(float(best['score']), 4)
        }
        # 记录返回结果
        app.logger.info("返回结果: %s", result)

        return jsonify(result)
        
    except Exception as e:
        app.logger.error("请求处理失败: %s", str(e))
        return jsonify({"error": "服务器内部错误"}), 500

if __name__ == "__main__":
    # Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
    app.run(host='0.0.0.0', port=7860)