Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image, ImageEnhance | |
import cv2 | |
import torch | |
import os | |
import requests | |
from tqdm import tqdm | |
from segment_anything import sam_model_registry, SamPredictor | |
import matplotlib.pyplot as plt | |
# Define o CSS customizado | |
css = """ | |
.gradio-container { | |
font-family: 'Inter', sans-serif; | |
max-width: 1200px !important; | |
margin: auto; | |
} | |
.main-container { | |
background-color: #ffffff; | |
border-radius: 15px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
padding: 20px; | |
margin: 20px; | |
} | |
.header { | |
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%); | |
color: white; | |
padding: 20px; | |
border-radius: 10px; | |
margin-bottom: 20px; | |
text-align: center; | |
} | |
.controls { | |
background-color: #f8fafc; | |
padding: 15px; | |
border-radius: 10px; | |
margin: 10px 0; | |
} | |
.slider-label { | |
color: #374151; | |
font-weight: 500; | |
} | |
.button-primary { | |
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%); | |
border: none; | |
padding: 10px 20px; | |
border-radius: 8px; | |
color: white; | |
font-weight: 600; | |
cursor: pointer; | |
transition: transform 0.2s; | |
} | |
.button-primary:hover { | |
transform: translateY(-2px); | |
} | |
.output-container { | |
background-color: #f8fafc; | |
padding: 15px; | |
border-radius: 10px; | |
margin-top: 20px; | |
} | |
.analysis-text { | |
font-family: 'Mono', monospace; | |
background-color: #1e293b; | |
color: #f8fafc; | |
padding: 15px; | |
border-radius: 8px; | |
white-space: pre-wrap; | |
} | |
.image-input, .image-output { | |
border-radius: 10px; | |
overflow: hidden; | |
border: 2px solid #e2e8f0; | |
} | |
""" | |
def download_sam_model(): | |
"""Download SAM model if not already present""" | |
model_path = "sam_vit_h_4b8939.pth" | |
if not os.path.exists(model_path): | |
print("Baixando modelo SAM... Isso pode levar alguns minutos.") | |
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" | |
response = requests.get(url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
with open(model_path, 'wb') as file, tqdm( | |
desc=model_path, | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as progress_bar: | |
for data in response.iter_content(chunk_size=1024): | |
size = file.write(data) | |
progress_bar.update(size) | |
return model_path | |
def load_sam_model(): | |
"""Load SAM model with automatic download""" | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
MODEL_TYPE = "vit_h" | |
checkpoint_path = download_sam_model() | |
print(f"Carregando modelo SAM no dispositivo: {DEVICE}") | |
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path) | |
sam.to(device=DEVICE) | |
predictor = SamPredictor(sam) | |
return predictor | |
def enhance_xray(image, brightness=1.0, contrast=1.0, sharpness=1.0): | |
"""Enhance X-ray image with adjustable parameters""" | |
if len(image.shape) == 2: | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
img = Image.fromarray(image) | |
# Apply enhancements | |
img = ImageEnhance.Brightness(img).enhance(brightness) | |
img = ImageEnhance.Contrast(img).enhance(contrast) | |
img = ImageEnhance.Sharpness(img).enhance(sharpness) | |
return np.array(img) | |
def generate_dense_points(image, num_points=100): | |
"""Gera uma grade densa de pontos para melhor segmentação""" | |
h, w = image.shape[:2] | |
points = [] | |
# Grade regular | |
num_x = int(np.sqrt(num_points * w / h)) | |
num_y = int(num_points / num_x) | |
x_points = np.linspace(w * 0.05, w * 0.95, num_x) | |
y_points = np.linspace(h * 0.05, h * 0.95, num_y) | |
# Grade regular | |
for x in x_points: | |
for y in y_points: | |
points.append([x, y]) | |
# Adiciona pontos aleatórios para melhor cobertura | |
num_random = num_points - len(points) | |
if num_random > 0: | |
random_x = np.random.uniform(w * 0.05, w * 0.95, num_random) | |
random_y = np.random.uniform(h * 0.05, h * 0.95, num_random) | |
for x, y in zip(random_x, random_y): | |
points.append([x, y]) | |
return np.array(points) | |
def show_mask(mask, image, alpha=0.5): | |
"""Aplica uma máscara colorida sobre a imagem com transparência ajustável""" | |
# Gera uma cor aleatória vibrante | |
hue = np.random.random() | |
color = plt.cm.hsv(hue)[:3] | |
color = np.array([*color, alpha]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
if image.shape[-1] == 3: | |
img_rgba = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA) | |
else: | |
img_rgba = image | |
mask_image_uint8 = (mask_image * 255).astype(np.uint8) | |
# Mistura a máscara com a imagem usando alpha blending | |
bg = img_rgba.astype(float) | |
fg = mask_image_uint8.astype(float) | |
alpha = fg[..., 3:] / 255.0 | |
fg_scaled = fg[..., :3] * alpha | |
bg_scaled = bg[..., :3] * (1 - alpha) | |
result = (fg_scaled + bg_scaled).astype(np.uint8) | |
return cv2.cvtColor(result, cv2.COLOR_RGBA2RGB) | |
def segment_teeth(image, predictor): | |
"""Segment teeth using SAM2 with improved detection""" | |
if len(image.shape) == 2: | |
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
predictor.set_image(image) | |
# Gerar pontos densos para melhor cobertura | |
input_points = generate_dense_points(image, num_points=150) | |
input_labels = np.ones(len(input_points)) | |
# Parâmetros ajustados para melhor detecção | |
masks, scores, logits = predictor.predict( | |
point_coords=input_points, | |
point_labels=input_labels, | |
multimask_output=True | |
) | |
# Filtrar máscaras por score | |
filtered_masks = [] | |
filtered_scores = [] | |
for mask, score in zip(masks, scores): | |
if score > 0.5: # Ajuste este limiar conforme necessário | |
filtered_masks.append(mask) | |
filtered_scores.append(score) | |
# Criar visualização da segmentação | |
final_image = image.copy() | |
# Ordenar máscaras por tamanho para melhor visualização | |
mask_sizes = [np.sum(mask) for mask in filtered_masks] | |
sorted_indices = np.argsort(mask_sizes)[::-1] | |
for idx in sorted_indices: | |
final_image = show_mask(filtered_masks[idx], final_image, alpha=0.4) | |
# Desenhar pontos de referência menores e mais discretos | |
for point in input_points: | |
cv2.circle(final_image, (int(point[0]), int(point[1])), 2, (0, 127, 255), -1) | |
return final_image, filtered_masks | |
def analyze_xray(image, brightness, contrast, sharpness, enable_segmentation): | |
"""Main function to process X-ray images""" | |
if image is None: | |
return None, "Por favor, carregue uma imagem." | |
try: | |
# Enhance image | |
enhanced = enhance_xray(image, brightness, contrast, sharpness) | |
# Initialize result | |
result = enhanced | |
# Perform segmentation if enabled | |
if enable_segmentation: | |
# Ensure predictor is initialized | |
global sam_predictor | |
if 'sam_predictor' not in globals(): | |
sam_predictor = load_sam_model() | |
result, masks = segment_teeth(enhanced, sam_predictor) | |
seg_status = "Segmentação realizada com sucesso" | |
num_segments = len(masks) | |
# Calcular área total segmentada | |
total_area = sum(mask.sum() for mask in masks) | |
image_area = image.shape[0] * image.shape[1] | |
coverage_percent = (total_area / image_area) * 100 | |
else: | |
seg_status = "Segmentação desativada" | |
num_segments = 0 | |
coverage_percent = 0 | |
# Generate analysis text | |
analysis = f"""📊 Análise da Radiografia: | |
━━━━━━━━━━━━━━━━━━━━━━━━━━━━ | |
�밝 Parâmetros de Imagem: | |
• Brilho: {brightness:.1f} | |
• Contraste: {contrast:.1f} | |
• Nitidez: {sharpness:.1f} | |
🔬 Resultado da Segmentação: | |
• Status: {seg_status} | |
• Estruturas detectadas: {num_segments} | |
• Cobertura da análise: {coverage_percent:.1f}% | |
📝 Observações: | |
• Processamento concluído com sucesso | |
• Pontos laranja indicam regiões de análise | |
• Áreas coloridas mostram estruturas detectadas""" | |
return result, analysis | |
except Exception as e: | |
import traceback | |
error_trace = traceback.format_exc() | |
print(error_trace) | |
return None, f"Erro durante o processamento: {str(e)}" | |
# Create Gradio interface with modern design | |
with gr.Blocks(css=css, title="Analisador de Radiografias Dentárias") as app: | |
with gr.Column(elem_classes="main-container"): | |
# Header | |
with gr.Column(elem_classes="header"): | |
gr.Markdown("# 🦷 Analisador Inteligente de Radiografias Dentárias") | |
gr.Markdown(""" | |
Análise avançada de radiografias com Inteligência Artificial | |
Utilizando o modelo SAM2 (Segment Anything Model 2) | |
""") | |
with gr.Row(): | |
# Input Column | |
with gr.Column(scale=1): | |
with gr.Group(elem_classes="controls"): | |
input_image = gr.Image( | |
label="Radiografia Original", | |
type="numpy", | |
elem_classes="image-input" | |
) | |
with gr.Column(): | |
brightness = gr.Slider( | |
minimum=0.1, maximum=2.0, value=1.0, | |
label="Brilho", | |
elem_classes="slider-label" | |
) | |
contrast = gr.Slider( | |
minimum=0.1, maximum=2.0, value=1.0, | |
label="Contraste", | |
elem_classes="slider-label" | |
) | |
sharpness = gr.Slider( | |
minimum=0.1, maximum=2.0, value=1.0, | |
label="Nitidez", | |
elem_classes="slider-label" | |
) | |
enable_segmentation = gr.Checkbox( | |
label="Ativar Segmentação Inteligente", | |
value=True | |
) | |
analyze_btn = gr.Button( | |
"🔍 Analisar Radiografia", | |
elem_classes="button-primary" | |
) | |
# Output Column | |
with gr.Column(scale=1): | |
with gr.Group(elem_classes="output-container"): | |
output_image = gr.Image( | |
label="Resultado da Análise", | |
type="numpy", | |
elem_classes="image-output" | |
) | |
analysis_text = gr.Textbox( | |
label="Análise Detalhada", | |
lines=8, | |
elem_classes="analysis-text" | |
) | |
# Footer | |
gr.Markdown(""" | |
### 📋 Instruções de Uso | |
1. Carregue uma radiografia dental | |
2. Ajuste os parâmetros de imagem conforme necessário | |
3. Ative a segmentação inteligente para detectar estruturas | |
4. Clique em "Analisar Radiografia" | |
⚠️ Nota: O primeiro uso pode levar alguns minutos para baixar o modelo. | |
""") | |
analyze_btn.click( | |
analyze_xray, | |
inputs=[input_image, brightness, contrast, sharpness, enable_segmentation], | |
outputs=[output_image, analysis_text] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
sam_predictor = load_sam_model() | |
app.launch() |