Canstralian's picture
Update app.py
58022ea verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
import re
# Set up logging
logging.basicConfig(
filename="app.log",
level=logging.INFO,
format="%(asctime)s:%(levelname)s:%(message)s"
)
# Model and tokenizer loading function with caching
def load_model():
"""
Loads and caches the pre-trained language model and tokenizer.
Returns:
model: Pre-trained language model.
tokenizer: Tokenizer for the model.
"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "Canstralian/pentest_ai" # Replace with the actual path if different
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map={"": device}, # This will specify CPU or GPU explicitly
load_in_8bit=False, # Disabled for stability
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
logging.info("Model and tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
logging.error(f"Error loading model: {e}")
return None, None
def sanitize_input(text):
"""
Sanitizes and validates user input text to prevent injection or formatting issues.
Args:
text (str): User input text.
Returns:
str: Sanitized text.
"""
if not isinstance(text, str):
raise ValueError("Input must be a string.")
# Basic sanitization to remove unwanted characters
sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text)
return sanitized_text.strip()
def generate_text(model, tokenizer, instruction):
"""
Generates text based on the provided instruction using the loaded model.
Args:
model: The language model.
tokenizer: Tokenizer for encoding/decoding.
instruction (str): Instruction text for the model.
Returns:
str: Generated text response from the model.
"""
try:
# Validate and sanitize instruction input
instruction = sanitize_input(instruction)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokens = tokenizer.encode(instruction, return_tensors='pt').to(device)
generated_tokens = model.generate(
tokens,
max_length=1024,
top_p=1.0,
temperature=0.5,
top_k=50
)
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
logging.info("Text generated successfully.")
return generated_text
except Exception as e:
logging.error(f"Error generating text: {e}")
return "Error in text generation."
# Gradio Interface Function
def gradio_interface(instruction):
"""
Interface function for Gradio to interact with the model and generate text.
"""
# Load the model and tokenizer
model, tokenizer = load_model()
if not model or not tokenizer:
return "Failed to load model or tokenizer. Please check your configuration."
# Generate the text
try:
generated_text = generate_text(model, tokenizer, instruction)
return generated_text
except ValueError as ve:
return f"Invalid input: {ve}"
except Exception as e:
logging.error(f"Error during text generation: {e}")
return "An error occurred. Please try again."
# Create Gradio Interface
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(label="Enter an instruction for the model:", placeholder="Type your instruction here..."),
outputs=gr.Textbox(label="Generated Text:"),
title="Penetration Testing AI Assistant",
description="This tool allows you to interact with a pre-trained AI model for penetration testing assistance. Enter an instruction to generate a response.",
)
# Launch the Gradio interface
iface.launch()