Spaces:
Running
Running
from flask import Flask, request, jsonify | |
from sentence_transformers import SentenceTransformer, util | |
import logging | |
app = Flask(__name__) | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO) | |
app.logger = logging.getLogger("CodeSearchAPI") | |
# 预定义代码片段(已验证数据) | |
CODE_SNIPPETS = [ | |
"""def sort_list(x): return sorted(x)""", | |
"""def count_above_threshold(elements, threshold=0): | |
return sum(1 for e in elements if e > threshold)""", | |
"""def find_min_max(elements): | |
return min(elements), max(elements)""" | |
] | |
# 输入数据验证 | |
def validate_snippets(snippets): | |
cleaned = [] | |
for idx, s in enumerate(snippets): | |
if not isinstance(s, str): | |
app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串") | |
s = str(s) | |
cleaned.append(s.replace("...", "").strip()) | |
return [s for s in cleaned if len(s) > 0] | |
# 初始化模型和编码 | |
try: | |
model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base") | |
valid_snippets = validate_snippets(CODE_SNIPPETS) | |
code_emb = model.encode(valid_snippets, convert_to_tensor=True) | |
app.logger.info(f"成功加载模型,编码 {len(valid_snippets)} 个有效代码片段") | |
except Exception as e: | |
app.logger.error(f"初始化失败: {str(e)}") | |
raise | |
def handle_search(): | |
"""API 处理端点""" | |
try: | |
# 请求验证 | |
if not request.is_json: | |
app.logger.warning("无效的 Content-Type") | |
return jsonify({"error": "需要 application/json"}), 415 | |
data = request.get_json() | |
query = data.get('query', '').strip() | |
if not query: | |
app.logger.warning("收到空查询") | |
return jsonify({"error": "查询不能为空"}), 400 | |
# 编码查询 | |
try: | |
query_emb = model.encode(query, convert_to_tensor=True) | |
except Exception as e: | |
app.logger.error(f"编码失败: {str(e)}") | |
return jsonify({"error": "查询处理失败"}), 500 | |
# 语义搜索 | |
try: | |
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] | |
best = hits[0] | |
result = { | |
"code": valid_snippets[best['corpus_id']], | |
"score": round(float(best['score']), 4) | |
} | |
app.logger.info(f"成功处理查询: '{query}'") | |
return jsonify(result) | |
except IndexError: | |
app.logger.error("无匹配结果") | |
return jsonify({"error": "无可用匹配"}), 404 | |
except Exception as e: | |
app.logger.error(f"未知错误: {str(e)}", exc_info=True) | |
return jsonify({"error": "服务器内部错误"}), 500 | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=8080) |