Spaces:
Runtime error
Runtime error
File size: 4,989 Bytes
a9fd595 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import os
import logging
from confluent_kafka import KafkaException, Producer
import json
import torch
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from confluent_kafka.serialization import (
MessageField,
SerializationContext,
)
from unsloth import FastLanguageModel
from uuid import uuid4
import concurrent.futures
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
hf_token = os.getenv("HF_TOKEN")
class MessageSend:
def __init__(self, username, title, level, detail=None):
self.username = username
self.title = title
self.level = level
self.detail = detail
def cover_message(msg):
"""Return a dictionary representation of a User instance for serialization."""
return dict(
username=msg.username,
title=msg.title,
level=msg.level,
detail=msg.detail
)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class TooManyRequestsError(Exception):
def __init__(self, retry_after):
self.retry_after = retry_after
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "admincybers2/sentinal",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token=hf_token
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(model)
vulnerable_prompt = "Identify the line of code that is vulnerable and describe the type of software vulnerability, no yapping if no vulnerable code found pls return 'no vulnerable'\n### Code Snippet:\n{}\n### Vulnerability Description:\n{}"
def extract_data(full_message):
try:
message = json.loads(full_message)
return message
except json.JSONDecodeError as e:
logger.error(f"Failed to extract data: {e}")
return None
def perform_ai_task(question):
prompt = vulnerable_prompt.format(question, "")
inputs = tokenizer([prompt], return_tensors="pt")
text_streamer = TextStreamer(tokenizer)
try:
model_output = model.generate(
**inputs,
streamer=text_streamer,
use_cache=True,
max_new_tokens=640,
temperature=0.5,
top_k=50,
top_p=0.9,
min_p=0.01,
typical_p=0.95,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
)
generated_text = tokenizer.decode(model_output[0], skip_special_tokens=True)
except RuntimeError as e:
error_message = str(e)
if "probability tensor contains either `inf`, `nan` or element < 0" in error_message:
logger.error("Encountered probability tensor error, skipping this task.")
return None
else:
logger.error(f"Runtime error during model generation: {error_message}. Switching to remote inference.")
deduplicated_text = deduplicate_text(generated_text)
return {
"detail": deduplicated_text
}
def deduplicate_text(text):
sentences = text.split('. ')
seen_sentences = set()
deduplicated_sentences = []
for sentence in sentences:
if sentence not in seen_sentences:
seen_sentences.add(sentence)
deduplicated_sentences.append(sentence)
return '. '.join(deduplicated_sentences) + '.'
def delivery_report(err, msg):
if err is not None:
logger.error(f"Message delivery failed: {err}")
else:
logger.info(f"Message delivered to {msg.topic()} [{msg.partition()}]")
def handle_message(msg, producer, ensure_producer_connected, avro_serializer):
logger.info(f'Message value {msg}')
if msg:
ensure_producer_connected(producer)
try:
ai_results = perform_ai_task(msg['message_send'])
if ai_results is None:
logger.error("AI task skipped due to an error in model generation.")
return
detail = ai_results.get("detail", "No details available")
topic = "get_scan_message"
messagedict = cover_message(
MessageSend(
username=msg['username'],
title=msg['path'],
level='',
detail=detail
)
)
if messagedict:
byte_value = avro_serializer(messagedict, SerializationContext(topic, MessageField.VALUE))
producer.produce(
topic,
value=byte_value,
headers={"correlation_id": str(uuid4())},
callback=delivery_report
)
producer.flush()
else:
logger.error("Message serialization failed; skipping production.")
except KafkaException as e:
logger.error(f"Kafka error producing message: {e}")
except Exception as e:
logger.error(f"Unhandled error in handle_message: {e}") |