File size: 2,333 Bytes
90f5392
 
892f484
79ce9cc
90f5392
 
 
892f484
 
 
90f5392
79ce9cc
892f484
 
 
 
 
 
 
 
79ce9cc
 
892f484
 
79ce9cc
 
 
 
 
 
 
 
 
 
892f484
79ce9cc
892f484
 
79ce9cc
 
 
 
 
 
 
 
892f484
 
79ce9cc
 
 
 
aacc39b
0db0051
 
892f484
0db0051
 
892f484
0db0051
892f484
 
 
79ce9cc
 
 
 
 
 
 
 
 
 
aacc39b
79ce9cc
892f484
90f5392
0db0051
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
import os

app = Flask(__name__)

# 配置日志
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")

# 预定义代码片段
CODE_SNIPPETS = [
    """def sort_list(x): return sorted(x)""",
    """def count_above_threshold(elements, threshold=0):
    return sum(1 for e in elements if e > threshold)""",
    """def find_min_max(elements):
    return min(elements), max(elements)"""
]

# 初始化标记
model_ready = False

try:
    # 初始化模型(使用预下载的缓存)
    model = SentenceTransformer(
        "flax-sentence-embeddings/st-codesearch-distilroberta-base",
        cache_folder=os.getenv("HF_HOME")
    )
    
    # 预计算编码
    code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True)
    model_ready = True
    app.logger.info("模型加载完成,服务就绪")
except Exception as e:
    app.logger.error(f"模型初始化失败: {str(e)}")
    raise

@app.route('/health')
def health_check():
    """健康检查端点"""
    if model_ready:
        return jsonify({"status": "ready"}), 200
    else:
        return jsonify({"status": "initializing"}), 503

@app.route('/search', methods=['POST'])
def handle_search():
    """搜索请求处理"""
    if not model_ready:
        return jsonify({"error": "服务正在初始化"}), 503
        
    try:
        # 请求验证
        if not request.is_json:
            return jsonify({"error": "需要 application/json"}), 415
            
        data = request.get_json()
        query = data.get('query', '').strip()
        
        if not query:
            return jsonify({"error": "查询不能为空"}), 400
            
        # 处理查询
        query_emb = model.encode(query, convert_to_tensor=True)
        hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
        best = hits[0]
        
        return jsonify({
            "code": CODE_SNIPPETS[best['corpus_id']],
            "score": round(float(best['score']), 4)
        })
        
    except Exception as e:
        app.logger.error(f"请求处理失败: {str(e)}")
        return jsonify({"error": "服务器内部错误"}), 500

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=8080)