File size: 2,744 Bytes
c954503
6a46dba
c954503
151aa67
c0298f8
 
6a46dba
c0298f8
6a46dba
 
c0298f8
 
 
 
 
 
 
 
 
c954503
c0298f8
6a46dba
c0298f8
b38a095
c0298f8
b38a095
 
6a46dba
 
 
c0298f8
c954503
 
 
 
69f088a
c954503
c0298f8
 
 
 
 
 
 
 
 
 
 
6a46dba
 
 
c0298f8
 
 
 
 
 
9bc591d
 
c0298f8
 
 
9bc591d
c0298f8
 
 
 
9bc591d
c0298f8
9bc591d
 
c0298f8
 
 
 
 
 
 
 
69f088a
9bc591d
c0298f8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
import torch

# Define the model names and mappings
MODEL_MAPPING = {
    "text2shellcommands": "Canstralian/text2shellcommands",
    "pentest_ai": "Canstralian/pentest_ai",
}

# Sidebar for model selection
def select_model():
    st.sidebar.header("Model Configuration")
    return st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys()))


# Load model and tokenizer with caching
@st.cache_resource
def load_model_and_tokenizer(model_name):
    try:
        # Use a fallback model for testing
        if model_name == "Canstralian/text2shellcommands":
            model_name = "t5-small"

        # Load the tokenizer and model
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        if "seq2seq" in model_name.lower():
            model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        else:
            model = AutoModelForSequenceClassification.from_pretrained(model_name)

        return tokenizer, model
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None, None


# Handle predictions
def predict_with_model(user_input, model, tokenizer, model_choice):
    if model_choice == "text2shellcommands":
        # Generate shell commands
        inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model.generate(**inputs)
        generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"Generated Shell Command": generated_command}
    else:
        # Perform classification
        inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits, dim=-1).item()
        return {
            "Predicted Class": predicted_class,
            "Logits": logits.tolist(),
        }


# Main Streamlit app
def main():
    st.title("AI Model Inference Dashboard")

    # Model selection
    model_choice = select_model()
    model_name = MODEL_MAPPING.get(model_choice)
    tokenizer, model = load_model_and_tokenizer(model_name)

    # Input text box
    user_input = st.text_area("Enter text:")

    # Perform prediction if input and models are available
    if user_input and model and tokenizer:
        result = predict_with_model(user_input, model, tokenizer, model_choice)
        for key, value in result.items():
            st.write(f"{key}: {value}")
    else:
        st.info("Please enter some text for prediction.")


if __name__ == "__main__":
    main()