File size: 4,989 Bytes
a9fd595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import logging
from confluent_kafka import KafkaException, Producer
import json
import torch
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from confluent_kafka.serialization import (
    MessageField,
    SerializationContext,
)
from unsloth import FastLanguageModel
from uuid import uuid4
import concurrent.futures

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
hf_token = os.getenv("HF_TOKEN")

class MessageSend:
    def __init__(self, username, title, level, detail=None):
        self.username = username
        self.title = title
        self.level = level
        self.detail = detail
        
def cover_message(msg):
    """Return a dictionary representation of a User instance for serialization."""
    return dict(
        username=msg.username,
        title=msg.title,
        level=msg.level,
        detail=msg.detail
    )

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class TooManyRequestsError(Exception):
    def __init__(self, retry_after):
        self.retry_after = retry_after

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "admincybers2/sentinal",
    max_seq_length = 4096,
    dtype = None,
    load_in_4bit = True,
    token=hf_token
)

# Enable native 2x faster inference
FastLanguageModel.for_inference(model)
vulnerable_prompt = "Identify the line of code that is vulnerable and describe the type of software vulnerability, no yapping if no vulnerable code found pls return 'no vulnerable'\n### Code Snippet:\n{}\n### Vulnerability Description:\n{}"


def extract_data(full_message):
    try:
        message = json.loads(full_message)
        return message
    except json.JSONDecodeError as e:
        logger.error(f"Failed to extract data: {e}")
        return None

def perform_ai_task(question):
    prompt = vulnerable_prompt.format(question, "")
    inputs = tokenizer([prompt], return_tensors="pt")
    text_streamer = TextStreamer(tokenizer)

    try:
        model_output = model.generate(
            **inputs,
            streamer=text_streamer,
            use_cache=True,
            max_new_tokens=640,
            temperature=0.5,
            top_k=50,
            top_p=0.9,
            min_p=0.01,
            typical_p=0.95,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3,
        )
        generated_text = tokenizer.decode(model_output[0], skip_special_tokens=True)
    except RuntimeError as e:
        error_message = str(e)
        if "probability tensor contains either `inf`, `nan` or element < 0" in error_message:
            logger.error("Encountered probability tensor error, skipping this task.")
            return None
        else:
            logger.error(f"Runtime error during model generation: {error_message}. Switching to remote inference.")

    deduplicated_text = deduplicate_text(generated_text)
    return {
        "detail": deduplicated_text
    }

def deduplicate_text(text):
    sentences = text.split('. ')
    seen_sentences = set()
    deduplicated_sentences = []

    for sentence in sentences:
        if sentence not in seen_sentences:
            seen_sentences.add(sentence)
            deduplicated_sentences.append(sentence)

    return '. '.join(deduplicated_sentences) + '.'

def delivery_report(err, msg):
    if err is not None:
        logger.error(f"Message delivery failed: {err}")
    else:
        logger.info(f"Message delivered to {msg.topic()} [{msg.partition()}]")

def handle_message(msg, producer, ensure_producer_connected, avro_serializer):
    logger.info(f'Message value {msg}')
    if msg:
        ensure_producer_connected(producer)
        try:
            ai_results = perform_ai_task(msg['message_send'])
            if ai_results is None:
                logger.error("AI task skipped due to an error in model generation.")
                return
            
            detail = ai_results.get("detail", "No details available")
            
            topic = "get_scan_message"

            messagedict = cover_message(
                MessageSend(
                    username=msg['username'],
                    title=msg['path'],
                    level='',
                    detail=detail
                )
            )
            
            if messagedict:
                byte_value = avro_serializer(messagedict, SerializationContext(topic, MessageField.VALUE))
                producer.produce(
                    topic,
                    value=byte_value,
                    headers={"correlation_id": str(uuid4())},
                    callback=delivery_report
                )
                producer.flush()
            else:
                logger.error("Message serialization failed; skipping production.")
        except KafkaException as e:
            logger.error(f"Kafka error producing message: {e}")
        except Exception as e:
            logger.error(f"Unhandled error in handle_message: {e}")