|
import os |
|
import re |
|
import logging |
|
import textwrap |
|
import autopep8 |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download |
|
from llama_cpp import Llama |
|
import jwt |
|
from typing import Generator |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
JWT_SECRET = os.environ.get("JWT_SECRET") |
|
JWT_ALGORITHM = "HS256" |
|
|
|
|
|
MODEL_NAME = "leetmonkey_peft__q8_0.gguf" |
|
REPO_ID = "sugiv/leetmonkey-peft-gguf" |
|
|
|
|
|
generation_kwargs = { |
|
"max_tokens": 2048, |
|
"stop": ["```", "### Instruction:", "### Response:"], |
|
"echo": False, |
|
"temperature": 0.2, |
|
"top_k": 50, |
|
"top_p": 0.95, |
|
"repeat_penalty": 1.1 |
|
} |
|
|
|
def download_model(model_name: str) -> str: |
|
logger.info(f"Downloading model: {model_name}") |
|
model_path = hf_hub_download( |
|
repo_id=REPO_ID, |
|
filename=model_name, |
|
cache_dir="./models", |
|
force_download=True, |
|
resume_download=True |
|
) |
|
logger.info(f"Model downloaded: {model_path}") |
|
return model_path |
|
|
|
|
|
model_path = download_model(MODEL_NAME) |
|
llm = Llama( |
|
model_path=model_path, |
|
n_ctx=2048, |
|
n_threads=4, |
|
n_gpu_layers=-1, |
|
verbose=False |
|
) |
|
logger.info("8-bit model loaded successfully") |
|
|
|
def generate_solution(instruction: str) -> str: |
|
system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions." |
|
full_prompt = f"""### Instruction: |
|
{system_prompt} |
|
|
|
Implement the following function for the LeetCode problem: |
|
|
|
{instruction} |
|
|
|
### Response: |
|
Here's the complete Python function implementation: |
|
|
|
```python |
|
""" |
|
|
|
response = llm(full_prompt, **generation_kwargs) |
|
return response["choices"][0]["text"] |
|
|
|
def extract_and_format_code(text: str) -> str: |
|
|
|
code_match = re.search(r'```python\s*(.*?)\s*```', text, re.DOTALL) |
|
if code_match: |
|
code = code_match.group(1) |
|
else: |
|
code = text |
|
|
|
|
|
code = re.sub(r'^.*?(?=def\s+\w+\s*\()', '', code, flags=re.DOTALL) |
|
|
|
|
|
code = textwrap.dedent(code) |
|
|
|
|
|
lines = code.split('\n') |
|
|
|
|
|
func_def_index = next((i for i, line in enumerate(lines) if line.strip().startswith('def ')), 0) |
|
|
|
|
|
indented_lines = [lines[func_def_index]] |
|
for line in lines[func_def_index + 1:]: |
|
if line.strip(): |
|
indented_lines.append(' ' + line) |
|
else: |
|
indented_lines.append(line) |
|
|
|
formatted_code = '\n'.join(indented_lines) |
|
|
|
try: |
|
return autopep8.fix_code(formatted_code) |
|
except: |
|
return formatted_code |
|
|
|
def verify_token(token: str) -> bool: |
|
try: |
|
jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) |
|
return True |
|
except jwt.PyJWTError: |
|
return False |
|
|
|
def generate_solution_api(instruction: str, token: str) -> str: |
|
if not verify_token(token): |
|
return "Invalid token. Please provide a valid JWT token." |
|
|
|
logger.info("Generating solution") |
|
generated_output = generate_solution(instruction) |
|
formatted_code = extract_and_format_code(generated_output) |
|
logger.info("Solution generated successfully") |
|
return formatted_code |
|
|
|
def stream_solution_api(instruction: str, token: str) -> Generator[str, None, None]: |
|
if not verify_token(token): |
|
yield "Invalid token. Please provide a valid JWT token." |
|
return |
|
|
|
logger.info("Streaming solution") |
|
system_prompt = "You are a Python coding assistant specialized in solving LeetCode problems. Provide only the complete implementation of the given function. Ensure proper indentation and formatting. Do not include any explanations or multiple solutions." |
|
full_prompt = f"""### Instruction: |
|
{system_prompt} |
|
|
|
Implement the following function for the LeetCode problem: |
|
|
|
{instruction} |
|
|
|
### Response: |
|
Here's the complete Python function implementation: |
|
|
|
```python |
|
""" |
|
|
|
generated_text = "" |
|
for chunk in llm(full_prompt, stream=True, **generation_kwargs): |
|
token = chunk["choices"]["text"] |
|
generated_text += token |
|
yield generated_text |
|
|
|
formatted_code = extract_and_format_code(generated_text) |
|
logger.info("Solution generated successfully") |
|
yield formatted_code |
|
|
|
|
|
def gradio_generate(instruction: str, token: str) -> str: |
|
return generate_solution_api(instruction, token) |
|
|
|
def gradio_stream(instruction: str, token: str) -> str: |
|
return "".join(list(stream_solution_api(instruction, token))) |
|
|
|
iface = gr.Interface( |
|
fn=[gradio_generate, gradio_stream], |
|
inputs=[ |
|
gr.Textbox(label="LeetCode Problem Instruction"), |
|
gr.Textbox(label="JWT Token") |
|
], |
|
outputs=[ |
|
gr.Code(label="Generated Solution"), |
|
gr.Code(label="Streamed Solution") |
|
], |
|
title="LeetCode Problem Solver", |
|
description="Enter a LeetCode problem instruction and your JWT token to generate a solution." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch(share=True) |