Spaces:
Running
Running
File size: 6,142 Bytes
c954503 b8e3be9 c954503 6a09dd7 151aa67 6a09dd7 c0298f8 6a09dd7 6a46dba 6a09dd7 c0298f8 6a09dd7 c0298f8 6a09dd7 c0298f8 6a09dd7 c0298f8 6a09dd7 c954503 6a09dd7 b38a095 6a09dd7 b8e3be9 6a09dd7 6a46dba 6a09dd7 6a46dba c0298f8 c954503 6a09dd7 c954503 69f088a c954503 6a09dd7 c0298f8 6a09dd7 c0298f8 6a09dd7 c0298f8 6a46dba c0298f8 9bc591d 6a09dd7 c0298f8 6a09dd7 9bc591d c0298f8 9bc591d 6a09dd7 9bc591d 6a09dd7 c0298f8 69f088a 9bc591d b8e3be9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import streamlit as st
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM,
)
import torch
import os
# Define the model names and their corresponding Hugging Face models
MODEL_MAPPING = {
"text2shellcommands": "t5-small", # Example seq2seq model for generating shell commands
"pentest_ai": "bert-base-uncased", # Example classification model for pentesting tasks
}
# Function to create a sidebar for model selection
def select_model():
"""
Adds a dropdown to the Streamlit sidebar for selecting a model.
Returns:
str: The selected model key from MODEL_MAPPING.
"""
st.sidebar.header("Model Configuration")
selected_model = st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
return selected_model
# Function to load the model and tokenizer with caching
@st.cache_resource
def load_model_and_tokenizer(model_name):
"""
Loads the tokenizer and model for the specified Hugging Face model name.
Uses caching to optimize performance.
Args:
model_name (str): The name of the Hugging Face model to load.
Returns:
tuple: A tokenizer and model instance.
"""
try:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Determine the correct model class to use
if "t5" in model_name or "seq2seq" in model_name:
# Load a sequence-to-sequence model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
else:
# Load a sequence classification model
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
except Exception as e:
# Display an error message in the Streamlit app
st.error(f"An error occurred while loading the model or tokenizer: {str(e)}")
return None, None
# Function to handle predictions based on the selected model
def predict_with_model(user_input, model, tokenizer, model_choice):
"""
Handles predictions using the loaded model and tokenizer.
Args:
user_input (str): Text input from the user.
model: Loaded Hugging Face model.
tokenizer: Loaded Hugging Face tokenizer.
model_choice (str): Selected model key from MODEL_MAPPING.
Returns:
dict: A dictionary containing the prediction results.
"""
if model_choice == "text2shellcommands":
# Generate shell commands (Seq2Seq task)
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs)
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"Generated Shell Command": generated_command}
else:
# Perform classification
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return {
"Predicted Class": predicted_class,
"Logits": logits.tolist(),
}
# Function to process uploaded files
def process_uploaded_file(uploaded_file):
"""
Reads and processes the uploaded file. Supports text and CSV files.
Args:
uploaded_file: The uploaded file.
Returns:
str: The content of the file as a string.
"""
try:
if uploaded_file is not None:
file_type = uploaded_file.type
# Text file processing
if "text" in file_type:
content = uploaded_file.read().decode("utf-8")
return content
# CSV file processing
elif "csv" in file_type:
import pandas as pd
df = pd.read_csv(uploaded_file)
return df.to_string() # Convert the dataframe to string
else:
st.error("Unsupported file type. Please upload a text or CSV file.")
return None
except Exception as e:
st.error(f"Error processing file: {e}")
return None
# Main function to define the Streamlit app
def main():
st.title("AI Model Inference Dashboard")
st.markdown(
"""
This dashboard allows you to interact with different AI models for inference tasks,
such as generating shell commands or performing text classification.
"""
)
# Model selection
model_choice = select_model()
model_name = MODEL_MAPPING.get(model_choice)
tokenizer, model = load_model_and_tokenizer(model_name)
# Input text area or file upload
input_choice = st.radio("Choose Input Method", ("Text Input", "Upload File"))
if input_choice == "Text Input":
user_input = st.text_area("Enter your text input:", placeholder="Type your text here...")
# Handle prediction after submit
submit_button = st.button("Submit")
if submit_button and user_input:
st.write("### Prediction Results:")
result = predict_with_model(user_input, model, tokenizer, model_choice)
for key, value in result.items():
st.write(f"**{key}:** {value}")
elif input_choice == "Upload File":
uploaded_file = st.file_uploader("Choose a text or CSV file", type=["txt", "csv"])
# Handle prediction after submit
submit_button = st.button("Submit")
if submit_button and uploaded_file:
file_content = process_uploaded_file(uploaded_file)
if file_content:
st.write("### File Content:")
st.write(file_content)
result = predict_with_model(file_content, model, tokenizer, model_choice)
st.write("### Prediction Results:")
for key, value in result.items():
st.write(f"**{key}:** {value}")
else:
st.info("No valid content found in the file.")
if __name__ == "__main__":
main()
|