File size: 2,823 Bytes
3087bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf544db
3087bb1
 
 
 
 
 
 
 
 
 
 
 
cf544db
 
 
 
 
 
 
 
3087bb1
 
 
be1b538
 
 
 
cf544db
3087bb1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import tempfile
import pytest
import io
import sys
import os
import requests

api_base = "https://api.endpoints.anyscale.com/v1"
token = os.environ["OPENAI_API_KEY"]
url = f"{api_base}/chat/completions"

def generate_test(code):
    s = requests.Session()
    message = "Write me a test of this function\n{}".format(code)
    system_prompt = """
    You are a helpful coding assistant.
    Your job is to help people write unit tests for the python code.
    If inputs and outputs are provided, please return a set of unit tests that will
    verify that the function will produce the corect outputs. Also provide tests to
    handle base and edge cases.
    """

    body = {
        "model": "meta-llama/Llama-2-70b-chat-hf",
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        "temperature": 0.7,
    }

    with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
        response = resp.json()["choices"][0]

    if response["finish_reason"] != "stop":
        raise ValueError("Print please try again -- response was not finished!")

    split_response = response["message"]["content"].split("```")
    if len(split_response) != 3:
        raise ValueError("Please try again -- response generated too many code blocks!")


def execute_code(code, test):    

    # Capture the standard output in a StringIO object
    old_stdout = sys.stdout
    new_stdout = io.StringIO()
    sys.stdout = new_stdout

    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as f:
        f.writelines(code)
        f.writelines(test)
        f.flush()
        temp_path = f.name
        pytest.main(["-x", temp_path])

    # Restore the standard output
    sys.stdout = old_stdout

    # Get the captured output from the StringIO object
    output = new_stdout.getvalue()
    return output

examples = ["""
def prime_factors(n):
    i = 2
    factors = []
    while i * i <= n:
        if n % i:
            i += 1
        else:
            n //= i
            factors.append(i)
    if n > 1:
        factors.append(n)
    return factors
    """,
"""
import numpy
    def matrix_multiplication(A, B):
    return np.dot(A, B)
"""
    ]
example = examples[0]

with gr.Blocks() as demo:
    gr.Markdown("<h1><center>Llama_test: generate unit test for your Python code</center></h1>")
    code_input = gr.Code(example, label="Provide the code of the function you want to test")
    gr.Examples(
    examples=examples,
    inputs=code_input,)

    generate_btn = gr.Button("Generate test")
    with gr.Row():
        code_output = gr.Code()
        code_output2 = gr.Code()

    generate_btn.click(execute_code, outputs=code_output)
if __name__ == "__main__":
    demo.launch()