import os import logging from confluent_kafka import KafkaException, Producer import json import torch from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from confluent_kafka.serialization import ( MessageField, SerializationContext, ) from unsloth import FastLanguageModel from uuid import uuid4 import concurrent.futures os.environ['CUDA_LAUNCH_BLOCKING'] = "1" hf_token = os.getenv("HF_TOKEN") class MessageSend: def __init__(self, username, title, level, detail=None): self.username = username self.title = title self.level = level self.detail = detail def cover_message(msg): """Return a dictionary representation of a User instance for serialization.""" return dict( username=msg.username, title=msg.title, level=msg.level, detail=msg.detail ) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) class TooManyRequestsError(Exception): def __init__(self, retry_after): self.retry_after = retry_after model, tokenizer = FastLanguageModel.from_pretrained( model_name = "admincybers2/sentinal", max_seq_length = 4096, dtype = None, load_in_4bit = True, token=hf_token ) # Enable native 2x faster inference FastLanguageModel.for_inference(model) vulnerable_prompt = "Identify the line of code that is vulnerable and describe the type of software vulnerability, no yapping if no vulnerable code found pls return 'no vulnerable'\n### Code Snippet:\n{}\n### Vulnerability Description:\n{}" def extract_data(full_message): try: message = json.loads(full_message) return message except json.JSONDecodeError as e: logger.error(f"Failed to extract data: {e}") return None def perform_ai_task(question): prompt = vulnerable_prompt.format(question, "") inputs = tokenizer([prompt], return_tensors="pt") text_streamer = TextStreamer(tokenizer) try: model_output = model.generate( **inputs, streamer=text_streamer, use_cache=True, max_new_tokens=640, temperature=0.5, top_k=50, top_p=0.9, min_p=0.01, typical_p=0.95, repetition_penalty=1.2, no_repeat_ngram_size=3, ) generated_text = tokenizer.decode(model_output[0], skip_special_tokens=True) except RuntimeError as e: error_message = str(e) if "probability tensor contains either `inf`, `nan` or element < 0" in error_message: logger.error("Encountered probability tensor error, skipping this task.") return None else: logger.error(f"Runtime error during model generation: {error_message}. Switching to remote inference.") deduplicated_text = deduplicate_text(generated_text) return { "detail": deduplicated_text } def deduplicate_text(text): sentences = text.split('. ') seen_sentences = set() deduplicated_sentences = [] for sentence in sentences: if sentence not in seen_sentences: seen_sentences.add(sentence) deduplicated_sentences.append(sentence) return '. '.join(deduplicated_sentences) + '.' def delivery_report(err, msg): if err is not None: logger.error(f"Message delivery failed: {err}") else: logger.info(f"Message delivered to {msg.topic()} [{msg.partition()}]") def handle_message(msg, producer, ensure_producer_connected, avro_serializer): logger.info(f'Message value {msg}') if msg: ensure_producer_connected(producer) try: ai_results = perform_ai_task(msg['message_send']) if ai_results is None: logger.error("AI task skipped due to an error in model generation.") return detail = ai_results.get("detail", "No details available") topic = "get_scan_message" messagedict = cover_message( MessageSend( username=msg['username'], title=msg['path'], level='', detail=detail ) ) if messagedict: byte_value = avro_serializer(messagedict, SerializationContext(topic, MessageField.VALUE)) producer.produce( topic, value=byte_value, headers={"correlation_id": str(uuid4())}, callback=delivery_report ) producer.flush() else: logger.error("Message serialization failed; skipping production.") except KafkaException as e: logger.error(f"Kafka error producing message: {e}") except Exception as e: logger.error(f"Unhandled error in handle_message: {e}")