|
|
|
import os |
|
import base64 |
|
import time |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from PIL import Image |
|
import gradio as gr |
|
import torchvision.transforms as transforms |
|
from transformers import AutoModel, AutoTokenizer |
|
from diffusers import StableDiffusionPipeline |
|
from torch.utils.data import Dataset, DataLoader |
|
import asyncio |
|
import aiofiles |
|
import fitz |
|
import requests |
|
import logging |
|
from io import BytesIO |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
norm_layer = nn.InstanceNorm2d |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_features): |
|
super(ResidualBlock, self).__init__() |
|
conv_block = [ |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(in_features, in_features, 3), |
|
norm_layer(in_features), |
|
nn.ReLU(inplace=True), |
|
nn.ReflectionPad2d(1), |
|
nn.Conv2d(in_features, in_features, 3), |
|
norm_layer(in_features) |
|
] |
|
self.conv_block = nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
return x + self.conv_block(x) |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): |
|
super(Generator, self).__init__() |
|
model0 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True)] |
|
self.model0 = nn.Sequential(*model0) |
|
model1 = [] |
|
in_features, out_features = 64, 128 |
|
for _ in range(2): |
|
model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True)] |
|
in_features, out_features = out_features, out_features * 2 |
|
self.model1 = nn.Sequential(*model1) |
|
model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)] |
|
self.model2 = nn.Sequential(*model2) |
|
model3 = [] |
|
out_features = in_features // 2 |
|
for _ in range(2): |
|
model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True)] |
|
in_features, out_features = out_features, out_features // 2 |
|
self.model3 = nn.Sequential(*model3) |
|
model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)] |
|
if sigmoid: |
|
model4 += [nn.Sigmoid()] |
|
self.model4 = nn.Sequential(*model4) |
|
|
|
def forward(self, x, cond=None): |
|
out = self.model0(x) |
|
out = self.model1(out) |
|
out = self.model2(out) |
|
out = self.model3(out) |
|
out = self.model4(out) |
|
return out |
|
|
|
|
|
model1 = Generator(3, 1, 3) |
|
model2 = Generator(3, 1, 3) |
|
try: |
|
model1.load_state_dict(torch.load('model.pth', map_location='cpu', weights_only=True)) |
|
model2.load_state_dict(torch.load('model2.pth', map_location='cpu', weights_only=True)) |
|
except FileNotFoundError: |
|
logger.warning("Model files not found. Please ensure 'model.pth' and 'model2.pth' are available.") |
|
model1.eval() |
|
model2.eval() |
|
|
|
|
|
class TinyUNet(nn.Module): |
|
def __init__(self, in_channels=3, out_channels=3): |
|
super(TinyUNet, self).__init__() |
|
self.down1 = nn.Conv2d(in_channels, 32, 3, padding=1) |
|
self.down2 = nn.Conv2d(32, 64, 3, padding=1, stride=2) |
|
self.mid = nn.Conv2d(64, 128, 3, padding=1) |
|
self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1) |
|
self.up2 = nn.Conv2d(64 + 32, 32, 3, padding=1) |
|
self.out = nn.Conv2d(32, out_channels, 3, padding=1) |
|
self.time_embed = nn.Linear(1, 64) |
|
|
|
def forward(self, x, t): |
|
t_embed = F.relu(self.time_embed(t.unsqueeze(-1))).view(t_embed.size(0), t_embed.size(1), 1, 1) |
|
x1 = F.relu(self.down1(x)) |
|
x2 = F.relu(self.down2(x1)) |
|
x_mid = F.relu(self.mid(x2)) + t_embed |
|
x_up1 = F.relu(self.up1(x_mid)) |
|
x_up2 = F.relu(self.up2(torch.cat([x_up1, x1], dim=1))) |
|
return self.out(x_up2) |
|
|
|
class TinyDiffusion: |
|
def __init__(self, model, timesteps=100): |
|
self.model = model |
|
self.timesteps = timesteps |
|
self.beta = torch.linspace(0.0001, 0.02, timesteps) |
|
self.alpha = 1 - self.beta |
|
self.alpha_cumprod = torch.cumprod(self.alpha, dim=0) |
|
|
|
def train(self, images, epochs=10): |
|
dataset = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images] |
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=True) |
|
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4) |
|
device = torch.device("cpu") |
|
self.model.to(device) |
|
for epoch in range(epochs): |
|
total_loss = 0 |
|
for x in dataloader: |
|
x = x.to(device) |
|
t = torch.randint(0, self.timesteps, (x.size(0),), device=device).float() |
|
noise = torch.randn_like(x) |
|
alpha_t = self.alpha_cumprod[t.long()].view(-1, 1, 1, 1) |
|
x_noisy = torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise |
|
pred_noise = self.model(x_noisy, t) |
|
loss = F.mse_loss(pred_noise, noise) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
total_loss += loss.item() |
|
logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}") |
|
return self |
|
|
|
def generate(self, size=(64, 64), steps=100): |
|
device = torch.device("cpu") |
|
x = torch.randn(1, 3, size[0], size[1], device=device) |
|
for t in reversed(range(steps)): |
|
t_tensor = torch.full((1,), t, device=device, dtype=torch.float32) |
|
alpha_t = self.alpha_cumprod[t].view(-1, 1, 1, 1) |
|
pred_noise = self.model(x, t_tensor) |
|
x = (x - (1 - self.alpha[t]) / torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(self.alpha[t]) |
|
if t > 0: |
|
x += torch.sqrt(self.beta[t]) * torch.randn_like(x) |
|
x = torch.clamp(x * 255, 0, 255).byte() |
|
return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy()) |
|
|
|
|
|
def generate_filename(sequence, ext="png"): |
|
timestamp = time.strftime("%d%m%Y%H%M%S") |
|
return f"{sequence}_{timestamp}.{ext}" |
|
|
|
def predict_line_drawing(input_img, ver): |
|
original_img = Image.open(input_img) if isinstance(input_img, str) else input_img |
|
original_size = original_img.size |
|
transform = transforms.Compose([ |
|
transforms.Resize(256, Image.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
input_tensor = transform(original_img).unsqueeze(0) |
|
with torch.no_grad(): |
|
output = model2(input_tensor) if ver == 'Simple Lines' else model1(input_tensor) |
|
output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1)) |
|
return output_img.resize(original_size, Image.BICUBIC) |
|
|
|
async def process_ocr(image): |
|
tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval() |
|
result = model.chat(tokenizer, image, ocr_type='ocr') |
|
output_file = generate_filename("ocr_output", "txt") |
|
async with aiofiles.open(output_file, "w") as f: |
|
await f.write(result) |
|
return result, output_file |
|
|
|
async def process_diffusion(images): |
|
unet = TinyUNet() |
|
diffusion = TinyDiffusion(unet) |
|
diffusion.train(images) |
|
gen_image = diffusion.generate() |
|
output_file = generate_filename("diffusion_output", "png") |
|
gen_image.save(output_file) |
|
return gen_image, output_file |
|
|
|
def download_pdf(url): |
|
output_path = f"pdf_{int(time.time())}.pdf" |
|
response = requests.get(url, stream=True, timeout=10) |
|
if response.status_code == 200: |
|
with open(output_path, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
return output_path |
|
return None |
|
|
|
|
|
with gr.Blocks(title="Mystical AI Vision Studio π", css=""" |
|
.gr-button {background-color: #4CAF50; color: white;} |
|
.gr-tab {border: 2px solid #2196F3; border-radius: 5px;} |
|
#gallery img {border: 1px solid #ddd; border-radius: 4px;} |
|
""") as demo: |
|
gr.Markdown("<h1 style='text-align: center; color: #2196F3;'>Mystical AI Vision Studio π</h1>") |
|
gr.Markdown("<p style='text-align: center;'>Transform images into line drawings, extract text with OCR, and craft unique art with diffusion!</p>") |
|
|
|
with gr.Tab("Image to Line Drawings π¨"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
img_input = gr.Image(type="pil", label="Upload Image") |
|
version = gr.Radio(['Complex Lines', 'Simple Lines'], label='Style', value='Simple Lines') |
|
submit_btn = gr.Button("Generate Line Drawing") |
|
with gr.Column(): |
|
line_output = gr.Image(type="pil", label="Line Drawing") |
|
download_btn = gr.Button("Download Output") |
|
submit_btn.click(predict_line_drawing, inputs=[img_input, version], outputs=line_output) |
|
download_btn.click(lambda x: gr.File(x, label="Download Line Drawing"), inputs=line_output, outputs=None) |
|
|
|
with gr.Tab("OCR Vision π"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
ocr_input = gr.Image(type="pil", label="Upload Image or PDF Snapshot") |
|
ocr_btn = gr.Button("Extract Text") |
|
with gr.Column(): |
|
ocr_text = gr.Textbox(label="Extracted Text", interactive=False) |
|
ocr_file = gr.File(label="Download OCR Result") |
|
async def run_ocr(img): |
|
result, file_path = await process_ocr(img) |
|
return result, file_path |
|
ocr_btn.click(run_ocr, inputs=ocr_input, outputs=[ocr_text, ocr_file]) |
|
|
|
with gr.Tab("Custom Diffusion π¨π€"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
diffusion_input = gr.File(label="Upload Images for Training", multiple=True) |
|
diffusion_btn = gr.Button("Train & Generate") |
|
with gr.Column(): |
|
diffusion_output = gr.Image(type="pil", label="Generated Art") |
|
diffusion_file = gr.File(label="Download Art") |
|
async def run_diffusion(files): |
|
images = [Image.open(BytesIO(f.read())) for f in files] |
|
img, file_path = await process_diffusion(images) |
|
return img, file_path |
|
diffusion_btn.click(run_diffusion, inputs=diffusion_input, outputs=[diffusion_output, diffusion_file]) |
|
|
|
with gr.Tab("PDF Downloader π₯"): |
|
with gr.Row(): |
|
pdf_url = gr.Textbox(label="Enter PDF URL") |
|
pdf_btn = gr.Button("Download PDF") |
|
pdf_output = gr.File(label="Downloaded PDF") |
|
pdf_btn.click(download_pdf, inputs=pdf_url, outputs=pdf_output) |
|
|
|
with gr.Tab("Gallery πΈ"): |
|
gallery = gr.Gallery(label="Processed Outputs", elem_id="gallery") |
|
def update_gallery(): |
|
files = [f for f in os.listdir('.') if f.endswith(('.png', '.txt', '.pdf'))] |
|
return [f for f in files] |
|
gr.Button("Refresh Gallery").click(update_gallery, outputs=gallery) |
|
|
|
|
|
gr.HTML(""" |
|
<script> |
|
document.addEventListener('DOMContentLoaded', () => { |
|
const buttons = document.querySelectorAll('.gr-button'); |
|
buttons.forEach(btn => { |
|
btn.addEventListener('mouseover', () => btn.style.backgroundColor = '#45a049'); |
|
btn.addEventListener('mouseout', () => btn.style.backgroundColor = '#4CAF50'); |
|
}); |
|
}); |
|
</script> |
|
""") |
|
|
|
demo.launch() |