|
import sqlite3 |
|
import gradio as gr |
|
from hashlib import md5 as hash_algo |
|
from re import match |
|
from io import BytesIO |
|
from pypdf import PdfReader |
|
from llm_rs import AutoModel,SessionConfig,GenerationConfig,Precision |
|
|
|
repo_name = "rustformers/mpt-7b-ggml" |
|
file_name = "mpt-7b-instruct-q5_1-ggjt.bin" |
|
script_env = 'prod' |
|
|
|
|
|
session_config = SessionConfig(threads=2,batch_size=2) |
|
model = AutoModel.from_pretrained(repo_name, model_file=file_name, session_config=session_config,verbose=True) |
|
|
|
def process_stream(rules, log, temperature, top_p, top_k, max_new_tokens, seed): |
|
con = sqlite3.connect("history.db") |
|
cur = con.cursor() |
|
instruction = '' |
|
hashes = [] |
|
|
|
if type(rules) is not list: |
|
rules = [rules] |
|
|
|
for rule in rules: |
|
data, hash = get_file_contents(rule) |
|
instruction += data + '\n' |
|
hashes.append(hash) |
|
|
|
hashes.sort() |
|
hashes = hash_algo(''.join(hashes).encode()).hexdigest() |
|
|
|
largest = 0 |
|
lines = instruction.split('\r\n') |
|
|
|
if len(lines) == 1: |
|
lines = instruction.split('\n') |
|
|
|
for line in lines: |
|
m = match('^(\d+)\.', line) |
|
if m != None: |
|
num = int(line[m.start():m.end()-1]) |
|
|
|
if num > largest: |
|
largest = num |
|
|
|
instruction += str(largest + 1) + '. ' |
|
|
|
query, hash = get_file_contents(log) |
|
hashes = hash_algo((hashes + hash).encode()).hexdigest() |
|
|
|
instruction = instruction.replace('\r\r\n', '\n') |
|
|
|
prompt=f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
### Instruction: |
|
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers. |
|
|
|
Q: Read the rules stated below and check the queries for any violation. State the rules which are violated by a query (if any). Also suggest a possible remediation, if possible. Do not make any assumptions outside of the rules stated below. |
|
|
|
{instruction}The queries are as follows: |
|
{query} |
|
|
|
A: |
|
|
|
### Response: |
|
Answer:""" |
|
|
|
response = "" |
|
row = cur.execute('SELECT response FROM queries WHERE hexdigest = ?', [hashes]).fetchone() |
|
|
|
if row != None: |
|
response += "Cached Result:\n" + row[0] |
|
yield response |
|
else: |
|
if script_env != 'test': |
|
generation_config = GenerationConfig(seed=seed,temperature=temperature,top_p=top_p,top_k=top_k,max_new_tokens=max_new_tokens) |
|
streamer = model.stream(prompt=prompt,generation_config=generation_config) |
|
for new_text in streamer: |
|
response += new_text |
|
yield response |
|
else: |
|
num = 0 |
|
while num < 100: |
|
response += " " + str(num) |
|
num += 1 |
|
yield response |
|
|
|
cur.execute('INSERT INTO queries VALUES(?, ?)', (hashes, response)) |
|
con.commit() |
|
|
|
cur.close() |
|
con.close() |
|
|
|
def get_file_contents(file): |
|
data = None |
|
byte_hash = '' |
|
|
|
with open(file.name, 'rb') as f: |
|
data = f.read() |
|
byte_hash = hash_algo(data).hexdigest() |
|
|
|
if file.name.endswith('.pdf'): |
|
rdr = PdfReader(BytesIO(data)) |
|
data = '' |
|
|
|
for page in rdr.pages: |
|
data += page.extract_text() |
|
else: |
|
data = data.decode() |
|
|
|
if file.name.endswith(".csv"): |
|
data = data.replace(',', ' ') |
|
|
|
return (data, byte_hash) |
|
|
|
def upload_log_file(files): |
|
file_paths = [file.name for file in files] |
|
return file_paths |
|
|
|
def upload_file(files): |
|
file_paths = [file.name for file in files] |
|
return file_paths |
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft(), |
|
css=".disclaimer {font-variant-caps: all-small-caps;}", |
|
) as demo: |
|
gr.Markdown( |
|
"""<h1><center>Grid 5.0 Information Security Track</center></h1> |
|
""" |
|
) |
|
|
|
rules = gr.File(file_count="multiple") |
|
upload_button = gr.UploadButton("Click to upload a new Compliance Document", file_types=[".txt", ".pdf"], file_count="multiple") |
|
upload_button.upload(upload_file, upload_button, rules) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
log = gr.File() |
|
upload_log_button = gr.UploadButton("Click to upload a log file", file_types=[".txt", ".csv", ".pdf"], file_count="multiple") |
|
upload_log_button.upload(upload_log_file, upload_log_button, log) |
|
|
|
with gr.Accordion("Advanced Options:", open=False): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
value=0.8, |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.1, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
top_p = gr.Slider( |
|
label="Top-p (nucleus sampling)", |
|
value=0.95, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
interactive=True, |
|
info=( |
|
"Sample from the smallest possible set of tokens whose cumulative probability " |
|
"exceeds top_p. Set to 1 to disable and sample from all tokens." |
|
), |
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
top_k = gr.Slider( |
|
label="Top-k", |
|
value=40, |
|
minimum=5, |
|
maximum=80, |
|
step=1, |
|
interactive=True, |
|
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", |
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
max_new_tokens = gr.Slider( |
|
label="Maximum new tokens", |
|
value=256, |
|
minimum=0, |
|
maximum=1024, |
|
step=5, |
|
interactive=True, |
|
info="The maximum number of new tokens to generate", |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
seed = gr.Number( |
|
label="Seed", |
|
value=42, |
|
interactive=True, |
|
info="The seed to use for the generation", |
|
precision=0 |
|
) |
|
with gr.Row(): |
|
submit = gr.Button("Submit") |
|
with gr.Row(): |
|
with gr.Box(): |
|
gr.Markdown("**Output**") |
|
output_7b = gr.Markdown() |
|
|
|
submit.click( |
|
process_stream, |
|
inputs=[rules, log, temperature, top_p, top_k, max_new_tokens,seed], |
|
outputs=output_7b, |
|
) |
|
|
|
demo.queue(max_size=4, concurrency_count=1).launch(debug=True) |