Dental-X-ray.ai / app.py
DHEIVER's picture
Update app.py
a6e2c32 verified
import gradio as gr
import numpy as np
from PIL import Image, ImageEnhance
import cv2
import torch
import os
import requests
from tqdm import tqdm
from segment_anything import sam_model_registry, SamPredictor
import matplotlib.pyplot as plt
# Define o CSS customizado
css = """
.gradio-container {
font-family: 'Inter', sans-serif;
max-width: 1200px !important;
margin: auto;
}
.main-container {
background-color: #ffffff;
border-radius: 15px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
padding: 20px;
margin: 20px;
}
.header {
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%);
color: white;
padding: 20px;
border-radius: 10px;
margin-bottom: 20px;
text-align: center;
}
.controls {
background-color: #f8fafc;
padding: 15px;
border-radius: 10px;
margin: 10px 0;
}
.slider-label {
color: #374151;
font-weight: 500;
}
.button-primary {
background: linear-gradient(90deg, #2563eb 0%, #3b82f6 100%);
border: none;
padding: 10px 20px;
border-radius: 8px;
color: white;
font-weight: 600;
cursor: pointer;
transition: transform 0.2s;
}
.button-primary:hover {
transform: translateY(-2px);
}
.output-container {
background-color: #f8fafc;
padding: 15px;
border-radius: 10px;
margin-top: 20px;
}
.analysis-text {
font-family: 'Mono', monospace;
background-color: #1e293b;
color: #f8fafc;
padding: 15px;
border-radius: 8px;
white-space: pre-wrap;
}
.image-input, .image-output {
border-radius: 10px;
overflow: hidden;
border: 2px solid #e2e8f0;
}
"""
def download_sam_model():
"""Download SAM model if not already present"""
model_path = "sam_vit_h_4b8939.pth"
if not os.path.exists(model_path):
print("Baixando modelo SAM... Isso pode levar alguns minutos.")
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(model_path, 'wb') as file, tqdm(
desc=model_path,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
progress_bar.update(size)
return model_path
def load_sam_model():
"""Load SAM model with automatic download"""
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
checkpoint_path = download_sam_model()
print(f"Carregando modelo SAM no dispositivo: {DEVICE}")
sam = sam_model_registry[MODEL_TYPE](checkpoint=checkpoint_path)
sam.to(device=DEVICE)
predictor = SamPredictor(sam)
return predictor
def enhance_xray(image, brightness=1.0, contrast=1.0, sharpness=1.0):
"""Enhance X-ray image with adjustable parameters"""
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
img = Image.fromarray(image)
# Apply enhancements
img = ImageEnhance.Brightness(img).enhance(brightness)
img = ImageEnhance.Contrast(img).enhance(contrast)
img = ImageEnhance.Sharpness(img).enhance(sharpness)
return np.array(img)
def generate_dense_points(image, num_points=100):
"""Gera uma grade densa de pontos para melhor segmentação"""
h, w = image.shape[:2]
points = []
# Grade regular
num_x = int(np.sqrt(num_points * w / h))
num_y = int(num_points / num_x)
x_points = np.linspace(w * 0.05, w * 0.95, num_x)
y_points = np.linspace(h * 0.05, h * 0.95, num_y)
# Grade regular
for x in x_points:
for y in y_points:
points.append([x, y])
# Adiciona pontos aleatórios para melhor cobertura
num_random = num_points - len(points)
if num_random > 0:
random_x = np.random.uniform(w * 0.05, w * 0.95, num_random)
random_y = np.random.uniform(h * 0.05, h * 0.95, num_random)
for x, y in zip(random_x, random_y):
points.append([x, y])
return np.array(points)
def show_mask(mask, image, alpha=0.5):
"""Aplica uma máscara colorida sobre a imagem com transparência ajustável"""
# Gera uma cor aleatória vibrante
hue = np.random.random()
color = plt.cm.hsv(hue)[:3]
color = np.array([*color, alpha])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
if image.shape[-1] == 3:
img_rgba = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
else:
img_rgba = image
mask_image_uint8 = (mask_image * 255).astype(np.uint8)
# Mistura a máscara com a imagem usando alpha blending
bg = img_rgba.astype(float)
fg = mask_image_uint8.astype(float)
alpha = fg[..., 3:] / 255.0
fg_scaled = fg[..., :3] * alpha
bg_scaled = bg[..., :3] * (1 - alpha)
result = (fg_scaled + bg_scaled).astype(np.uint8)
return cv2.cvtColor(result, cv2.COLOR_RGBA2RGB)
def segment_teeth(image, predictor):
"""Segment teeth using SAM2 with improved detection"""
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
predictor.set_image(image)
# Gerar pontos densos para melhor cobertura
input_points = generate_dense_points(image, num_points=150)
input_labels = np.ones(len(input_points))
# Parâmetros ajustados para melhor detecção
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=True
)
# Filtrar máscaras por score
filtered_masks = []
filtered_scores = []
for mask, score in zip(masks, scores):
if score > 0.5: # Ajuste este limiar conforme necessário
filtered_masks.append(mask)
filtered_scores.append(score)
# Criar visualização da segmentação
final_image = image.copy()
# Ordenar máscaras por tamanho para melhor visualização
mask_sizes = [np.sum(mask) for mask in filtered_masks]
sorted_indices = np.argsort(mask_sizes)[::-1]
for idx in sorted_indices:
final_image = show_mask(filtered_masks[idx], final_image, alpha=0.4)
# Desenhar pontos de referência menores e mais discretos
for point in input_points:
cv2.circle(final_image, (int(point[0]), int(point[1])), 2, (0, 127, 255), -1)
return final_image, filtered_masks
def analyze_xray(image, brightness, contrast, sharpness, enable_segmentation):
"""Main function to process X-ray images"""
if image is None:
return None, "Por favor, carregue uma imagem."
try:
# Enhance image
enhanced = enhance_xray(image, brightness, contrast, sharpness)
# Initialize result
result = enhanced
# Perform segmentation if enabled
if enable_segmentation:
# Ensure predictor is initialized
global sam_predictor
if 'sam_predictor' not in globals():
sam_predictor = load_sam_model()
result, masks = segment_teeth(enhanced, sam_predictor)
seg_status = "Segmentação realizada com sucesso"
num_segments = len(masks)
# Calcular área total segmentada
total_area = sum(mask.sum() for mask in masks)
image_area = image.shape[0] * image.shape[1]
coverage_percent = (total_area / image_area) * 100
else:
seg_status = "Segmentação desativada"
num_segments = 0
coverage_percent = 0
# Generate analysis text
analysis = f"""📊 Análise da Radiografia:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━
�밝 Parâmetros de Imagem:
• Brilho: {brightness:.1f}
• Contraste: {contrast:.1f}
• Nitidez: {sharpness:.1f}
🔬 Resultado da Segmentação:
• Status: {seg_status}
• Estruturas detectadas: {num_segments}
• Cobertura da análise: {coverage_percent:.1f}%
📝 Observações:
• Processamento concluído com sucesso
• Pontos laranja indicam regiões de análise
• Áreas coloridas mostram estruturas detectadas"""
return result, analysis
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(error_trace)
return None, f"Erro durante o processamento: {str(e)}"
# Create Gradio interface with modern design
with gr.Blocks(css=css, title="Analisador de Radiografias Dentárias") as app:
with gr.Column(elem_classes="main-container"):
# Header
with gr.Column(elem_classes="header"):
gr.Markdown("# 🦷 Analisador Inteligente de Radiografias Dentárias")
gr.Markdown("""
Análise avançada de radiografias com Inteligência Artificial
Utilizando o modelo SAM2 (Segment Anything Model 2)
""")
with gr.Row():
# Input Column
with gr.Column(scale=1):
with gr.Group(elem_classes="controls"):
input_image = gr.Image(
label="Radiografia Original",
type="numpy",
elem_classes="image-input"
)
with gr.Column():
brightness = gr.Slider(
minimum=0.1, maximum=2.0, value=1.0,
label="Brilho",
elem_classes="slider-label"
)
contrast = gr.Slider(
minimum=0.1, maximum=2.0, value=1.0,
label="Contraste",
elem_classes="slider-label"
)
sharpness = gr.Slider(
minimum=0.1, maximum=2.0, value=1.0,
label="Nitidez",
elem_classes="slider-label"
)
enable_segmentation = gr.Checkbox(
label="Ativar Segmentação Inteligente",
value=True
)
analyze_btn = gr.Button(
"🔍 Analisar Radiografia",
elem_classes="button-primary"
)
# Output Column
with gr.Column(scale=1):
with gr.Group(elem_classes="output-container"):
output_image = gr.Image(
label="Resultado da Análise",
type="numpy",
elem_classes="image-output"
)
analysis_text = gr.Textbox(
label="Análise Detalhada",
lines=8,
elem_classes="analysis-text"
)
# Footer
gr.Markdown("""
### 📋 Instruções de Uso
1. Carregue uma radiografia dental
2. Ajuste os parâmetros de imagem conforme necessário
3. Ative a segmentação inteligente para detectar estruturas
4. Clique em "Analisar Radiografia"
⚠️ Nota: O primeiro uso pode levar alguns minutos para baixar o modelo.
""")
analyze_btn.click(
analyze_xray,
inputs=[input_image, brightness, contrast, sharpness, enable_segmentation],
outputs=[output_image, analysis_text]
)
# Launch the app
if __name__ == "__main__":
sam_predictor = load_sam_model()
app.launch()