import gradio as gr import numpy as np from PIL import Image, ImageEnhance import cv2 import torch import os import requests from tqdm import tqdm from segment_anything import sam_model_registry, SamPredictor import matplotlib.pyplot as plt # Define o CSS customizado css = """ .gradio-container { font-family: 'Inter', sans-serif; max-width: 1200px !important; margin: auto; } .main-container { background-color: #ffffff; border-radius: 15px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); padding: 20px; margin: 20px; } .header { background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%); color: white; padding: 20px; border-radius: 10px; margin-bottom: 20px; text-align: center; } .controls { background-color: #f8fafc; padding: 15px; border-radius: 10px; margin: 10px 0; } .slider-label { color: #374151; font-weight: 500; } .button-primary { background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%); border: none; padding: 10px 20px; border-radius: 8px; color: white; font-weight: 600; cursor: pointer; transition: transform 0.2s; } .button-primary:hover { transform: translateY(-2px); } .output-container { background-color: #f8fafc; padding: 15px; border-radius: 10px; margin-top: 20px; } .analysis-text { font-family: 'Mono', monospace; background-color: #1e293b; color: #f8fafc; padding: 15px; border-radius: 8px; white-space: pre-wrap; } .image-input, .image-output { border-radius: 10px; overflow: hidden; border: 2px solid #e2e8f0; } """ def download_sam_model(): """Download SAM model if not already present""" model_path = "sam_vit_h_4b8939.pth" if not os.path.exists(model_path): print("Baixando modelo SAM... Isso pode levar alguns minutos.") url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(model_path, 'wb') as file, tqdm( desc=model_path, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as progress_bar: for data in response.iter_content(chunk_size=1024): size = file.write(data) progress_bar.update(size) return model_path def load_sam_model(): """Load SAM model with automatic download""" DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MODEL_TYPE = "vit_h" checkpoint_path = download_sam_model() print(f"Carregando modelo SAM no dispositivo: {DEVICE}") sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path) sam.to(device=DEVICE) predictor = SamPredictor(sam) return predictor def enhance_xray(image, brightness=1.0, contrast=1.0, sharpness=1.0): """Enhance X-ray image with adjustable parameters""" if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) img = Image.fromarray(image) # Apply enhancements img = ImageEnhance.Brightness(img).enhance(brightness) img = ImageEnhance.Contrast(img).enhance(contrast) img = ImageEnhance.Sharpness(img).enhance(sharpness) return np.array(img) def generate_dense_points(image, num_points=100): """Gera uma grade densa de pontos para melhor segmentação""" h, w = image.shape[:2] points = [] # Grade regular num_x = int(np.sqrt(num_points * w / h)) num_y = int(num_points / num_x) x_points = np.linspace(w * 0.05, w * 0.95, num_x) y_points = np.linspace(h * 0.05, h * 0.95, num_y) # Grade regular for x in x_points: for y in y_points: points.append([x, y]) # Adiciona pontos aleatórios para melhor cobertura num_random = num_points - len(points) if num_random > 0: random_x = np.random.uniform(w * 0.05, w * 0.95, num_random) random_y = np.random.uniform(h * 0.05, h * 0.95, num_random) for x, y in zip(random_x, random_y): points.append([x, y]) return np.array(points) def show_mask(mask, image, alpha=0.5): """Aplica uma máscara colorida sobre a imagem com transparência ajustável""" # Gera uma cor aleatória vibrante hue = np.random.random() color = plt.cm.hsv(hue)[:3] color = np.array([*color, alpha]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) if image.shape[-1] == 3: img_rgba = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA) else: img_rgba = image mask_image_uint8 = (mask_image * 255).astype(np.uint8) # Mistura a máscara com a imagem usando alpha blending bg = img_rgba.astype(float) fg = mask_image_uint8.astype(float) alpha = fg[..., 3:] / 255.0 fg_scaled = fg[..., :3] * alpha bg_scaled = bg[..., :3] * (1 - alpha) result = (fg_scaled + bg_scaled).astype(np.uint8) return cv2.cvtColor(result, cv2.COLOR_RGBA2RGB) def segment_teeth(image, predictor): """Segment teeth using SAM2 with improved detection""" if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) predictor.set_image(image) # Gerar pontos densos para melhor cobertura input_points = generate_dense_points(image, num_points=150) input_labels = np.ones(len(input_points)) # Parâmetros ajustados para melhor detecção masks, scores, logits = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=True ) # Filtrar máscaras por score filtered_masks = [] filtered_scores = [] for mask, score in zip(masks, scores): if score > 0.5: # Ajuste este limiar conforme necessário filtered_masks.append(mask) filtered_scores.append(score) # Criar visualização da segmentação final_image = image.copy() # Ordenar máscaras por tamanho para melhor visualização mask_sizes = [np.sum(mask) for mask in filtered_masks] sorted_indices = np.argsort(mask_sizes)[::-1] for idx in sorted_indices: final_image = show_mask(filtered_masks[idx], final_image, alpha=0.4) # Desenhar pontos de referência menores e mais discretos for point in input_points: cv2.circle(final_image, (int(point[0]), int(point[1])), 2, (0, 127, 255), -1) return final_image, filtered_masks def analyze_xray(image, brightness, contrast, sharpness, enable_segmentation): """Main function to process X-ray images""" if image is None: return None, "Por favor, carregue uma imagem." try: # Enhance image enhanced = enhance_xray(image, brightness, contrast, sharpness) # Initialize result result = enhanced # Perform segmentation if enabled if enable_segmentation: # Ensure predictor is initialized global sam_predictor if 'sam_predictor' not in globals(): sam_predictor = load_sam_model() result, masks = segment_teeth(enhanced, sam_predictor) seg_status = "Segmentação realizada com sucesso" num_segments = len(masks) # Calcular área total segmentada total_area = sum(mask.sum() for mask in masks) image_area = image.shape[0] * image.shape[1] coverage_percent = (total_area / image_area) * 100 else: seg_status = "Segmentação desativada" num_segments = 0 coverage_percent = 0 # Generate analysis text analysis = f"""📊 Análise da Radiografia: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━ �밝 Parâmetros de Imagem: • Brilho: {brightness:.1f} • Contraste: {contrast:.1f} • Nitidez: {sharpness:.1f} 🔬 Resultado da Segmentação: • Status: {seg_status} • Estruturas detectadas: {num_segments} • Cobertura da análise: {coverage_percent:.1f}% 📝 Observações: • Processamento concluído com sucesso • Pontos laranja indicam regiões de análise • Áreas coloridas mostram estruturas detectadas""" return result, analysis except Exception as e: import traceback error_trace = traceback.format_exc() print(error_trace) return None, f"Erro durante o processamento: {str(e)}" # Create Gradio interface with modern design with gr.Blocks(css=css, title="Analisador de Radiografias Dentárias") as app: with gr.Column(elem_classes="main-container"): # Header with gr.Column(elem_classes="header"): gr.Markdown("# 🦷 Analisador Inteligente de Radiografias Dentárias") gr.Markdown(""" Análise avançada de radiografias com Inteligência Artificial Utilizando o modelo SAM2 (Segment Anything Model 2) """) with gr.Row(): # Input Column with gr.Column(scale=1): with gr.Group(elem_classes="controls"): input_image = gr.Image( label="Radiografia Original", type="numpy", elem_classes="image-input" ) with gr.Column(): brightness = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, label="Brilho", elem_classes="slider-label" ) contrast = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, label="Contraste", elem_classes="slider-label" ) sharpness = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, label="Nitidez", elem_classes="slider-label" ) enable_segmentation = gr.Checkbox( label="Ativar Segmentação Inteligente", value=True ) analyze_btn = gr.Button( "🔍 Analisar Radiografia", elem_classes="button-primary" ) # Output Column with gr.Column(scale=1): with gr.Group(elem_classes="output-container"): output_image = gr.Image( label="Resultado da Análise", type="numpy", elem_classes="image-output" ) analysis_text = gr.Textbox( label="Análise Detalhada", lines=8, elem_classes="analysis-text" ) # Footer gr.Markdown(""" ### 📋 Instruções de Uso 1. Carregue uma radiografia dental 2. Ajuste os parâmetros de imagem conforme necessário 3. Ative a segmentação inteligente para detectar estruturas 4. Clique em "Analisar Radiografia" ⚠️ Nota: O primeiro uso pode levar alguns minutos para baixar o modelo. """) analyze_btn.click( analyze_xray, inputs=[input_image, brightness, contrast, sharpness, enable_segmentation], outputs=[output_image, analysis_text] ) # Launch the app if __name__ == "__main__": sam_predictor = load_sam_model() app.launch()