Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from flask import Flask, request, jsonify
|
| 2 |
from sentence_transformers import SentenceTransformer, util
|
| 3 |
import logging
|
|
|
|
| 4 |
|
| 5 |
app = Flask(__name__)
|
| 6 |
|
|
@@ -8,7 +9,7 @@ app = Flask(__name__)
|
|
| 8 |
logging.basicConfig(level=logging.INFO)
|
| 9 |
app.logger = logging.getLogger("CodeSearchAPI")
|
| 10 |
|
| 11 |
-
#
|
| 12 |
CODE_SNIPPETS = [
|
| 13 |
"""def sort_list(x): return sorted(x)""",
|
| 14 |
"""def count_above_threshold(elements, threshold=0):
|
|
@@ -17,66 +18,61 @@ CODE_SNIPPETS = [
|
|
| 17 |
return min(elements), max(elements)"""
|
| 18 |
]
|
| 19 |
|
| 20 |
-
#
|
| 21 |
-
|
| 22 |
-
cleaned = []
|
| 23 |
-
for idx, s in enumerate(snippets):
|
| 24 |
-
if not isinstance(s, str):
|
| 25 |
-
app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串")
|
| 26 |
-
s = str(s)
|
| 27 |
-
cleaned.append(s.replace("...", "").strip())
|
| 28 |
-
return [s for s in cleaned if len(s) > 0]
|
| 29 |
|
| 30 |
-
# 初始化模型和编码
|
| 31 |
try:
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
except Exception as e:
|
| 37 |
-
app.logger.error(f"
|
| 38 |
raise
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
@app.route('/search', methods=['POST'])
|
| 41 |
def handle_search():
|
| 42 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 43 |
try:
|
| 44 |
# 请求验证
|
| 45 |
if not request.is_json:
|
| 46 |
-
app.logger.warning("无效的 Content-Type")
|
| 47 |
return jsonify({"error": "需要 application/json"}), 415
|
| 48 |
|
| 49 |
data = request.get_json()
|
| 50 |
query = data.get('query', '').strip()
|
| 51 |
|
| 52 |
if not query:
|
| 53 |
-
app.logger.warning("收到空查询")
|
| 54 |
return jsonify({"error": "查询不能为空"}), 400
|
| 55 |
|
| 56 |
-
#
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
best = hits[0]
|
| 67 |
-
result = {
|
| 68 |
-
"code": valid_snippets[best['corpus_id']],
|
| 69 |
-
"score": round(float(best['score']), 4)
|
| 70 |
-
}
|
| 71 |
-
app.logger.info(f"成功处理查询: '{query}'")
|
| 72 |
-
return jsonify(result)
|
| 73 |
-
|
| 74 |
-
except IndexError:
|
| 75 |
-
app.logger.error("无匹配结果")
|
| 76 |
-
return jsonify({"error": "无可用匹配"}), 404
|
| 77 |
-
|
| 78 |
except Exception as e:
|
| 79 |
-
app.logger.error(f"
|
| 80 |
return jsonify({"error": "服务器内部错误"}), 500
|
| 81 |
|
| 82 |
if __name__ == "__main__":
|
|
|
|
| 1 |
from flask import Flask, request, jsonify
|
| 2 |
from sentence_transformers import SentenceTransformer, util
|
| 3 |
import logging
|
| 4 |
+
import os
|
| 5 |
|
| 6 |
app = Flask(__name__)
|
| 7 |
|
|
|
|
| 9 |
logging.basicConfig(level=logging.INFO)
|
| 10 |
app.logger = logging.getLogger("CodeSearchAPI")
|
| 11 |
|
| 12 |
+
# 预定义代码片段
|
| 13 |
CODE_SNIPPETS = [
|
| 14 |
"""def sort_list(x): return sorted(x)""",
|
| 15 |
"""def count_above_threshold(elements, threshold=0):
|
|
|
|
| 18 |
return min(elements), max(elements)"""
|
| 19 |
]
|
| 20 |
|
| 21 |
+
# 初始化标记
|
| 22 |
+
model_ready = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
|
|
|
| 24 |
try:
|
| 25 |
+
# 初始化模型(使用预下载的缓存)
|
| 26 |
+
model = SentenceTransformer(
|
| 27 |
+
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
|
| 28 |
+
cache_folder=os.getenv("HF_HOME")
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# 预计算编码
|
| 32 |
+
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True)
|
| 33 |
+
model_ready = True
|
| 34 |
+
app.logger.info("模型加载完成,服务就绪")
|
| 35 |
except Exception as e:
|
| 36 |
+
app.logger.error(f"模型初始化失败: {str(e)}")
|
| 37 |
raise
|
| 38 |
|
| 39 |
+
@app.route('/health')
|
| 40 |
+
def health_check():
|
| 41 |
+
"""健康检查端点"""
|
| 42 |
+
if model_ready:
|
| 43 |
+
return jsonify({"status": "ready"}), 200
|
| 44 |
+
else:
|
| 45 |
+
return jsonify({"status": "initializing"}), 503
|
| 46 |
+
|
| 47 |
@app.route('/search', methods=['POST'])
|
| 48 |
def handle_search():
|
| 49 |
+
"""搜索请求处理"""
|
| 50 |
+
if not model_ready:
|
| 51 |
+
return jsonify({"error": "服务正在初始化"}), 503
|
| 52 |
+
|
| 53 |
try:
|
| 54 |
# 请求验证
|
| 55 |
if not request.is_json:
|
|
|
|
| 56 |
return jsonify({"error": "需要 application/json"}), 415
|
| 57 |
|
| 58 |
data = request.get_json()
|
| 59 |
query = data.get('query', '').strip()
|
| 60 |
|
| 61 |
if not query:
|
|
|
|
| 62 |
return jsonify({"error": "查询不能为空"}), 400
|
| 63 |
|
| 64 |
+
# 处理查询
|
| 65 |
+
query_emb = model.encode(query, convert_to_tensor=True)
|
| 66 |
+
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
|
| 67 |
+
best = hits[0]
|
| 68 |
+
|
| 69 |
+
return jsonify({
|
| 70 |
+
"code": CODE_SNIPPETS[best['corpus_id']],
|
| 71 |
+
"score": round(float(best['score']), 4)
|
| 72 |
+
})
|
| 73 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
except Exception as e:
|
| 75 |
+
app.logger.error(f"请求处理失败: {str(e)}")
|
| 76 |
return jsonify({"error": "服务器内部错误"}), 500
|
| 77 |
|
| 78 |
if __name__ == "__main__":
|