OSINT_Tool / app.py
Canstralian's picture
Update app.py
c0298f8 verified
raw
history blame
2.74 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
import torch
# Define the model names and mappings
MODEL_MAPPING = {
"text2shellcommands": "Canstralian/text2shellcommands",
"pentest_ai": "Canstralian/pentest_ai",
}
# Sidebar for model selection
def select_model():
st.sidebar.header("Model Configuration")
return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
# Load model and tokenizer with caching
@st.cache_resource
def load_model_and_tokenizer(model_name):
try:
# Use a fallback model for testing
if model_name == "Canstralian/text2shellcommands":
model_name = "t5-small"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
if "seq2seq" in model_name.lower():
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
except Exception as e:
st.error(f"Error loading model: {e}")
return None, None
# Handle predictions
def predict_with_model(user_input, model, tokenizer, model_choice):
if model_choice == "text2shellcommands":
# Generate shell commands
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs)
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"Generated Shell Command": generated_command}
else:
# Perform classification
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return {
"Predicted Class": predicted_class,
"Logits": logits.tolist(),
}
# Main Streamlit app
def main():
st.title("AI Model Inference Dashboard")
# Model selection
model_choice = select_model()
model_name = MODEL_MAPPING.get(model_choice)
tokenizer, model = load_model_and_tokenizer(model_name)
# Input text box
user_input = st.text_area("Enter text:")
# Perform prediction if input and models are available
if user_input and model and tokenizer:
result = predict_with_model(user_input, model, tokenizer, model_choice)
for key, value in result.items():
st.write(f"{key}: {value}")
else:
st.info("Please enter some text for prediction.")
if __name__ == "__main__":
main()