File size: 6,142 Bytes
c954503
b8e3be9
 
 
 
 
c954503
6a09dd7
151aa67
6a09dd7
c0298f8
6a09dd7
 
6a46dba
 
6a09dd7
c0298f8
6a09dd7
 
 
 
 
c0298f8
6a09dd7
 
c0298f8
 
6a09dd7
c0298f8
 
6a09dd7
 
 
 
 
 
 
 
 
 
c954503
6a09dd7
b38a095
6a09dd7
 
b8e3be9
6a09dd7
6a46dba
 
6a09dd7
6a46dba
c0298f8
c954503
 
6a09dd7
 
c954503
69f088a
c954503
6a09dd7
c0298f8
6a09dd7
 
 
 
 
 
 
 
 
 
 
 
c0298f8
6a09dd7
c0298f8
 
 
 
 
 
 
6a46dba
 
 
c0298f8
 
 
 
 
 
9bc591d
 
6a09dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0298f8
 
6a09dd7
 
 
 
 
 
9bc591d
c0298f8
 
 
 
9bc591d
6a09dd7
 
9bc591d
6a09dd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0298f8
69f088a
9bc591d
b8e3be9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import streamlit as st
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
)
import torch
import os

# Define the model names and their corresponding Hugging Face models
MODEL_MAPPING = {
    "text2shellcommands": "t5-small",  # Example seq2seq model for generating shell commands
    "pentest_ai": "bert-base-uncased",  # Example classification model for pentesting tasks
}

# Function to create a sidebar for model selection
def select_model():
    """
    Adds a dropdown to the Streamlit sidebar for selecting a model.
    Returns:
        str: The selected model key from MODEL_MAPPING.
    """
    st.sidebar.header("Model Configuration")
    selected_model = st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))
    return selected_model


# Function to load the model and tokenizer with caching
@st.cache_resource
def load_model_and_tokenizer(model_name):
    """
    Loads the tokenizer and model for the specified Hugging Face model name.
    Uses caching to optimize performance.

    Args:
        model_name (str): The name of the Hugging Face model to load.

    Returns:
        tuple: A tokenizer and model instance.
    """
    try:
        # Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Determine the correct model class to use
        if "t5" in model_name or "seq2seq" in model_name:
            # Load a sequence-to-sequence model
            model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        else:
            # Load a sequence classification model
            model = AutoModelForSequenceClassification.from_pretrained(model_name)

        return tokenizer, model
    except Exception as e:
        # Display an error message in the Streamlit app
        st.error(f"An error occurred while loading the model or tokenizer: {str(e)}")
        return None, None


# Function to handle predictions based on the selected model
def predict_with_model(user_input, model, tokenizer, model_choice):
    """
    Handles predictions using the loaded model and tokenizer.

    Args:
        user_input (str): Text input from the user.
        model: Loaded Hugging Face model.
        tokenizer: Loaded Hugging Face tokenizer.
        model_choice (str): Selected model key from MODEL_MAPPING.

    Returns:
        dict: A dictionary containing the prediction results.
    """
    if model_choice == "text2shellcommands":
        # Generate shell commands (Seq2Seq task)
        inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model.generate(**inputs)
        generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"Generated Shell Command": generated_command}
    else:
        # Perform classification
        inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=-1).item()
        return {
            "Predicted Class": predicted_class,
            "Logits": logits.tolist(),
        }


# Function to process uploaded files
def process_uploaded_file(uploaded_file):
    """
    Reads and processes the uploaded file. Supports text and CSV files.

    Args:
        uploaded_file: The uploaded file.

    Returns:
        str: The content of the file as a string.
    """
    try:
        if uploaded_file is not None:
            file_type = uploaded_file.type

            # Text file processing
            if "text" in file_type:
                content = uploaded_file.read().decode("utf-8")
                return content
            # CSV file processing
            elif "csv" in file_type:
                import pandas as pd
                df = pd.read_csv(uploaded_file)
                return df.to_string()  # Convert the dataframe to string
            else:
                st.error("Unsupported file type. Please upload a text or CSV file.")
                return None
    except Exception as e:
        st.error(f"Error processing file: {e}")
        return None


# Main function to define the Streamlit app
def main():
    st.title("AI Model Inference Dashboard")
    st.markdown(
        """
        This dashboard allows you to interact with different AI models for inference tasks, 
        such as generating shell commands or performing text classification.
        """
    )

    # Model selection
    model_choice = select_model()
    model_name = MODEL_MAPPING.get(model_choice)
    tokenizer, model = load_model_and_tokenizer(model_name)

    # Input text area or file upload
    input_choice = st.radio("Choose Input Method", ("Text Input", "Upload File"))

    if input_choice == "Text Input":
        user_input = st.text_area("Enter your text input:", placeholder="Type your text here...")

        # Handle prediction after submit
        submit_button = st.button("Submit")

        if submit_button and user_input:
            st.write("### Prediction Results:")
            result = predict_with_model(user_input, model, tokenizer, model_choice)
            for key, value in result.items():
                st.write(f"**{key}:** {value}")

    elif input_choice == "Upload File":
        uploaded_file = st.file_uploader("Choose a text or CSV file", type=["txt", "csv"])

        # Handle prediction after submit
        submit_button = st.button("Submit")

        if submit_button and uploaded_file:
            file_content = process_uploaded_file(uploaded_file)
            if file_content:
                st.write("### File Content:")
                st.write(file_content)
                result = predict_with_model(file_content, model, tokenizer, model_choice)
                st.write("### Prediction Results:")
                for key, value in result.items():
                    st.write(f"**{key}:** {value}")
            else:
                st.info("No valid content found in the file.")


if __name__ == "__main__":
    main()