sugiv commited on
Commit
0f5897f
·
1 Parent(s): 59754d5

Leetmonkey In Action via Inference

Browse files
Files changed (2) hide show
  1. app.py +180 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ import textwrap
5
+ import autopep8
6
+ import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+ from llama_cpp import Llama
9
+ import jwt
10
+ from typing import Generator
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # JWT settings
17
+ JWT_SECRET = os.environ.get("JWT_SECRET")
18
+ JWT_ALGORITHM = "HS256"
19
+
20
+ # Model settings
21
+ MODEL_NAME = "leetmonkey_peft__q8_0.gguf"
22
+ REPO_ID = "sugiv/leetmonkey-peft-gguf"
23
+
24
+ # Generation parameters
25
+ generation_kwargs = {
26
+ "max_tokens": 2048,
27
+ "stop": ["```", "### Instruction:", "### Response:"],
28
+ "echo": False,
29
+ "temperature": 0.2,
30
+ "top_k": 50,
31
+ "top_p": 0.95,
32
+ "repeat_penalty": 1.1
33
+ }
34
+
35
+ def download_model(model_name: str) -> str:
36
+ logger.info(f"Downloading model: {model_name}")
37
+ model_path = hf_hub_download(
38
+ repo_id=REPO_ID,
39
+ filename=model_name,
40
+ cache_dir="./models",
41
+ force_download=True,
42
+ resume_download=True
43
+ )
44
+ logger.info(f"Model downloaded: {model_path}")
45
+ return model_path
46
+
47
+ # Download and load the 8-bit model at startup
48
+ model_path = download_model(MODEL_NAME)
49
+ llm = Llama(
50
+ model_path=model_path,
51
+ n_ctx=2048,
52
+ n_threads=4,
53
+ n_gpu_layers=-1, # Use all available GPU layers
54
+ verbose=False
55
+ )
56
+ logger.info("8-bit model loaded successfully")
57
+
58
+ def generate_solution(instruction: str) -> str:
59
+ system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
60
+ full_prompt = f"""### Instruction:
61
+ {system_prompt}
62
+
63
+ Implement the following function for the LeetCode problem:
64
+
65
+ {instruction}
66
+
67
+ ### Response:
68
+ Here's the complete Python function implementation:
69
+
70
+ ```python
71
+ """
72
+
73
+ response = llm(full_prompt, **generation_kwargs)
74
+ return response["choices"][0]["text"]
75
+
76
+ def extract_and_format_code(text: str) -> str:
77
+ # Extract code between triple backticks
78
+ code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL)
79
+ if code_match:
80
+ code = code_match.group(1)
81
+ else:
82
+ code = text
83
+
84
+ # Remove any text before the function definition
85
+ code = re.sub(r'^.*?(?=def\s+\w+\s*\()', '', code, flags=re.DOTALL)
86
+
87
+ # Dedent the code to remove any common leading whitespace
88
+ code = textwrap.dedent(code)
89
+
90
+ # Split the code into lines
91
+ lines = code.split('\n')
92
+
93
+ # Find the function definition line
94
+ func_def_index = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), 0)
95
+
96
+ # Ensure proper indentation
97
+ indented_lines = [lines[func_def_index]] # Keep the function definition as is
98
+ for line in lines[func_def_index + 1:]:
99
+ if line.strip(): # If the line is not empty
100
+ indented_lines.append(' ' + line) # Add 4 spaces of indentation
101
+ else:
102
+ indented_lines.append(line) # Keep empty lines as is
103
+
104
+ formatted_code = '\n'.join(indented_lines)
105
+
106
+ try:
107
+ return autopep8.fix_code(formatted_code)
108
+ except:
109
+ return formatted_code
110
+
111
+ def verify_token(token: str) -> bool:
112
+ try:
113
+ jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
114
+ return True
115
+ except jwt.PyJWTError:
116
+ return False
117
+
118
+ def generate_solution_api(instruction: str, token: str) -> str:
119
+ if not verify_token(token):
120
+ return "Invalid token. Please provide a valid JWT token."
121
+
122
+ logger.info("Generating solution")
123
+ generated_output = generate_solution(instruction)
124
+ formatted_code = extract_and_format_code(generated_output)
125
+ logger.info("Solution generated successfully")
126
+ return formatted_code
127
+
128
+ def stream_solution_api(instruction: str, token: str) -> Generator[str, None, None]:
129
+ if not verify_token(token):
130
+ yield "Invalid token. Please provide a valid JWT token."
131
+ return
132
+
133
+ logger.info("Streaming solution")
134
+ system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions."
135
+ full_prompt = f"""### Instruction:
136
+ {system_prompt}
137
+
138
+ Implement the following function for the LeetCode problem:
139
+
140
+ {instruction}
141
+
142
+ ### Response:
143
+ Here's the complete Python function implementation:
144
+
145
+ ```python
146
+ """
147
+
148
+ generated_text = ""
149
+ for chunk in llm(full_prompt, stream=True, **generation_kwargs):
150
+ token = chunk["choices"]["text"]
151
+ generated_text += token
152
+ yield generated_text
153
+
154
+ formatted_code = extract_and_format_code(generated_text)
155
+ logger.info("Solution generated successfully")
156
+ yield formatted_code
157
+
158
+ # Gradio interface
159
+ def gradio_generate(instruction: str, token: str) -> str:
160
+ return generate_solution_api(instruction, token)
161
+
162
+ def gradio_stream(instruction: str, token: str) -> str:
163
+ return "".join(list(stream_solution_api(instruction, token)))
164
+
165
+ iface = gr.Interface(
166
+ fn=[gradio_generate, gradio_stream],
167
+ inputs=[
168
+ gr.Textbox(label="LeetCode Problem Instruction"),
169
+ gr.Textbox(label="JWT Token")
170
+ ],
171
+ outputs=[
172
+ gr.Code(label="Generated Solution"),
173
+ gr.Code(label="Streamed Solution")
174
+ ],
175
+ title="LeetCode Problem Solver",
176
+ description="Enter a LeetCode problem instruction and your JWT token to generate a solution."
177
+ )
178
+
179
+ if __name__ == "__main__":
180
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ llama-cpp-python
3
+ huggingface_hub
4
+ pyjwt
5
+ autopep8