File size: 3,231 Bytes
e4d1c0f
a19817b
e4d1c0f
 
a19817b
e4d1c0f
a19817b
 
 
e4d1c0f
 
a19817b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4d1c0f
a19817b
e4d1c0f
a19817b
e4d1c0f
 
 
 
a19817b
 
 
e4d1c0f
a19817b
e4d1c0f
a19817b
 
 
e4d1c0f
 
a19817b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4d1c0f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import torch
import numpy as np

from sentence_transformers import SentenceTransformer, util

# 1. Load your fine-tuned retrieval model (on CodeSearchNet - Python)
#    This is the model you pushed to the Hugging Face Hub after training.
model_name = "juanwisz/modernbert-python-code-retrieval"
device = "cuda" if torch.cuda.is_available() else "cpu"

# SentenceTransformer automatically handles tokenizer + embedding
embedding_model = SentenceTransformer(model_name, device=device)

# 2. Define a function to:
#    - Parse code snippets from the text box (split by "---")
#    - Compute embeddings for the user’s query and each snippet
#    - Return the top 3 most relevant code snippets based on cosine similarity
def retrieve_top_snippets(query, code_input):
    # Split the code snippets by "---"
    # Each snippet is trimmed for cleanliness
    snippets = [s.strip() for s in code_input.split("---") if s.strip()]

    # Edge-case: if user provided no code, just return
    if len(snippets) == 0:
        return "No code snippets detected (make sure to separate them with ---)."

    # Embed the query and code snippets
    query_emb = embedding_model.encode(query, convert_to_tensor=True)
    snippets_emb = embedding_model.encode(snippets, convert_to_tensor=True)

    # Compute cosine similarities [batch_size x 1] with all code snippets
    cos_scores = util.cos_sim(query_emb, snippets_emb)[0]

    # Sort results by decreasing score
    # argsort(descending) means the first indices are the most relevant
    top_indices = torch.topk(cos_scores, k=min(3, len(snippets))).indices

    # Prepare text output with top 3 matches
    results = []
    for idx in top_indices:
        score = cos_scores[idx].item()
        snippet_text = snippets[idx]
        results.append(f"**Score**: {score:.4f}\n```python\n{snippet_text}\n```")

    # Join all results nicely
    return "\n\n".join(results)


#####################
### Gradio Layout ###
#####################
css = """
#container {
    margin: 0 auto;
    max-width: 700px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Code Retrieval using ModernBERT\n"
                "Enter a natural language query and paste multiple Python code snippets, "
                "delimited by `---`. We'll return the top 3 matches.")

    with gr.Column(elem_id="container"):
        with gr.Row():
            query_input = gr.Textbox(
                label="Natural Language Query",
                placeholder="What does your function do? e.g., 'Parse JSON from a string'"
            )

        code_snippets_input = gr.Textbox(
            label="Paste Python functions (delimited by ---)",
            lines=10,
            placeholder="Example:\n---\ndef parse_json(data):\n    return json.loads(data)\n---\ndef add_numbers(a, b):\n    return a + b\n---"
        )

        search_btn = gr.Button("Search", variant="primary")
        results_output = gr.Markdown(label="Top 3 Matches")

        # On click, run our retrieval function
        search_btn.click(
            fn=retrieve_top_snippets,
            inputs=[query_input, code_snippets_input],
            outputs=results_output
        )

if __name__ == "__main__":
    demo.launch()