dhuynh95 commited on
Commit
b9e5451
·
1 Parent(s): be1b538

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -3
app.py CHANGED
@@ -5,11 +5,144 @@ import io
5
  import sys
6
  import os
7
  import requests
 
 
8
 
9
  api_base = "https://api.endpoints.anyscale.com/v1"
10
  token = os.environ["OPENAI_API_KEY"]
11
  url = f"{api_base}/chat/completions"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def generate_test(code):
14
  s = requests.Session()
15
  message = "Write me a test of this function\n{}".format(code)
@@ -93,9 +226,9 @@ with gr.Blocks() as demo:
93
 
94
  generate_btn = gr.Button("Generate test")
95
  with gr.Row():
96
- code_output = gr.Code()
97
- code_output2 = gr.Code()
98
 
99
- generate_btn.click(execute_code, outputs=code_output)
100
  if __name__ == "__main__":
101
  demo.launch()
 
5
  import sys
6
  import os
7
  import requests
8
+ import ast
9
+
10
 
11
  api_base = "https://api.endpoints.anyscale.com/v1"
12
  token = os.environ["OPENAI_API_KEY"]
13
  url = f"{api_base}/chat/completions"
14
 
15
+ def extract_functions_from_file(filename):
16
+ """Given a file written to disk, extract all functions from it into a list."""
17
+ with open(filename, "r") as file:
18
+ tree = ast.parse(file.read())
19
+
20
+ functions = []
21
+
22
+ for node in ast.walk(tree):
23
+ if isinstance(node, ast.FunctionDef):
24
+ start_line = node.lineno
25
+ end_line = node.end_lineno if hasattr(node, "end_lineno") else start_line
26
+ with open(filename, "r") as file:
27
+ function_code = "".join(
28
+ [
29
+ line
30
+ for i, line in enumerate(file)
31
+ if start_line <= i + 1 <= end_line
32
+ ]
33
+ )
34
+ functions.append(function_code)
35
+
36
+ return functions
37
+
38
+
39
+ def extract_tests_from_list(l):
40
+ """Given a list of strings, extract all functions from it into a list."""
41
+ return [t for t in l if t.startswith("def")]
42
+
43
+
44
+ def remove_leading_whitespace(func_str):
45
+ """Given a string representing a function, remove the leading whitespace from each
46
+ line such that the function definition is left-aligned and all following lines
47
+ follow Python's whitespace formatting rules.
48
+ """
49
+ lines = func_str.split("\n")
50
+ # Find the amount of whitespace before 'def' (the function signature)
51
+ leading_whitespace = len(lines[0]) - len(lines[0].lstrip())
52
+ # Remove that amount of whitespace from each line
53
+ new_lines = [line[leading_whitespace:] for line in lines if line.strip()]
54
+ return "\n".join(new_lines)
55
+
56
+
57
+ def main(fxn: str, examples: str = "", temperature: float = 0.7):
58
+ """Requires Anyscale Endpoints Alpha API access.
59
+
60
+ If examples is not a empty string, it will be formatted into
61
+ a list of input/output pairs used to prompt the model.
62
+ """
63
+
64
+ s = requests.Session()
65
+ api_base = os.environ["OPENAI_API_BASE"]
66
+ token = os.environ["OPENAI_API_KEY"]
67
+ url = f"{api_base}/chat/completions"
68
+
69
+ message = "Write me a test of this function\n{}".format(fxn)
70
+
71
+ if examples:
72
+ message += "\nExample input output pairs:\n"
73
+
74
+ system_prompt = """
75
+ You are a helpful coding assistant.
76
+ Your job is to help people write unit tests for their python code. Please write all
77
+ unit tests in the format expected by pytest. If inputs and outputs are provided,
78
+ return a set of unit tests that will verify that the function will produce the
79
+ corect outputs. Also provide tests to handle base and edge cases. It is very
80
+ important that the code is formatted correctly for pytest.
81
+ """
82
+
83
+ body = {
84
+ "model": "meta-llama/Llama-2-70b-chat-hf",
85
+ "messages": [
86
+ {"role": "system", "content": system_prompt},
87
+ {"role": "user", "content": message},
88
+ ],
89
+ "temperature": temperature,
90
+ }
91
+
92
+ with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
93
+ response = resp.json()["choices"][0]
94
+
95
+ if response["finish_reason"] != "stop":
96
+ raise ValueError("Print please try again -- response was not finished!")
97
+
98
+ # Parse the response to get the tests out.
99
+ split_response = response["message"]["content"].split("```")
100
+ if len(split_response) != 3:
101
+ raise ValueError("Please try again -- response generated too many code blocks!")
102
+
103
+ all_tests = split_response[1]
104
+
105
+ # Writes out all tests to a file. Then, extracts each individual test out into a
106
+ # list.
107
+ with tempfile.NamedTemporaryFile(
108
+ prefix="all_tests_", suffix=".py", mode="w"
109
+ ) as temp:
110
+ temp.writelines(all_tests)
111
+ temp.flush()
112
+ parsed_tests = extract_functions_from_file(temp.name)
113
+
114
+ # Loop through test, run pytest, and return two lists of tests.
115
+ passed_tests, failed_tests = [], []
116
+ for test in parsed_tests:
117
+ test_formatted = remove_leading_whitespace(test)
118
+
119
+ print("testing: \n {}".format(test_formatted))
120
+
121
+ with tempfile.NamedTemporaryFile(
122
+ prefix="test_", suffix=".py", mode="w"
123
+ ) as temp:
124
+ # Writes out each test to a file. Then, runs pytest on that file.
125
+ full_test_file = "#!/usr/bin/env python\n\nimport pytest\n{}\n{}".format(
126
+ fxn, test_formatted
127
+ )
128
+ temp.writelines(full_test_file)
129
+ temp.flush()
130
+
131
+ retcode = pytest.main(["-x", temp.name])
132
+
133
+ print(retcode.name)
134
+ if retcode.name == "TESTS_FAILED":
135
+ failed_tests.append(test)
136
+ print("test failed")
137
+
138
+ elif retcode.name == "OK":
139
+ passed_tests.append(test)
140
+ print("test passed")
141
+
142
+ passed_tests = "\n".join(passed_tests)
143
+ failed_tests = "\n".join(failed_tests)
144
+ return passed_tests, failed_tests
145
+
146
  def generate_test(code):
147
  s = requests.Session()
148
  message = "Write me a test of this function\n{}".format(code)
 
226
 
227
  generate_btn = gr.Button("Generate test")
228
  with gr.Row():
229
+ code_output = gr.Code(label="Passed tests")
230
+ code_output2 = gr.Code(label="Failed tests")
231
 
232
+ generate_btn.click(main, inputs=code_input, outputs=[code_output, code_output2])
233
  if __name__ == "__main__":
234
  demo.launch()