Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import logging | |
import re | |
# Set up logging | |
logging.basicConfig( | |
filename="app.log", | |
level=logging.INFO, | |
format="%(asctime)s:%(levelname)s:%(message)s" | |
) | |
# Model and tokenizer loading function with caching | |
def load_model(): | |
""" | |
Loads and caches the pre-trained language model and tokenizer. | |
Returns: | |
model: Pre-trained language model. | |
tokenizer: Tokenizer for the model. | |
""" | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_path = "Canstralian/pentest_ai" # Replace with the actual path if different | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
device_map={"": device}, # This will specify CPU or GPU explicitly | |
load_in_8bit=False, # Disabled for stability | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
logging.info("Model and tokenizer loaded successfully.") | |
return model, tokenizer | |
except Exception as e: | |
logging.error(f"Error loading model: {e}") | |
return None, None | |
def sanitize_input(text): | |
""" | |
Sanitizes and validates user input text to prevent injection or formatting issues. | |
Args: | |
text (str): User input text. | |
Returns: | |
str: Sanitized text. | |
""" | |
if not isinstance(text, str): | |
raise ValueError("Input must be a string.") | |
# Basic sanitization to remove unwanted characters | |
sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text) | |
return sanitized_text.strip() | |
def generate_text(model, tokenizer, instruction): | |
""" | |
Generates text based on the provided instruction using the loaded model. | |
Args: | |
model: The language model. | |
tokenizer: Tokenizer for encoding/decoding. | |
instruction (str): Instruction text for the model. | |
Returns: | |
str: Generated text response from the model. | |
""" | |
try: | |
# Validate and sanitize instruction input | |
instruction = sanitize_input(instruction) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokens = tokenizer.encode(instruction, return_tensors='pt').to(device) | |
generated_tokens = model.generate( | |
tokens, | |
max_length=1024, | |
top_p=1.0, | |
temperature=0.5, | |
top_k=50 | |
) | |
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
logging.info("Text generated successfully.") | |
return generated_text | |
except Exception as e: | |
logging.error(f"Error generating text: {e}") | |
return "Error in text generation." | |
# Gradio Interface Function | |
def gradio_interface(instruction): | |
""" | |
Interface function for Gradio to interact with the model and generate text. | |
""" | |
# Load the model and tokenizer | |
model, tokenizer = load_model() | |
if not model or not tokenizer: | |
return "Failed to load model or tokenizer. Please check your configuration." | |
# Generate the text | |
try: | |
generated_text = generate_text(model, tokenizer, instruction) | |
return generated_text | |
except ValueError as ve: | |
return f"Invalid input: {ve}" | |
except Exception as e: | |
logging.error(f"Error during text generation: {e}") | |
return "An error occurred. Please try again." | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Textbox(label="Enter an instruction for the model:", placeholder="Type your instruction here..."), | |
outputs=gr.Textbox(label="Generated Text:"), | |
title="Penetration Testing AI Assistant", | |
description="This tool allows you to interact with a pre-trained AI model for penetration testing assistance. Enter an instruction to generate a response.", | |
) | |
# Launch the Gradio interface | |
iface.launch() | |