Spaces:
Runtime error
Runtime error
import os | |
import logging | |
from confluent_kafka import KafkaException, Producer | |
import json | |
import torch | |
from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from confluent_kafka.serialization import ( | |
MessageField, | |
SerializationContext, | |
) | |
from unsloth import FastLanguageModel | |
from uuid import uuid4 | |
import concurrent.futures | |
os.environ['CUDA_LAUNCH_BLOCKING'] = "1" | |
hf_token = os.getenv("HF_TOKEN") | |
class MessageSend: | |
def __init__(self, username, title, level, detail=None): | |
self.username = username | |
self.title = title | |
self.level = level | |
self.detail = detail | |
def cover_message(msg): | |
"""Return a dictionary representation of a User instance for serialization.""" | |
return dict( | |
username=msg.username, | |
title=msg.title, | |
level=msg.level, | |
detail=msg.detail | |
) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class TooManyRequestsError(Exception): | |
def __init__(self, retry_after): | |
self.retry_after = retry_after | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name = "admincybers2/sentinal", | |
max_seq_length = 4096, | |
dtype = None, | |
load_in_4bit = True, | |
token=hf_token | |
) | |
# Enable native 2x faster inference | |
FastLanguageModel.for_inference(model) | |
vulnerable_prompt = "Identify the line of code that is vulnerable and describe the type of software vulnerability, no yapping if no vulnerable code found pls return 'no vulnerable'\n### Code Snippet:\n{}\n### Vulnerability Description:\n{}" | |
def extract_data(full_message): | |
try: | |
message = json.loads(full_message) | |
return message | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to extract data: {e}") | |
return None | |
def perform_ai_task(question): | |
prompt = vulnerable_prompt.format(question, "") | |
inputs = tokenizer([prompt], return_tensors="pt") | |
text_streamer = TextStreamer(tokenizer) | |
try: | |
model_output = model.generate( | |
**inputs, | |
streamer=text_streamer, | |
use_cache=True, | |
max_new_tokens=640, | |
temperature=0.5, | |
top_k=50, | |
top_p=0.9, | |
min_p=0.01, | |
typical_p=0.95, | |
repetition_penalty=1.2, | |
no_repeat_ngram_size=3, | |
) | |
generated_text = tokenizer.decode(model_output[0], skip_special_tokens=True) | |
except RuntimeError as e: | |
error_message = str(e) | |
if "probability tensor contains either `inf`, `nan` or element < 0" in error_message: | |
logger.error("Encountered probability tensor error, skipping this task.") | |
return None | |
else: | |
logger.error(f"Runtime error during model generation: {error_message}. Switching to remote inference.") | |
deduplicated_text = deduplicate_text(generated_text) | |
return { | |
"detail": deduplicated_text | |
} | |
def deduplicate_text(text): | |
sentences = text.split('. ') | |
seen_sentences = set() | |
deduplicated_sentences = [] | |
for sentence in sentences: | |
if sentence not in seen_sentences: | |
seen_sentences.add(sentence) | |
deduplicated_sentences.append(sentence) | |
return '. '.join(deduplicated_sentences) + '.' | |
def delivery_report(err, msg): | |
if err is not None: | |
logger.error(f"Message delivery failed: {err}") | |
else: | |
logger.info(f"Message delivered to {msg.topic()} [{msg.partition()}]") | |
def handle_message(msg, producer, ensure_producer_connected, avro_serializer): | |
logger.info(f'Message value {msg}') | |
if msg: | |
ensure_producer_connected(producer) | |
try: | |
ai_results = perform_ai_task(msg['message_send']) | |
if ai_results is None: | |
logger.error("AI task skipped due to an error in model generation.") | |
return | |
detail = ai_results.get("detail", "No details available") | |
topic = "get_scan_message" | |
messagedict = cover_message( | |
MessageSend( | |
username=msg['username'], | |
title=msg['path'], | |
level='', | |
detail=detail | |
) | |
) | |
if messagedict: | |
byte_value = avro_serializer(messagedict, SerializationContext(topic, MessageField.VALUE)) | |
producer.produce( | |
topic, | |
value=byte_value, | |
headers={"correlation_id": str(uuid4())}, | |
callback=delivery_report | |
) | |
producer.flush() | |
else: | |
logger.error("Message serialization failed; skipping production.") | |
except KafkaException as e: | |
logger.error(f"Kafka error producing message: {e}") | |
except Exception as e: | |
logger.error(f"Unhandled error in handle_message: {e}") |