import streamlit as st from PIL import Image import torch from RealESRGAN import RealESRGAN from io import BytesIO import base64 import streamlit.components.v1 as components # Function to load the model based on scale and anime toggle def load_model(scale, anime=False): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = RealESRGAN(device, scale=scale, anime=anime) model_path = { (2, False): 'model/RealESRGAN_x2.pth', (4, False): 'model/RealESRGAN_x4plus.pth', (8, False): 'model/RealESRGAN_x8.pth', (4, True): 'model/RealESRGAN_x4plus_anime_6B.pth' }[(scale, anime)] model.load_weights(model_path) return model def enhance_image(image, scale, anime): model = load_model(scale, anime=anime) sr_image = model.predict(image) buffer = BytesIO() sr_image.save(buffer, format="PNG") buffer.seek(0) return sr_image, buffer def get_base64_image(image): buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() def image_comparison_slider(original_base64, enhanced_base64): slider_html = f"""
""" components.html(slider_html, height=600, scrolling=True) def main(): st.title("Generative AI Image Restoration") # Image upload uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"]) if uploaded_image is not None: image = Image.open(uploaded_image) st.image(image, caption="Original Image", use_column_width=True) # Anime toggle anime = st.checkbox("Anime Image", value=False) # Conditional scale options if anime: scale = "4x" # Set to 4x automatically when anime is selected else: scale = st.radio("Upscaling Factor", ["2x", "4x", "8x"], index=0) scale_value = int(scale.replace('x', '')) # Enhance button if st.button("Restore Image"): enhanced_image, buffer = enhance_image(image, scale_value, anime) # Convert images to base64 for comparison slider original_base64 = get_base64_image(image) enhanced_base64 = get_base64_image(enhanced_image) # Show comparison slider image_comparison_slider(original_base64, enhanced_base64) # Download button st.download_button( label="Download Enhanced Image", data=buffer, file_name="enhanced_image.png", mime="image/png" ) if __name__ == "__main__": main()