rootxhacker's picture
Update app.py
9d39551 verified
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import spaces
# Load the model and tokenizer
peft_model_id = "rootxhacker/CodeAstra-7B"
config = PeftConfig.from_pretrained(peft_model_id)
# Function to move tensors to CPU
def to_cpu(obj):
if isinstance(obj, torch.Tensor):
return obj.cpu()
elif isinstance(obj, list):
return [to_cpu(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(to_cpu(item) for item in obj)
elif isinstance(obj, dict):
return {key: to_cpu(value) for key, value in obj.items()}
return obj
# Load the model
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
load_in_4bit=True,
device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id)
@spaces.GPU()
def get_completion(query, model, tokenizer):
try:
# Move model to CUDA
model = model.cuda()
# Ensure input is on CUDA
inputs = tokenizer(query, return_tensors="pt").to('cuda')
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
# Move outputs to CPU before decoding
outputs = to_cpu(outputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"An error occurred: {str(e)}"
finally:
# Move model back to CPU to free up GPU memory
model = model.cpu()
torch.cuda.empty_cache()
@spaces.GPU()
def code_review(code_to_analyze):
two_shot_prompt = f"""find all vulnerabilities which in the code
{code_to_analyze} """
full_response = get_completion(two_shot_prompt, model, tokenizer)
# Return the full response without any processing
return full_response
# Create Gradio interface
iface = gr.Interface(
fn=code_review,
inputs=gr.Textbox(lines=10, label="Enter code to analyze"),
outputs=gr.Textbox(label="Code Review Result"),
title="Code Review Expert",
description="This tool analyzes code for potential security flaws, logic vulnerabilities, and provides guidance on secure coding practices."
)
# Launch the Gradio app
iface.launch()