CyberEndE / aitask.py
admincybers2's picture
Create aitask.py
a9fd595 verified
raw
history blame
4.99 kB
import os
import logging
from confluent_kafka import KafkaException, Producer
import json
import torch
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from confluent_kafka.serialization import (
MessageField,
SerializationContext,
)
from unsloth import FastLanguageModel
from uuid import uuid4
import concurrent.futures
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
hf_token = os.getenv("HF_TOKEN")
class MessageSend:
def __init__(self, username, title, level, detail=None):
self.username = username
self.title = title
self.level = level
self.detail = detail
def cover_message(msg):
"""Return a dictionary representation of a User instance for serialization."""
return dict(
username=msg.username,
title=msg.title,
level=msg.level,
detail=msg.detail
)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class TooManyRequestsError(Exception):
def __init__(self, retry_after):
self.retry_after = retry_after
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "admincybers2/sentinal",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token=hf_token
)
# Enable native 2x faster inference
FastLanguageModel.for_inference(model)
vulnerable_prompt = "Identify the line of code that is vulnerable and describe the type of software vulnerability, no yapping if no vulnerable code found pls return 'no vulnerable'\n### Code Snippet:\n{}\n### Vulnerability Description:\n{}"
def extract_data(full_message):
try:
message = json.loads(full_message)
return message
except json.JSONDecodeError as e:
logger.error(f"Failed to extract data: {e}")
return None
def perform_ai_task(question):
prompt = vulnerable_prompt.format(question, "")
inputs = tokenizer([prompt], return_tensors="pt")
text_streamer = TextStreamer(tokenizer)
try:
model_output = model.generate(
**inputs,
streamer=text_streamer,
use_cache=True,
max_new_tokens=640,
temperature=0.5,
top_k=50,
top_p=0.9,
min_p=0.01,
typical_p=0.95,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
)
generated_text = tokenizer.decode(model_output[0], skip_special_tokens=True)
except RuntimeError as e:
error_message = str(e)
if "probability tensor contains either `inf`, `nan` or element < 0" in error_message:
logger.error("Encountered probability tensor error, skipping this task.")
return None
else:
logger.error(f"Runtime error during model generation: {error_message}. Switching to remote inference.")
deduplicated_text = deduplicate_text(generated_text)
return {
"detail": deduplicated_text
}
def deduplicate_text(text):
sentences = text.split('. ')
seen_sentences = set()
deduplicated_sentences = []
for sentence in sentences:
if sentence not in seen_sentences:
seen_sentences.add(sentence)
deduplicated_sentences.append(sentence)
return '. '.join(deduplicated_sentences) + '.'
def delivery_report(err, msg):
if err is not None:
logger.error(f"Message delivery failed: {err}")
else:
logger.info(f"Message delivered to {msg.topic()} [{msg.partition()}]")
def handle_message(msg, producer, ensure_producer_connected, avro_serializer):
logger.info(f'Message value {msg}')
if msg:
ensure_producer_connected(producer)
try:
ai_results = perform_ai_task(msg['message_send'])
if ai_results is None:
logger.error("AI task skipped due to an error in model generation.")
return
detail = ai_results.get("detail", "No details available")
topic = "get_scan_message"
messagedict = cover_message(
MessageSend(
username=msg['username'],
title=msg['path'],
level='',
detail=detail
)
)
if messagedict:
byte_value = avro_serializer(messagedict, SerializationContext(topic, MessageField.VALUE))
producer.produce(
topic,
value=byte_value,
headers={"correlation_id": str(uuid4())},
callback=delivery_report
)
producer.flush()
else:
logger.error("Message serialization failed; skipping production.")
except KafkaException as e:
logger.error(f"Kafka error producing message: {e}")
except Exception as e:
logger.error(f"Unhandled error in handle_message: {e}")