karthikmad's picture
Bug Fix for multiselect
79acc49 verified
raw
history blame
14.1 kB
import logging
import time
from datetime import timedelta
from typing import Dict, List
import streamlit as st
from llm_guard.input_scanners import get_scanner_by_name
from llm_guard.input_scanners.anonymize import default_entity_types
from llm_guard.input_scanners.code import SUPPORTED_LANGUAGES as SUPPORTED_CODE_LANGUAGES
from llm_guard.input_scanners.language import MatchType as LanguageMatchType
from llm_guard.input_scanners.prompt_injection import MatchType as PromptInjectionMatchType
from llm_guard.input_scanners.toxicity import MatchType as ToxicityMatchType
from llm_guard.vault import Vault
from streamlit_tags import st_tags
logger = logging.getLogger("llm-guard-playground")
def init_settings() -> (List, Dict):
all_scanners = [
"Anonymize",
"BanCompetitors",
"BanSubstrings",
"BanTopics",
"Code",
"Language",
"PromptInjection",
"Regex",
"Secrets",
"Sentiment",
"TokenLimit",
"Toxicity",
]
st_enabled_scanners = st.sidebar.multiselect(
"Select scanners",
options=all_scanners,
default=all_scanners,
help="The list can be found here: https://llm-guard.com/input_scanners/anonymize/",
)
settings = {}
if "Anonymize" in st_enabled_scanners:
st_anon_expander = st.sidebar.expander(
"Anonymize",
expanded=False,
)
with st_anon_expander:
st_anon_entity_types = st_tags(
label="Anonymize entities",
text="Type and press enter",
value=default_entity_types,
suggestions=default_entity_types
+ ["DATE_TIME", "NRP", "LOCATION", "MEDICAL_LICENSE", "US_PASSPORT"],
maxtags=30,
key="anon_entity_types",
)
st.caption(
"Check all supported entities: https://llm-guard.com/input_scanners/anonymize/"
)
st_anon_hidden_names = st_tags(
label="Hidden names to be anonymized",
text="Type and press enter",
value=[],
suggestions=[],
maxtags=30,
key="anon_hidden_names",
)
st.caption("These names will be hidden e.g. [REDACTED_CUSTOM1].")
st_anon_allowed_names = st_tags(
label="Allowed names to ignore",
text="Type and press enter",
value=[],
suggestions=[],
maxtags=30,
key="anon_allowed_names",
)
st.caption("These names will be ignored even if flagged by the detector.")
st_anon_preamble = st.text_input(
"Preamble", value="Text to prepend to sanitized prompt: "
)
st_anon_use_faker = st.checkbox(
"Use Faker",
value=False,
help="Use Faker library to generate fake data",
key="anon_use_faker",
)
st_anon_threshold = st.slider(
label="Threshold",
value=0.0,
min_value=0.0,
max_value=1.0,
step=0.1,
key="anon_threshold",
)
settings["Anonymize"] = {
"entity_types": st_anon_entity_types,
"hidden_names": st_anon_hidden_names,
"allowed_names": st_anon_allowed_names,
"preamble": st_anon_preamble,
"use_faker": st_anon_use_faker,
"threshold": st_anon_threshold,
}
if "BanCompetitors" in st_enabled_scanners:
st_bc_expander = st.sidebar.expander(
"Ban Competitors",
expanded=False,
)
with st_bc_expander:
st_bc_competitors = st_tags(
label="List of competitors",
text="Type and press enter",
value=["openai", "anthropic", "deepmind", "google"],
suggestions=[],
maxtags=30,
key="bc_competitors",
)
st_bc_threshold = st.slider(
label="Threshold",
value=0.5,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_competitors_threshold",
)
settings["BanCompetitors"] = {
"competitors": st_bc_competitors,
"threshold": st_bc_threshold,
}
if "BanSubstrings" in st_enabled_scanners:
st_bs_expander = st.sidebar.expander(
"Ban Substrings",
expanded=False,
)
with st_bs_expander:
st_bs_substrings = st.text_area(
"Enter substrings to ban (one per line)",
value="test\nhello\nworld",
height=200,
).split("\n")
st_bs_match_type = st.selectbox(
"Match type", ["str", "word"], index=0, key="bs_match_type"
)
st_bs_case_sensitive = st.checkbox(
"Case sensitive", value=False, key="bs_case_sensitive"
)
st_bs_redact = st.checkbox("Redact", value=False, key="bs_redact")
st_bs_contains_all = st.checkbox("Contains all", value=False, key="bs_contains_all")
settings["BanSubstrings"] = {
"substrings": st_bs_substrings,
"match_type": st_bs_match_type,
"case_sensitive": st_bs_case_sensitive,
"redact": st_bs_redact,
"contains_all": st_bs_contains_all,
}
if "BanTopics" in st_enabled_scanners:
st_bt_expander = st.sidebar.expander(
"Ban Topics",
expanded=False,
)
with st_bt_expander:
st_bt_topics = st_tags(
label="List of topics",
text="Type and press enter",
value=["violence"],
suggestions=[],
maxtags=30,
key="bt_topics",
)
st_bt_threshold = st.slider(
label="Threshold",
value=0.6,
min_value=0.0,
max_value=1.0,
step=0.05,
key="ban_topics_threshold",
)
settings["BanTopics"] = {
"topics": st_bt_topics,
"threshold": st_bt_threshold,
}
if "Code" in st_enabled_scanners:
st_cd_expander = st.sidebar.expander(
"Code",
expanded=False,
)
with st_cd_expander:
st_cd_languages = st.multiselect(
"Programming languages",
SUPPORTED_CODE_LANGUAGES,
default=["python"],
)
st_cd_is_blocked = st.checkbox("Is blocked", value=False, key="code_is_blocked")
settings["Code"] = {
"languages": st_cd_languages,
"is_blocked": st_cd_is_blocked,
}
if "Language" in st_enabled_scanners:
st_lan_expander = st.sidebar.expander(
"Language",
expanded=False,
)
with st_lan_expander:
st_lan_valid_language = st.multiselect(
"Languages",
[
"ar",
"bg",
"de",
"el",
"en",
"es",
"fr",
"hi",
"it",
"ja",
"nl",
"pl",
"pt",
"ru",
"sw",
"th",
"tr",
"ur",
"vi",
"zh",
],
default=["en"],
)
st_lan_match_type = st.selectbox(
"Match type",
[e.value for e in LanguageMatchType],
index=1,
key="language_match_type",
)
settings["Language"] = {
"valid_languages": st_lan_valid_language,
"match_type": st_lan_match_type,
}
if "PromptInjection" in st_enabled_scanners:
st_pi_expander = st.sidebar.expander(
"Prompt Injection",
expanded=False,
)
with st_pi_expander:
st_pi_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="prompt_injection_threshold",
)
st_pi_match_type = st.selectbox(
"Match type",
[e.value for e in PromptInjectionMatchType],
index=1,
key="prompt_injection_match_type",
)
settings["PromptInjection"] = {
"threshold": st_pi_threshold,
"match_type": st_pi_match_type,
}
if "Regex" in st_enabled_scanners:
st_regex_expander = st.sidebar.expander(
"Regex",
expanded=False,
)
with st_regex_expander:
st_regex_patterns = st.text_area(
"Enter patterns to ban (one per line)",
value="Bearer [A-Za-z0-9-._~+/]+",
height=200,
).split("\n")
st_regex_is_blocked = st.checkbox("Is blocked", value=False, key="regex_is_blocked")
st_regex_redact = st.checkbox(
"Redact",
value=False,
help="Replace the matched bad patterns with [REDACTED]",
key="regex_redact",
)
settings["Regex"] = {
"patterns": st_regex_patterns,
"is_blocked": st_regex_is_blocked,
"redact": st_regex_redact,
}
if "Secrets" in st_enabled_scanners:
st_sec_expander = st.sidebar.expander(
"Secrets",
expanded=False,
)
with st_sec_expander:
st_sec_redact_mode = st.selectbox("Redact mode", ["all", "partial", "hash"])
settings["Secrets"] = {
"redact_mode": st_sec_redact_mode,
}
if "Sentiment" in st_enabled_scanners:
st_sent_expander = st.sidebar.expander(
"Sentiment",
expanded=False,
)
with st_sent_expander:
st_sent_threshold = st.slider(
label="Threshold",
value=-0.1,
min_value=-1.0,
max_value=1.0,
step=0.1,
key="sentiment_threshold",
help="Negative values are negative sentiment, positive values are positive sentiment",
)
settings["Sentiment"] = {
"threshold": st_sent_threshold,
}
if "TokenLimit" in st_enabled_scanners:
st_tl_expander = st.sidebar.expander(
"Token Limit",
expanded=False,
)
with st_tl_expander:
st_tl_limit = st.number_input(
"Limit", value=4096, min_value=0, max_value=10000, step=10
)
st_tl_encoding_name = st.selectbox(
"Encoding name",
["cl100k_base", "p50k_base", "r50k_base"],
index=0,
help="Read more: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb",
)
settings["TokenLimit"] = {
"limit": st_tl_limit,
"encoding_name": st_tl_encoding_name,
}
if "Toxicity" in st_enabled_scanners:
st_tox_expander = st.sidebar.expander(
"Toxicity",
expanded=False,
)
with st_tox_expander:
st_tox_threshold = st.slider(
label="Threshold",
value=0.75,
min_value=0.0,
max_value=1.0,
step=0.05,
key="toxicity_threshold",
)
st_tox_match_type = st.selectbox(
"Match type",
[e.value for e in ToxicityMatchType],
index=1,
key="toxicity_match_type",
)
settings["Toxicity"] = {
"threshold": st_tox_threshold,
"match_type": st_tox_match_type,
}
return st_enabled_scanners, settings
def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
logger.debug(f"Initializing {scanner_name} scanner")
if scanner_name == "Anonymize":
settings["vault"] = vault
if scanner_name in ["Anonymize", "BanTopics", "Code", "PromptInjection", "Toxicity"]:
settings["use_onnx"] = True
return get_scanner_by_name(scanner_name, settings)
def scan(
vault: Vault, enabled_scanners: List[str], settings: Dict, text: str, fail_fast: bool = False
) -> (str, List[Dict[str, any]]):
sanitized_prompt = text
results = []
status_text = "Scanning prompt..."
if fail_fast:
status_text = "Scanning prompt (fail fast mode)..."
with st.status(status_text, expanded=True) as status:
for scanner_name in enabled_scanners:
st.write(f"{scanner_name} scanner...")
scanner = get_scanner(scanner_name, vault, settings[scanner_name])
start_time = time.monotonic()
sanitized_prompt, is_valid, risk_score = scanner.scan(sanitized_prompt)
end_time = time.monotonic()
results.append(
{
"scanner": scanner_name,
"is_valid": is_valid,
"risk_score": risk_score,
"took_sec": round(timedelta(seconds=end_time - start_time).total_seconds(), 2),
}
)
if fail_fast and not is_valid:
break
status.update(label="Scanning complete", state="complete", expanded=False)
return sanitized_prompt, results