abiabidali's picture
Update app.py
f899183 verified
raw
history blame
5.42 kB
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from RealESRGAN import RealESRGAN
import gradio as gr
import numpy as np
import tempfile
import time
import os
# Set device to GPU if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the Real-ESRGAN model with specified scale
def load_model(scale):
model = RealESRGAN(device, scale=scale)
weights_path = f'weights/RealESRGAN_x{scale}.pth'
try:
model.load_weights(weights_path, download=True)
print(f"Weights for scale {scale} loaded successfully.")
except Exception as e:
print(f"Error loading weights for scale {scale}: {e}")
model.load_weights(weights_path, download=False)
return model
# Load different scales of the Real-ESRGAN model
model2 = load_model(2)
model4 = load_model(4)
model8 = load_model(8)
# Initialize BLIP processor and model for image captioning
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
# Enhance the image using the specified scale
def enhance_image(image, scale):
try:
print(f"Enhancing image with scale {scale}...")
start_time = time.time()
image_np = np.array(image.convert('RGB'))
print(f"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}")
if scale == '2x':
result = model2.predict(image_np)
elif scale == '4x':
result = model4.predict(image_np)
else:
result = model8.predict(image_np)
enhanced_image = Image.fromarray(np.uint8(result))
print(f"Image enhanced in {time.time() - start_time:.2f} seconds")
return enhanced_image
except Exception as e:
print(f"Error enhancing image: {e}")
return image
# Generate captions for the images using BLIP
def generate_caption(image):
inputs = processor(images=image, return_tensors="pt").to(device)
output_ids = model.generate(**inputs)
caption = processor.decode(output_ids[0], skip_special_tokens=True)
return caption
# Adjust the DPI of the image
def muda_dpi(input_image, dpi):
dpi_tuple = (dpi, dpi)
image = Image.fromarray(input_image.astype('uint8'), 'RGB')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
image.save(temp_file, format='JPEG', dpi=dpi_tuple)
temp_file.close()
return Image.open(temp_file.name)
# Resize the image to the specified width and height
def resize_image(input_image, width, height):
image = Image.fromarray(input_image.astype('uint8'), 'RGB')
resized_image = image.resize((width, height))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
resized_image.save(temp_file, format='JPEG')
temp_file.close()
return Image.open(temp_file.name)
# Process the images: enhance, adjust DPI, resize, caption, and save
def process_images(image_files, enhance, scale, adjust_dpi, dpi, resize, width, height):
processed_images = []
file_paths = []
captions = []
for i, image_file in enumerate(image_files):
input_image = np.array(Image.open(image_file).convert('RGB'))
original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
if enhance:
original_image = enhance_image(original_image, scale)
if adjust_dpi:
original_image = muda_dpi(np.array(original_image), dpi)
if resize:
original_image = resize_image(np.array(original_image), width, height)
# Generate a caption for the image
caption = generate_caption(original_image)
captions.append(caption)
# Create a custom filename
custom_filename = f"Image_Captioning_with_BLIP_{i+1}.jpg"
# Save the image with the custom filename
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
original_image.save(temp_file.name, format='JPEG')
# Rename the file with the custom name
final_path = temp_file.name.replace(temp_file.name.split('/')[-1], custom_filename)
os.rename(temp_file.name, final_path)
processed_images.append(original_image)
file_paths.append(final_path)
return processed_images, file_paths, captions
# Gradio interface setup
iface = gr.Interface(
fn=process_images,
inputs=[
gr.Files(label="Upload Image Files"),
gr.Checkbox(label="Enhance Images (ESRGAN)"),
gr.Radio(['2x', '4x', '8x'], type="value", value='2x', label='Resolution model'),
gr.Checkbox(label="Adjust DPI"),
gr.Number(label="DPI", value=300),
gr.Checkbox(label="Resize"),
gr.Number(label="Width", value=512),
gr.Number(label="Height", value=512)
],
outputs=[
gr.Gallery(label="Final Images"),
gr.Files(label="Download Final Images"),
gr.Textbox(label="Image Captions")
],
title="Multi-Image Enhancer with Captioning",
description="Upload multiple images (.jpg, .png), enhance using AI, adjust DPI, resize, generate captions, and download the final results."
)
# Launch the Gradio interface
iface.launch(debug=True)