awacke1 commited on
Commit
2fc2f40
Β·
verified Β·
1 Parent(s): a41511f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +842 -188
app.py CHANGED
@@ -1,96 +1,130 @@
1
  #!/usr/bin/env python3
2
  import os
 
3
  import base64
4
  import time
 
 
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
- import numpy as np
9
- from PIL import Image
10
- import gradio as gr
11
- import torchvision.transforms as transforms
12
- from transformers import AutoModel, AutoTokenizer
13
  from diffusers import StableDiffusionPipeline
14
  from torch.utils.data import Dataset, DataLoader
15
- import asyncio
16
- import aiofiles
17
- import fitz # PyMuPDF
18
  import requests
 
 
 
19
  import logging
 
 
20
  from io import BytesIO
21
  from dataclasses import dataclass
22
- from typing import Optional
 
 
 
 
23
 
24
- # Logging setup
25
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
26
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Neural network layers for line drawing
29
- norm_layer = nn.InstanceNorm2d
30
-
31
- # Residual Block for Generator
32
- class ResidualBlock(nn.Module):
33
- def __init__(self, in_features):
34
- super(ResidualBlock, self).__init__()
35
- conv_block = [
36
- nn.ReflectionPad2d(1),
37
- nn.Conv2d(in_features, in_features, 3),
38
- norm_layer(in_features),
39
- nn.ReLU(inplace=True),
40
- nn.ReflectionPad2d(1),
41
- nn.Conv2d(in_features, in_features, 3),
42
- norm_layer(in_features)
43
- ]
44
- self.conv_block = nn.Sequential(*conv_block)
45
-
46
- def forward(self, x):
47
- return x + self.conv_block(x)
48
-
49
- # Generator for Line Drawings
50
- class Generator(nn.Module):
51
- def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
52
- super(Generator, self).__init__()
53
- model0 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7), norm_layer(64), nn.ReLU(inplace=True)]
54
- self.model0 = nn.Sequential(*model0)
55
- model1 = []
56
- in_features, out_features = 64, 128
57
- for _ in range(2):
58
- model1 += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), norm_layer(out_features), nn.ReLU(inplace=True)]
59
- in_features, out_features = out_features, out_features * 2
60
- self.model1 = nn.Sequential(*model1)
61
- model2 = [ResidualBlock(in_features) for _ in range(n_residual_blocks)]
62
- self.model2 = nn.Sequential(*model2)
63
- model3 = []
64
- out_features = in_features // 2
65
- for _ in range(2):
66
- model3 += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), norm_layer(out_features), nn.ReLU(inplace=True)]
67
- in_features, out_features = out_features, out_features // 2
68
- self.model3 = nn.Sequential(*model3)
69
- model4 = [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7)]
70
- if sigmoid:
71
- model4 += [nn.Sigmoid()]
72
- self.model4 = nn.Sequential(*model4)
73
-
74
- def forward(self, x, cond=None):
75
- out = self.model0(x)
76
- out = self.model1(out)
77
- out = self.model2(out)
78
- out = self.model3(out)
79
- out = self.model4(out)
80
- return out
81
-
82
- # Load Line Drawing Models
83
- model1 = Generator(3, 1, 3)
84
- model2 = Generator(3, 1, 3)
85
- try:
86
- model1.load_state_dict(torch.load('model.pth', map_location='cpu', weights_only=True))
87
- model2.load_state_dict(torch.load('model2.pth', map_location='cpu', weights_only=True))
88
- except FileNotFoundError:
89
- logger.warning("Model files not found. Please ensure 'model.pth' and 'model2.pth' are available.")
90
- model1.eval()
91
- model2.eval()
92
-
93
- # Tiny Diffusion Model
94
  class TinyUNet(nn.Module):
95
  def __init__(self, in_channels=3, out_channels=3):
96
  super(TinyUNet, self).__init__()
@@ -103,7 +137,9 @@ class TinyUNet(nn.Module):
103
  self.time_embed = nn.Linear(1, 64)
104
 
105
  def forward(self, x, t):
106
- t_embed = F.relu(self.time_embed(t.unsqueeze(-1))).view(t_embed.size(0), t_embed.size(1), 1, 1)
 
 
107
  x1 = F.relu(self.down1(x))
108
  x2 = F.relu(self.down2(x1))
109
  x_mid = F.relu(self.mid(x2)) + t_embed
@@ -119,8 +155,8 @@ class TinyDiffusion:
119
  self.alpha = 1 - self.beta
120
  self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
121
 
122
- def train(self, images, epochs=10):
123
- dataset = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images]
124
  dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
125
  optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
126
  device = torch.device("cpu")
@@ -155,126 +191,744 @@ class TinyDiffusion:
155
  x = torch.clamp(x * 255, 0, 255).byte()
156
  return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
157
 
158
- # Utility Functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def generate_filename(sequence, ext="png"):
160
  timestamp = time.strftime("%d%m%Y%H%M%S")
161
  return f"{sequence}_{timestamp}.{ext}"
162
 
163
- def predict_line_drawing(input_img, ver):
164
- original_img = Image.open(input_img) if isinstance(input_img, str) else input_img
165
- original_size = original_img.size
166
- transform = transforms.Compose([
167
- transforms.Resize(256, Image.BICUBIC),
168
- transforms.ToTensor(),
169
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
170
- ])
171
- input_tensor = transform(original_img).unsqueeze(0)
172
- with torch.no_grad():
173
- output = model2(input_tensor) if ver == 'Simple Lines' else model1(input_tensor)
174
- output_img = transforms.ToPILImage()(output.squeeze().cpu().clamp(0, 1))
175
- return output_img.resize(original_size, Image.BICUBIC)
176
-
177
- async def process_ocr(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
179
  model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
180
  result = model.chat(tokenizer, image, ocr_type='ocr')
181
- output_file = generate_filename("ocr_output", "txt")
 
182
  async with aiofiles.open(output_file, "w") as f:
183
  await f.write(result)
184
- return result, output_file
 
185
 
186
- async def process_diffusion(images):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  unet = TinyUNet()
188
  diffusion = TinyDiffusion(unet)
189
  diffusion.train(images)
190
  gen_image = diffusion.generate()
191
- output_file = generate_filename("diffusion_output", "png")
192
- gen_image.save(output_file)
193
- return gen_image, output_file
194
-
195
- def download_pdf(url):
196
- output_path = f"pdf_{int(time.time())}.pdf"
197
- response = requests.get(url, stream=True, timeout=10)
198
- if response.status_code == 200:
199
- with open(output_path, "wb") as f:
200
- for chunk in response.iter_content(chunk_size=8192):
201
- f.write(chunk)
202
- return output_path
203
- return None
204
-
205
- # Gradio Blocks UI
206
- with gr.Blocks(title="Mystical AI Vision Studio 🌌", css="""
207
- .gr-button {background-color: #4CAF50; color: white;}
208
- .gr-tab {border: 2px solid #2196F3; border-radius: 5px;}
209
- #gallery img {border: 1px solid #ddd; border-radius: 4px;}
210
- """) as demo:
211
- gr.Markdown("<h1 style='text-align: center; color: #2196F3;'>Mystical AI Vision Studio 🌌</h1>")
212
- gr.Markdown("<p style='text-align: center;'>Transform images into line drawings, extract text with OCR, and craft unique art with diffusion!</p>")
213
-
214
- with gr.Tab("Image to Line Drawings 🎨"):
215
- with gr.Row():
216
- with gr.Column():
217
- img_input = gr.Image(type="pil", label="Upload Image")
218
- version = gr.Radio(['Complex Lines', 'Simple Lines'], label='Style', value='Simple Lines')
219
- submit_btn = gr.Button("Generate Line Drawing")
220
- with gr.Column():
221
- line_output = gr.Image(type="pil", label="Line Drawing")
222
- download_btn = gr.Button("Download Output")
223
- submit_btn.click(predict_line_drawing, inputs=[img_input, version], outputs=line_output)
224
- download_btn.click(lambda x: gr.File(x, label="Download Line Drawing"), inputs=line_output, outputs=None)
225
-
226
- with gr.Tab("OCR Vision πŸ”"):
227
- with gr.Row():
228
- with gr.Column():
229
- ocr_input = gr.Image(type="pil", label="Upload Image or PDF Snapshot")
230
- ocr_btn = gr.Button("Extract Text")
231
- with gr.Column():
232
- ocr_text = gr.Textbox(label="Extracted Text", interactive=False)
233
- ocr_file = gr.File(label="Download OCR Result")
234
- async def run_ocr(img):
235
- result, file_path = await process_ocr(img)
236
- return result, file_path
237
- ocr_btn.click(run_ocr, inputs=ocr_input, outputs=[ocr_text, ocr_file])
238
-
239
- with gr.Tab("Custom Diffusion πŸŽ¨πŸ€“"):
240
- with gr.Row():
241
- with gr.Column():
242
- diffusion_input = gr.File(label="Upload Images for Training", multiple=True)
243
- diffusion_btn = gr.Button("Train & Generate")
244
- with gr.Column():
245
- diffusion_output = gr.Image(type="pil", label="Generated Art")
246
- diffusion_file = gr.File(label="Download Art")
247
- async def run_diffusion(files):
248
- images = [Image.open(BytesIO(f.read())) for f in files]
249
- img, file_path = await process_diffusion(images)
250
- return img, file_path
251
- diffusion_btn.click(run_diffusion, inputs=diffusion_input, outputs=[diffusion_output, diffusion_file])
252
-
253
- with gr.Tab("PDF Downloader πŸ“₯"):
254
- with gr.Row():
255
- pdf_url = gr.Textbox(label="Enter PDF URL")
256
- pdf_btn = gr.Button("Download PDF")
257
- pdf_output = gr.File(label="Downloaded PDF")
258
- pdf_btn.click(download_pdf, inputs=pdf_url, outputs=pdf_output)
259
-
260
- with gr.Tab("Gallery πŸ“Έ"):
261
- gallery = gr.Gallery(label="Processed Outputs", elem_id="gallery")
262
- def update_gallery():
263
- files = [f for f in os.listdir('.') if f.endswith(('.png', '.txt', '.pdf'))]
264
- return [f for f in files]
265
- gr.Button("Refresh Gallery").click(update_gallery, outputs=gallery)
266
-
267
- # JavaScript for dynamic UI enhancements
268
- gr.HTML("""
269
- <script>
270
- document.addEventListener('DOMContentLoaded', () => {
271
- const buttons = document.querySelectorAll('.gr-button');
272
- buttons.forEach(btn => {
273
- btn.addEventListener('mouseover', () => btn.style.backgroundColor = '#45a049');
274
- btn.addEventListener('mouseout', () => btn.style.backgroundColor = '#4CAF50');
275
- });
276
- });
277
- </script>
278
- """)
279
-
280
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  import os
3
+ import glob
4
  import base64
5
  import time
6
+ import shutil
7
+ import streamlit as st
8
+ import pandas as pd
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
 
 
 
 
13
  from diffusers import StableDiffusionPipeline
14
  from torch.utils.data import Dataset, DataLoader
15
+ import csv
16
+ import fitz
 
17
  import requests
18
+ from PIL import Image
19
+ import cv2
20
+ import numpy as np
21
  import logging
22
+ import asyncio
23
+ import aiofiles
24
  from io import BytesIO
25
  from dataclasses import dataclass
26
+ from typing import Optional, Tuple
27
+ import zipfile
28
+ import math
29
+ import random
30
+ import re
31
 
 
32
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
33
  logger = logging.getLogger(__name__)
34
+ log_records = []
35
+
36
+ class LogCaptureHandler(logging.Handler):
37
+ def emit(self, record):
38
+ log_records.append(record)
39
+
40
+ logger.addHandler(LogCaptureHandler())
41
+
42
+ st.set_page_config(
43
+ page_title="AI Vision & SFT Titans πŸš€",
44
+ page_icon="πŸ€–",
45
+ layout="wide",
46
+ initial_sidebar_state="expanded",
47
+ menu_items={
48
+ 'Get Help': 'https://huggingface.co/awacke1',
49
+ 'Report a Bug': 'https://huggingface.co/spaces/awacke1',
50
+ 'About': "AI Vision & SFT Titans: PDFs, OCR, Image Gen, Line Drawings, Custom Diffusion, and SFT on CPU! 🌌"
51
+ }
52
+ )
53
+
54
+ if 'history' not in st.session_state:
55
+ st.session_state['history'] = []
56
+ if 'builder' not in st.session_state:
57
+ st.session_state['builder'] = None
58
+ if 'model_loaded' not in st.session_state:
59
+ st.session_state['model_loaded'] = False
60
+ if 'processing' not in st.session_state:
61
+ st.session_state['processing'] = {}
62
+ if 'asset_checkboxes' not in st.session_state:
63
+ st.session_state['asset_checkboxes'] = {}
64
+ if 'downloaded_pdfs' not in st.session_state:
65
+ st.session_state['downloaded_pdfs'] = {}
66
+ if 'unique_counter' not in st.session_state:
67
+ st.session_state['unique_counter'] = 0 # For generating unique keys
68
+
69
+ @dataclass
70
+ class ModelConfig:
71
+ name: str
72
+ base_model: str
73
+ size: str
74
+ domain: Optional[str] = None
75
+ model_type: str = "causal_lm"
76
+ @property
77
+ def model_path(self):
78
+ return f"models/{self.name}"
79
+
80
+ @dataclass
81
+ class DiffusionConfig:
82
+ name: str
83
+ base_model: str
84
+ size: str
85
+ domain: Optional[str] = None
86
+ @property
87
+ def model_path(self):
88
+ return f"diffusion_models/{self.name}"
89
+
90
+ class SFTDataset(Dataset):
91
+ def __init__(self, data, tokenizer, max_length=128):
92
+ self.data = data
93
+ self.tokenizer = tokenizer
94
+ self.max_length = max_length
95
+ def __len__(self):
96
+ return len(self.data)
97
+ def __getitem__(self, idx):
98
+ prompt = self.data[idx]["prompt"]
99
+ response = self.data[idx]["response"]
100
+ full_text = f"{prompt} {response}"
101
+ full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
102
+ prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
103
+ input_ids = full_encoding["input_ids"].squeeze()
104
+ attention_mask = full_encoding["attention_mask"].squeeze()
105
+ labels = input_ids.clone()
106
+ prompt_len = prompt_encoding["input_ids"].shape[1]
107
+ if prompt_len < self.max_length:
108
+ labels[:prompt_len] = -100
109
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
110
+
111
+ class DiffusionDataset(Dataset):
112
+ def __init__(self, images, texts):
113
+ self.images = images
114
+ self.texts = texts
115
+ def __len__(self):
116
+ return len(self.images)
117
+ def __getitem__(self, idx):
118
+ return {"image": self.images[idx], "text": self.texts[idx]}
119
+
120
+ class TinyDiffusionDataset(Dataset):
121
+ def __init__(self, images):
122
+ self.images = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images]
123
+ def __len__(self):
124
+ return len(self.images)
125
+ def __getitem__(self, idx):
126
+ return self.images[idx]
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  class TinyUNet(nn.Module):
129
  def __init__(self, in_channels=3, out_channels=3):
130
  super(TinyUNet, self).__init__()
 
137
  self.time_embed = nn.Linear(1, 64)
138
 
139
  def forward(self, x, t):
140
+ t_embed = F.relu(self.time_embed(t.unsqueeze(-1)))
141
+ t_embed = t_embed.view(t_embed.size(0), t_embed.size(1), 1, 1)
142
+
143
  x1 = F.relu(self.down1(x))
144
  x2 = F.relu(self.down2(x1))
145
  x_mid = F.relu(self.mid(x2)) + t_embed
 
155
  self.alpha = 1 - self.beta
156
  self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
157
 
158
+ def train(self, images, epochs=50):
159
+ dataset = TinyDiffusionDataset(images)
160
  dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
161
  optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
162
  device = torch.device("cpu")
 
191
  x = torch.clamp(x * 255, 0, 255).byte()
192
  return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
193
 
194
+ def upscale(self, image, scale_factor=2):
195
+ img_tensor = torch.tensor(np.array(image.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0) / 255.0
196
+ upscaled = F.interpolate(img_tensor, scale_factor=scale_factor, mode='bilinear', align_corners=False)
197
+ upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
198
+ return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
199
+
200
+ class ModelBuilder:
201
+ def __init__(self):
202
+ self.config = None
203
+ self.model = None
204
+ self.tokenizer = None
205
+ self.sft_data = None
206
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! πŸ˜‚", "Training complete! Time for a binary coffee break. β˜•"]
207
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
208
+ with st.spinner(f"Loading {model_path}... ⏳"):
209
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
210
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
211
+ if self.tokenizer.pad_token is None:
212
+ self.tokenizer.pad_token = self.tokenizer.eos_token
213
+ if config:
214
+ self.config = config
215
+ self.model.to("cuda" if torch.cuda.is_available() else "cpu")
216
+ st.success(f"Model loaded! πŸŽ‰ {random.choice(self.jokes)}")
217
+ return self
218
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
219
+ self.sft_data = []
220
+ with open(csv_path, "r") as f:
221
+ reader = csv.DictReader(f)
222
+ for row in reader:
223
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
224
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
225
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
226
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
227
+ self.model.train()
228
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
229
+ self.model.to(device)
230
+ for epoch in range(epochs):
231
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... βš™οΈ"):
232
+ total_loss = 0
233
+ for batch in dataloader:
234
+ optimizer.zero_grad()
235
+ input_ids = batch["input_ids"].to(device)
236
+ attention_mask = batch["attention_mask"].to(device)
237
+ labels = batch["labels"].to(device)
238
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
239
+ loss = outputs.loss
240
+ loss.backward()
241
+ optimizer.step()
242
+ total_loss += loss.item()
243
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
244
+ st.success(f"SFT Fine-tuning completed! πŸŽ‰ {random.choice(self.jokes)}")
245
+ return self
246
+ def save_model(self, path: str):
247
+ with st.spinner("Saving model... πŸ’Ύ"):
248
+ os.makedirs(os.path.dirname(path), exist_ok=True)
249
+ self.model.save_pretrained(path)
250
+ self.tokenizer.save_pretrained(path)
251
+ st.success(f"Model saved at {path}! βœ…")
252
+ def evaluate(self, prompt: str, status_container=None):
253
+ self.model.eval()
254
+ if status_container:
255
+ status_container.write("Preparing to evaluate... 🧠")
256
+ try:
257
+ with torch.no_grad():
258
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
259
+ outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
260
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
261
+ except Exception as e:
262
+ if status_container:
263
+ status_container.error(f"Oops! Something broke: {str(e)} πŸ’₯")
264
+ return f"Error: {str(e)}"
265
+
266
+ class DiffusionBuilder:
267
+ def __init__(self):
268
+ self.config = None
269
+ self.pipeline = None
270
+ def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
271
+ with st.spinner(f"Loading diffusion model {model_path}... ⏳"):
272
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cpu")
273
+ if config:
274
+ self.config = config
275
+ st.success(f"Diffusion model loaded! 🎨")
276
+ return self
277
+ def fine_tune_sft(self, images, texts, epochs=3):
278
+ dataset = DiffusionDataset(images, texts)
279
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
280
+ optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
281
+ self.pipeline.unet.train()
282
+ for epoch in range(epochs):
283
+ with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... βš™οΈ"):
284
+ total_loss = 0
285
+ for batch in dataloader:
286
+ optimizer.zero_grad()
287
+ image = batch["image"][0].to(self.pipeline.device)
288
+ text = batch["text"][0]
289
+ latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
290
+ noise = torch.randn_like(latents)
291
+ timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
292
+ noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
293
+ text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
294
+ pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
295
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
296
+ loss.backward()
297
+ optimizer.step()
298
+ total_loss += loss.item()
299
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
300
+ st.success("Diffusion SFT Fine-tuning completed! 🎨")
301
+ return self
302
+ def save_model(self, path: str):
303
+ with st.spinner("Saving diffusion model... πŸ’Ύ"):
304
+ os.makedirs(os.path.dirname(path), exist_ok=True)
305
+ self.pipeline.save_pretrained(path)
306
+ st.success(f"Diffusion model saved at {path}! βœ…")
307
+ def generate(self, prompt: str):
308
+ return self.pipeline(prompt, num_inference_steps=20).images[0]
309
+
310
  def generate_filename(sequence, ext="png"):
311
  timestamp = time.strftime("%d%m%Y%H%M%S")
312
  return f"{sequence}_{timestamp}.{ext}"
313
 
314
+ def pdf_url_to_filename(url):
315
+ safe_name = re.sub(r'[<>:"/\\|?*]', '_', url)
316
+ return f"{safe_name}.pdf"
317
+
318
+ def get_download_link(file_path, mime_type="application/pdf", label="Download"):
319
+ with open(file_path, 'rb') as f:
320
+ data = f.read()
321
+ b64 = base64.b64encode(data).decode()
322
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>'
323
+
324
+ def zip_directory(directory_path, zip_path):
325
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
326
+ for root, _, files in os.walk(directory_path):
327
+ for file in files:
328
+ zipf.write(os.path.join(root, file), os.path.relpath(os.path.join(root, file), os.path.dirname(directory_path)))
329
+
330
+ def get_model_files(model_type="causal_lm"):
331
+ path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
332
+ return [d for d in glob.glob(path) if os.path.isdir(d)]
333
+
334
+ def get_gallery_files(file_types=["png", "pdf"]):
335
+ return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files
336
+
337
+ def get_pdf_files():
338
+ return sorted(glob.glob("*.pdf"))
339
+
340
+ def download_pdf(url, output_path):
341
+ try:
342
+ response = requests.get(url, stream=True, timeout=10)
343
+ if response.status_code == 200:
344
+ with open(output_path, "wb") as f:
345
+ for chunk in response.iter_content(chunk_size=8192):
346
+ f.write(chunk)
347
+ return True
348
+ except requests.RequestException as e:
349
+ logger.error(f"Failed to download {url}: {e}")
350
+ return False
351
+
352
+ async def process_pdf_snapshot(pdf_path, mode="single"):
353
+ start_time = time.time()
354
+ status = st.empty()
355
+ status.text(f"Processing PDF Snapshot ({mode})... (0s)")
356
+ try:
357
+ doc = fitz.open(pdf_path)
358
+ output_files = []
359
+ if mode == "single":
360
+ page = doc[0]
361
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
362
+ output_file = generate_filename("single", "png")
363
+ pix.save(output_file)
364
+ output_files.append(output_file)
365
+ elif mode == "twopage":
366
+ for i in range(min(2, len(doc))):
367
+ page = doc[i]
368
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
369
+ output_file = generate_filename(f"twopage_{i}", "png")
370
+ pix.save(output_file)
371
+ output_files.append(output_file)
372
+ elif mode == "allpages":
373
+ for i in range(len(doc)):
374
+ page = doc[i]
375
+ pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
376
+ output_file = generate_filename(f"page_{i}", "png")
377
+ pix.save(output_file)
378
+ output_files.append(output_file)
379
+ doc.close()
380
+ elapsed = int(time.time() - start_time)
381
+ status.text(f"PDF Snapshot ({mode}) completed in {elapsed}s!")
382
+ update_gallery()
383
+ return output_files
384
+ except Exception as e:
385
+ status.error(f"Failed to process PDF: {str(e)}")
386
+ return []
387
+
388
+ async def process_ocr(image, output_file):
389
+ start_time = time.time()
390
+ status = st.empty()
391
+ status.text("Processing GOT-OCR2_0... (0s)")
392
  tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
393
  model = AutoModel.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True, torch_dtype=torch.float32).to("cpu").eval()
394
  result = model.chat(tokenizer, image, ocr_type='ocr')
395
+ elapsed = int(time.time() - start_time)
396
+ status.text(f"GOT-OCR2_0 completed in {elapsed}s!")
397
  async with aiofiles.open(output_file, "w") as f:
398
  await f.write(result)
399
+ update_gallery()
400
+ return result
401
 
402
+ async def process_image_gen(prompt, output_file):
403
+ start_time = time.time()
404
+ status = st.empty()
405
+ status.text("Processing Image Gen... (0s)")
406
+ pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cpu")
407
+ gen_image = pipeline(prompt, num_inference_steps=20).images[0]
408
+ elapsed = int(time.time() - start_time)
409
+ status.text(f"Image Gen completed in {elapsed}s!")
410
+ gen_image.save(output_file)
411
+ update_gallery()
412
+ return gen_image
413
+
414
+ async def process_custom_diffusion(images, output_file, model_name):
415
+ start_time = time.time()
416
+ status = st.empty()
417
+ status.text(f"Training {model_name}... (0s)")
418
  unet = TinyUNet()
419
  diffusion = TinyDiffusion(unet)
420
  diffusion.train(images)
421
  gen_image = diffusion.generate()
422
+ upscaled_image = diffusion.upscale(gen_image, scale_factor=2)
423
+ elapsed = int(time.time() - start_time)
424
+ status.text(f"{model_name} completed in {elapsed}s!")
425
+ upscaled_image.save(output_file)
426
+ update_gallery()
427
+ return upscaled_image
428
+
429
+ def mock_search(query: str) -> str:
430
+ if "superhero" in query.lower():
431
+ return "Latest trends: Gold-plated Batman statues, VR superhero battles."
432
+ return "No relevant results found."
433
+
434
+ def mock_duckduckgo_search(query: str) -> str:
435
+ if "superhero party trends" in query.lower():
436
+ return """
437
+ Latest trends for 2025:
438
+ - Luxury decorations: Gold-plated Batman statues, holographic Avengers displays.
439
+ - Entertainment: Live stunt shows with Iron Man suits, VR superhero battles.
440
+ - Catering: Gourmet kryptonite-green cocktails, Thor’s hammer-shaped appetizers.
441
+ """
442
+ return "No relevant results found."
443
+
444
+ class PartyPlannerAgent:
445
+ def __init__(self, model, tokenizer):
446
+ self.model = model
447
+ self.tokenizer = tokenizer
448
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
449
+ self.model.to(self.device)
450
+ def generate(self, prompt: str) -> str:
451
+ self.model.eval()
452
+ with torch.no_grad():
453
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
454
+ outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
455
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
456
+ def plan_party(self, task: str) -> pd.DataFrame:
457
+ search_result = mock_duckduckgo_search("latest superhero party trends")
458
+ prompt = f"Given this context: '{search_result}'\n{task}"
459
+ plan_text = self.generate(prompt)
460
+ locations = {
461
+ "Wayne Manor": (42.3601, -71.0589),
462
+ "New York": (40.7128, -74.0060),
463
+ "Los Angeles": (34.0522, -118.2437),
464
+ "London": (51.5074, -0.1278)
465
+ }
466
+ wayne_coords = locations["Wayne Manor"]
467
+ travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
468
+ catchphrases = ["To the Batmobile!", "Avengers, assemble!", "I am Iron Man!", "By the power of Grayskull!"]
469
+ data = [
470
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues", "Catchphrase": random.choice(catchphrases)},
471
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Holographic Avengers displays", "Catchphrase": random.choice(catchphrases)},
472
+ {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows with Iron Man suits", "Catchphrase": random.choice(catchphrases)},
473
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles", "Catchphrase": random.choice(catchphrases)},
474
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gourmet kryptonite-green cocktails", "Catchphrase": random.choice(catchphrases)},
475
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Thor’s hammer-shaped appetizers", "Catchphrase": random.choice(catchphrases)},
476
+ ]
477
+ return pd.DataFrame(data)
478
+
479
+ class CVPartyPlannerAgent:
480
+ def __init__(self, pipeline):
481
+ self.pipeline = pipeline
482
+ def generate(self, prompt: str) -> Image.Image:
483
+ return self.pipeline(prompt, num_inference_steps=20).images[0]
484
+ def plan_party(self, task: str) -> pd.DataFrame:
485
+ search_result = mock_search("superhero party trends")
486
+ prompt = f"Given this context: '{search_result}'\n{task}"
487
+ data = [
488
+ {"Theme": "Batman", "Image Idea": "Gold-plated Batman statue"},
489
+ {"Theme": "Avengers", "Image Idea": "VR superhero battle scene"}
490
+ ]
491
+ return pd.DataFrame(data)
492
+
493
+ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
494
+ def to_radians(degrees: float) -> float:
495
+ return degrees * (math.pi / 180)
496
+ lat1, lon1 = map(to_radians, origin_coords)
497
+ lat2, lon2 = map(to_radians, destination_coords)
498
+ EARTH_RADIUS_KM = 6371.0
499
+ dlon = lon2 - lon1
500
+ dlat = lat2 - lat1
501
+ a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
502
+ c = 2 * math.asin(math.sqrt(a))
503
+ distance = EARTH_RADIUS_KM * c
504
+ actual_distance = distance * 1.1
505
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
506
+ return round(flight_time, 2)
507
+
508
+ st.title("AI Vision & SFT Titans πŸš€")
509
+
510
+ st.sidebar.header("Captured Files πŸ“œ")
511
+ cols = st.sidebar.columns(2)
512
+ with cols[0]:
513
+ if st.button("Zip All 🀐"):
514
+ zip_path = f"all_assets_{int(time.time())}.zip"
515
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
516
+ for file in get_gallery_files():
517
+ zipf.write(file, os.path.basename(file))
518
+ st.sidebar.markdown(get_download_link(zip_path, "application/zip", "Download All Assets"), unsafe_allow_html=True)
519
+ with cols[1]:
520
+ if st.button("Zap All! πŸ—‘οΈ"):
521
+ for file in get_gallery_files():
522
+ os.remove(file)
523
+ st.session_state['asset_checkboxes'].clear()
524
+ st.session_state['downloaded_pdfs'].clear()
525
+ st.sidebar.success("All assets vaporized! πŸ’¨")
526
+ st.rerun()
527
+
528
+ gallery_size = st.sidebar.slider("Gallery Size", 1, 10, 2)
529
+ def update_gallery():
530
+ all_files = get_gallery_files()
531
+ if all_files:
532
+ st.sidebar.subheader("Asset Gallery πŸ“ΈπŸ“–")
533
+ cols = st.sidebar.columns(2)
534
+ for idx, file in enumerate(all_files[:gallery_size * 2]):
535
+ with cols[idx % 2]:
536
+ st.session_state['unique_counter'] += 1 # Increment counter for uniqueness
537
+ unique_id = st.session_state['unique_counter']
538
+ if file.endswith('.png'):
539
+ st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
540
+ else:
541
+ doc = fitz.open(file)
542
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5))
543
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
544
+ st.image(img, caption=os.path.basename(file), use_container_width=True)
545
+ doc.close()
546
+ checkbox_key = f"asset_{file}_{unique_id}" # Unique key with counter
547
+ st.session_state['asset_checkboxes'][file] = st.checkbox(
548
+ "Use for SFT/Input",
549
+ value=st.session_state['asset_checkboxes'].get(file, False),
550
+ key=checkbox_key
551
+ )
552
+ mime_type = "image/png" if file.endswith('.png') else "application/pdf"
553
+ st.markdown(get_download_link(file, mime_type, "Snag It! πŸ“₯"), unsafe_allow_html=True)
554
+ if st.button("Zap It! πŸ—‘οΈ", key=f"delete_{file}_{unique_id}"): # Unique key with counter
555
+ os.remove(file)
556
+ if file in st.session_state['asset_checkboxes']:
557
+ del st.session_state['asset_checkboxes'][file]
558
+ if file.endswith('.pdf'):
559
+ url_key = next((k for k, v in st.session_state['downloaded_pdfs'].items() if v == file), None)
560
+ if url_key:
561
+ del st.session_state['downloaded_pdfs'][url_key]
562
+ st.sidebar.success(f"Asset {os.path.basename(file)} vaporized! πŸ’¨")
563
+ st.rerun()
564
+ update_gallery()
565
+
566
+ st.sidebar.subheader("Model Management πŸ—‚οΈ")
567
+ model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"], key="sidebar_model_type")
568
+ model_dirs = get_model_files(model_type)
569
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs, key="sidebar_model_select")
570
+ if selected_model != "None" and st.sidebar.button("Load Model πŸ“‚"):
571
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
572
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
573
+ builder.load_model(selected_model, config)
574
+ st.session_state['builder'] = builder
575
+ st.session_state['model_loaded'] = True
576
+ st.rerun()
577
+
578
+ st.sidebar.subheader("Action Logs πŸ“œ")
579
+ log_container = st.sidebar.empty()
580
+ with log_container:
581
+ for record in log_records:
582
+ st.write(f"{record.asctime} - {record.levelname} - {record.message}")
583
+
584
+ st.sidebar.subheader("History πŸ“œ")
585
+ history_container = st.sidebar.empty()
586
+ with history_container:
587
+ for entry in st.session_state['history'][-gallery_size * 2:]:
588
+ st.write(entry)
589
+
590
+ tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9 = st.tabs([
591
+ "Camera Snap πŸ“·", "Download PDFs πŸ“₯", "Build Titan 🌱", "Fine-Tune Titan πŸ”§",
592
+ "Test Titan πŸ§ͺ", "Agentic RAG Party 🌐", "Test OCR πŸ”", "Test Image Gen 🎨", "Custom Diffusion πŸŽ¨πŸ€“"
593
+ ])
594
+
595
+ with tab1:
596
+ st.header("Camera Snap πŸ“·")
597
+ st.subheader("Single Capture")
598
+ cols = st.columns(2)
599
+ with cols[0]:
600
+ cam0_img = st.camera_input("Take a picture - Cam 0", key="cam0")
601
+ if cam0_img:
602
+ filename = generate_filename("cam0")
603
+ with open(filename, "wb") as f:
604
+ f.write(cam0_img.getvalue())
605
+ entry = f"Snapshot from Cam 0: {filename}"
606
+ if entry not in st.session_state['history']:
607
+ st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 0:")] + [entry]
608
+ st.image(Image.open(filename), caption="Camera 0", use_container_width=True)
609
+ logger.info(f"Saved snapshot from Camera 0: {filename}")
610
+ update_gallery()
611
+ with cols[1]:
612
+ cam1_img = st.camera_input("Take a picture - Cam 1", key="cam1")
613
+ if cam1_img:
614
+ filename = generate_filename("cam1")
615
+ with open(filename, "wb") as f:
616
+ f.write(cam1_img.getvalue())
617
+ entry = f"Snapshot from Cam 1: {filename}"
618
+ if entry not in st.session_state['history']:
619
+ st.session_state['history'] = [e for e in st.session_state['history'] if not e.startswith("Snapshot from Cam 1:")] + [entry]
620
+ st.image(Image.open(filename), caption="Camera 1", use_container_width=True)
621
+ logger.info(f"Saved snapshot from Camera 1: {filename}")
622
+ update_gallery()
623
+
624
+ with tab2:
625
+ st.header("Download PDFs πŸ“₯")
626
+ if st.button("Examples πŸ“š"):
627
+ example_urls = [
628
+ "https://arxiv.org/pdf/2308.03892",
629
+ "https://arxiv.org/pdf/1912.01703",
630
+ "https://arxiv.org/pdf/2408.11039",
631
+ "https://arxiv.org/pdf/2109.10282",
632
+ "https://arxiv.org/pdf/2112.10752",
633
+ "https://arxiv.org/pdf/2308.11236",
634
+ "https://arxiv.org/pdf/1706.03762",
635
+ "https://arxiv.org/pdf/2006.11239",
636
+ "https://arxiv.org/pdf/2305.11207",
637
+ "https://arxiv.org/pdf/2106.09685",
638
+ "https://arxiv.org/pdf/2005.11401",
639
+ "https://arxiv.org/pdf/2106.10504"
640
+ ]
641
+ st.session_state['pdf_urls'] = "\n".join(example_urls)
642
+
643
+ url_input = st.text_area("Enter PDF URLs (one per line)", value=st.session_state.get('pdf_urls', ""), height=200)
644
+ if st.button("Robo-Download πŸ€–"):
645
+ urls = url_input.strip().split("\n")
646
+ progress_bar = st.progress(0)
647
+ status_text = st.empty()
648
+ total_urls = len(urls)
649
+ existing_pdfs = get_pdf_files()
650
+ for idx, url in enumerate(urls):
651
+ if url:
652
+ output_path = pdf_url_to_filename(url)
653
+ status_text.text(f"Fetching {idx + 1}/{total_urls}: {os.path.basename(output_path)}...")
654
+ if output_path not in existing_pdfs:
655
+ if download_pdf(url, output_path):
656
+ st.session_state['downloaded_pdfs'][url] = output_path
657
+ logger.info(f"Downloaded PDF from {url} to {output_path}")
658
+ entry = f"Downloaded PDF: {output_path}"
659
+ if entry not in st.session_state['history']:
660
+ st.session_state['history'].append(entry)
661
+ else:
662
+ st.error(f"Failed to nab {url} 😿")
663
+ else:
664
+ st.info(f"Already got {os.path.basename(output_path)}! Skipping... 🐾")
665
+ st.session_state['downloaded_pdfs'][url] = output_path
666
+ progress_bar.progress((idx + 1) / total_urls)
667
+ status_text.text("Robo-Download complete! πŸš€")
668
+ update_gallery()
669
+
670
+ mode = st.selectbox("Snapshot Mode", ["Single Page (High-Res)", "Two Pages (High-Res)", "All Pages (High-Res)"], key="download_mode")
671
+ if st.button("Snapshot Selected πŸ“Έ"):
672
+ selected_pdfs = [path for path in get_gallery_files() if path.endswith('.pdf') and st.session_state['asset_checkboxes'].get(path, False)]
673
+ if selected_pdfs:
674
+ for pdf_path in selected_pdfs:
675
+ mode_key = {"Single Page (High-Res)": "single", "Two Pages (High-Res)": "twopage", "All Pages (High-Res)": "allpages"}[mode]
676
+ snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key))
677
+ for snapshot in snapshots:
678
+ st.image(Image.open(snapshot), caption=snapshot, use_container_width=True)
679
+ else:
680
+ st.warning("No PDFs selected for snapshotting! Check some boxes in the sidebar gallery.")
681
+
682
+ with tab3:
683
+ st.header("Build Titan 🌱")
684
+ model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
685
+ base_model = st.selectbox("Select Tiny Model",
686
+ ["HuggingFaceTB/SmolLM-135M", "Qwen/Qwen1.5-0.5B-Chat"] if model_type == "Causal LM" else
687
+ ["OFA-Sys/small-stable-diffusion-v0", "stabilityai/stable-diffusion-2-base"])
688
+ model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
689
+ domain = st.text_input("Target Domain", "general")
690
+ if st.button("Download Model ⬇️"):
691
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small", domain=domain)
692
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
693
+ builder.load_model(base_model, config)
694
+ builder.save_model(config.model_path)
695
+ st.session_state['builder'] = builder
696
+ st.session_state['model_loaded'] = True
697
+ entry = f"Built {model_type} model: {model_name}"
698
+ if entry not in st.session_state['history']:
699
+ st.session_state['history'].append(entry)
700
+ st.success(f"Model downloaded and saved to {config.model_path}! πŸŽ‰")
701
+ st.rerun()
702
+
703
+ with tab4:
704
+ st.header("Fine-Tune Titan πŸ”§")
705
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
706
+ st.warning("Please build or load a Titan first! ⚠️")
707
+ else:
708
+ if isinstance(st.session_state['builder'], ModelBuilder):
709
+ if st.button("Generate Sample CSV πŸ“"):
710
+ sample_data = [
711
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human smarts in machines."},
712
+ {"prompt": "Explain machine learning", "response": "Machine learning is AI’s gym where models bulk up on data."},
713
+ ]
714
+ csv_path = f"sft_data_{int(time.time())}.csv"
715
+ with open(csv_path, "w", newline="") as f:
716
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
717
+ writer.writeheader()
718
+ writer.writerows(sample_data)
719
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
720
+ st.success(f"Sample CSV generated as {csv_path}! βœ…")
721
+
722
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
723
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV πŸ”„"):
724
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
725
+ with open(csv_path, "wb") as f:
726
+ f.write(uploaded_csv.read())
727
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
728
+ new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
729
+ st.session_state['builder'].config = new_config
730
+ st.session_state['builder'].fine_tune_sft(csv_path)
731
+ st.session_state['builder'].save_model(new_config.model_path)
732
+ zip_path = f"{new_config.model_path}.zip"
733
+ zip_directory(new_config.model_path, zip_path)
734
+ entry = f"Fine-tuned Causal LM: {new_model_name}"
735
+ if entry not in st.session_state['history']:
736
+ st.session_state['history'].append(entry)
737
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
738
+ st.rerun()
739
+ elif isinstance(st.session_state['builder'], DiffusionBuilder):
740
+ selected_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
741
+ if len(selected_files) >= 2:
742
+ demo_data = [{"image": file, "text": f"Asset {os.path.basename(file).split('.')[0]}"} for file in selected_files]
743
+ edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
744
+ if st.button("Fine-Tune with Dataset πŸ”„"):
745
+ images = [Image.open(row["image"]) if row["image"].endswith('.png') else Image.frombytes("RGB", fitz.open(row["image"])[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)).size, fitz.open(row["image"])[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)).samples) for _, row in edited_data.iterrows()]
746
+ texts = [row["text"] for _, row in edited_data.iterrows()]
747
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
748
+ new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
749
+ st.session_state['builder'].config = new_config
750
+ st.session_state['builder'].fine_tune_sft(images, texts)
751
+ st.session_state['builder'].save_model(new_config.model_path)
752
+ zip_path = f"{new_config.model_path}.zip"
753
+ zip_directory(new_config.model_path, zip_path)
754
+ entry = f"Fine-tuned Diffusion: {new_model_name}"
755
+ if entry not in st.session_state['history']:
756
+ st.session_state['history'].append(entry)
757
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
758
+ csv_path = f"sft_dataset_{int(time.time())}.csv"
759
+ with open(csv_path, "w", newline="") as f:
760
+ writer = csv.writer(f)
761
+ writer.writerow(["image", "text"])
762
+ for _, row in edited_data.iterrows():
763
+ writer.writerow([row["image"], row["text"]])
764
+ st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
765
+
766
+ with tab5:
767
+ st.header("Test Titan πŸ§ͺ")
768
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
769
+ st.warning("Please build or load a Titan first! ⚠️")
770
+ else:
771
+ if isinstance(st.session_state['builder'], ModelBuilder):
772
+ if st.session_state['builder'].sft_data:
773
+ st.write("Testing with SFT Data:")
774
+ for item in st.session_state['builder'].sft_data[:3]:
775
+ prompt = item["prompt"]
776
+ expected = item["response"]
777
+ status_container = st.empty()
778
+ generated = st.session_state['builder'].evaluate(prompt, status_container)
779
+ st.write(f"**Prompt**: {prompt}")
780
+ st.write(f"**Expected**: {expected}")
781
+ st.write(f"**Generated**: {generated}")
782
+ st.write("---")
783
+ status_container.empty()
784
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
785
+ if st.button("Run Test ▢️"):
786
+ status_container = st.empty()
787
+ result = st.session_state['builder'].evaluate(test_prompt, status_container)
788
+ entry = f"Causal LM Test: {test_prompt} -> {result}"
789
+ if entry not in st.session_state['history']:
790
+ st.session_state['history'].append(entry)
791
+ st.write(f"**Generated Response**: {result}")
792
+ status_container.empty()
793
+ elif isinstance(st.session_state['builder'], DiffusionBuilder):
794
+ test_prompt = st.text_area("Enter Test Prompt", "Neon Batman")
795
+ if st.button("Run Test ▢️"):
796
+ image = st.session_state['builder'].generate(test_prompt)
797
+ output_file = generate_filename("diffusion_test", "png")
798
+ image.save(output_file)
799
+ entry = f"Diffusion Test: {test_prompt} -> {output_file}"
800
+ if entry not in st.session_state['history']:
801
+ st.session_state['history'].append(entry)
802
+ st.image(image, caption="Generated Image")
803
+ update_gallery()
804
+
805
+ with tab6:
806
+ st.header("Agentic RAG Party 🌐")
807
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
808
+ st.warning("Please build or load a Titan first! ⚠️")
809
+ else:
810
+ if isinstance(st.session_state['builder'], ModelBuilder):
811
+ if st.button("Run NLP RAG Demo πŸŽ‰"):
812
+ agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
813
+ task = "Plan a luxury superhero-themed party at Wayne Manor."
814
+ plan_df = agent.plan_party(task)
815
+ entry = f"NLP RAG Demo: Planned party at Wayne Manor"
816
+ if entry not in st.session_state['history']:
817
+ st.session_state['history'].append(entry)
818
+ st.dataframe(plan_df)
819
+ elif isinstance(st.session_state['builder'], DiffusionBuilder):
820
+ if st.button("Run CV RAG Demo πŸŽ‰"):
821
+ agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
822
+ task = "Generate images for a luxury superhero-themed party."
823
+ plan_df = agent.plan_party(task)
824
+ entry = f"CV RAG Demo: Generated party images"
825
+ if entry not in st.session_state['history']:
826
+ st.session_state['history'].append(entry)
827
+ st.dataframe(plan_df)
828
+ for _, row in plan_df.iterrows():
829
+ image = agent.generate(row["Image Idea"])
830
+ output_file = generate_filename(f"cv_rag_{row['Theme'].lower()}", "png")
831
+ image.save(output_file)
832
+ st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
833
+ update_gallery()
834
+
835
+ with tab7:
836
+ st.header("Test OCR πŸ”")
837
+ all_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
838
+ if all_files:
839
+ selected_file = st.selectbox("Select Image or PDF", all_files, key="ocr_select")
840
+ if selected_file:
841
+ if selected_file.endswith('.png'):
842
+ image = Image.open(selected_file)
843
+ else:
844
+ doc = fitz.open(selected_file)
845
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
846
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
847
+ doc.close()
848
+ st.image(image, caption="Input Image", use_container_width=True)
849
+ if st.button("Run OCR πŸš€", key="ocr_run"):
850
+ output_file = generate_filename("ocr_output", "txt")
851
+ st.session_state['processing']['ocr'] = True
852
+ result = asyncio.run(process_ocr(image, output_file))
853
+ entry = f"OCR Test: {selected_file} -> {output_file}"
854
+ if entry not in st.session_state['history']:
855
+ st.session_state['history'].append(entry)
856
+ st.text_area("OCR Result", result, height=200, key="ocr_result")
857
+ st.success(f"OCR output saved to {output_file}")
858
+ st.session_state['processing']['ocr'] = False
859
+ else:
860
+ st.warning("No images or PDFs selected yet. Check some boxes in the sidebar gallery!")
861
+
862
+ with tab8:
863
+ st.header("Test Image Gen 🎨")
864
+ all_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
865
+ if all_files:
866
+ selected_file = st.selectbox("Select Image or PDF", all_files, key="gen_select")
867
+ if selected_file:
868
+ if selected_file.endswith('.png'):
869
+ image = Image.open(selected_file)
870
+ else:
871
+ doc = fitz.open(selected_file)
872
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
873
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
874
+ doc.close()
875
+ st.image(image, caption="Reference Image", use_container_width=True)
876
+ prompt = st.text_area("Prompt", "Generate a similar superhero image", key="gen_prompt")
877
+ if st.button("Run Image Gen πŸš€", key="gen_run"):
878
+ output_file = generate_filename("gen_output", "png")
879
+ st.session_state['processing']['gen'] = True
880
+ result = asyncio.run(process_image_gen(prompt, output_file))
881
+ entry = f"Image Gen Test: {prompt} -> {output_file}"
882
+ if entry not in st.session_state['history']:
883
+ st.session_state['history'].append(entry)
884
+ st.image(result, caption="Generated Image", use_container_width=True)
885
+ st.success(f"Image saved to {output_file}")
886
+ st.session_state['processing']['gen'] = False
887
+ else:
888
+ st.warning("No images or PDFs selected yet. Check some boxes in the sidebar gallery!")
889
+
890
+ with tab9:
891
+ st.header("Custom Diffusion πŸŽ¨πŸ€“")
892
+ st.write("Unleash your inner artist with our tiny diffusion models!")
893
+ all_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
894
+ if all_files:
895
+ st.subheader("Select Images or PDFs to Train")
896
+ selected_files = st.multiselect("Pick Images or PDFs", all_files, key="diffusion_select")
897
+ images = []
898
+ for file in selected_files:
899
+ if file.endswith('.png'):
900
+ images.append(Image.open(file))
901
+ else:
902
+ doc = fitz.open(file)
903
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
904
+ images.append(Image.frombytes("RGB", [pix.width, pix.height], pix.samples))
905
+ doc.close()
906
+
907
+ model_options = [
908
+ ("PixelTickler 🎨✨", "OFA-Sys/small-stable-diffusion-v0"),
909
+ ("DreamWeaver πŸŒ™πŸ–ŒοΈ", "stabilityai/stable-diffusion-2-base"),
910
+ ("TinyArtBot πŸ€–πŸ–ΌοΈ", "custom")
911
+ ]
912
+ model_choice = st.selectbox("Choose Your Diffusion Dynamo", [opt[0] for opt in model_options], key="diffusion_model")
913
+ model_name = next(opt[1] for opt in model_options if opt[0] == model_choice)
914
+
915
+ if st.button("Train & Generate πŸš€", key="diffusion_run"):
916
+ output_file = generate_filename("custom_diffusion", "png")
917
+ st.session_state['processing']['diffusion'] = True
918
+ if model_name == "custom":
919
+ result = asyncio.run(process_custom_diffusion(images, output_file, model_choice))
920
+ else:
921
+ builder = DiffusionBuilder()
922
+ builder.load_model(model_name)
923
+ result = builder.generate("A superhero scene inspired by captured images")
924
+ result.save(output_file)
925
+ entry = f"Custom Diffusion: {model_choice} -> {output_file}"
926
+ if entry not in st.session_state['history']:
927
+ st.session_state['history'].append(entry)
928
+ st.image(result, caption=f"{model_choice} Masterpiece", use_container_width=True)
929
+ st.success(f"Image saved to {output_file}")
930
+ st.session_state['processing']['diffusion'] = False
931
+ else:
932
+ st.warning("No images or PDFs selected yet. Check some boxes in the sidebar gallery!")
933
+
934
+ update_gallery()