awacke1 commited on
Commit
26a04a2
·
verified ·
1 Parent(s): 3a3e5ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -455
app.py CHANGED
@@ -64,7 +64,11 @@ if 'asset_checkboxes' not in st.session_state:
64
  if 'downloaded_pdfs' not in st.session_state:
65
  st.session_state['downloaded_pdfs'] = {}
66
  if 'unique_counter' not in st.session_state:
67
- st.session_state['unique_counter'] = 0 # For generating unique keys
 
 
 
 
68
 
69
  @dataclass
70
  class ModelConfig:
@@ -87,122 +91,11 @@ class DiffusionConfig:
87
  def model_path(self):
88
  return f"diffusion_models/{self.name}"
89
 
90
- class SFTDataset(Dataset):
91
- def __init__(self, data, tokenizer, max_length=128):
92
- self.data = data
93
- self.tokenizer = tokenizer
94
- self.max_length = max_length
95
- def __len__(self):
96
- return len(self.data)
97
- def __getitem__(self, idx):
98
- prompt = self.data[idx]["prompt"]
99
- response = self.data[idx]["response"]
100
- full_text = f"{prompt} {response}"
101
- full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
102
- prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
103
- input_ids = full_encoding["input_ids"].squeeze()
104
- attention_mask = full_encoding["attention_mask"].squeeze()
105
- labels = input_ids.clone()
106
- prompt_len = prompt_encoding["input_ids"].shape[1]
107
- if prompt_len < self.max_length:
108
- labels[:prompt_len] = -100
109
- return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
110
-
111
- class DiffusionDataset(Dataset):
112
- def __init__(self, images, texts):
113
- self.images = images
114
- self.texts = texts
115
- def __len__(self):
116
- return len(self.images)
117
- def __getitem__(self, idx):
118
- return {"image": self.images[idx], "text": self.texts[idx]}
119
-
120
- class TinyDiffusionDataset(Dataset):
121
- def __init__(self, images):
122
- self.images = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images]
123
- def __len__(self):
124
- return len(self.images)
125
- def __getitem__(self, idx):
126
- return self.images[idx]
127
-
128
- class TinyUNet(nn.Module):
129
- def __init__(self, in_channels=3, out_channels=3):
130
- super(TinyUNet, self).__init__()
131
- self.down1 = nn.Conv2d(in_channels, 32, 3, padding=1)
132
- self.down2 = nn.Conv2d(32, 64, 3, padding=1, stride=2)
133
- self.mid = nn.Conv2d(64, 128, 3, padding=1)
134
- self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
135
- self.up2 = nn.Conv2d(64 + 32, 32, 3, padding=1)
136
- self.out = nn.Conv2d(32, out_channels, 3, padding=1)
137
- self.time_embed = nn.Linear(1, 64)
138
-
139
- def forward(self, x, t):
140
- t_embed = F.relu(self.time_embed(t.unsqueeze(-1)))
141
- t_embed = t_embed.view(t_embed.size(0), t_embed.size(1), 1, 1)
142
-
143
- x1 = F.relu(self.down1(x))
144
- x2 = F.relu(self.down2(x1))
145
- x_mid = F.relu(self.mid(x2)) + t_embed
146
- x_up1 = F.relu(self.up1(x_mid))
147
- x_up2 = F.relu(self.up2(torch.cat([x_up1, x1], dim=1)))
148
- return self.out(x_up2)
149
-
150
- class TinyDiffusion:
151
- def __init__(self, model, timesteps=100):
152
- self.model = model
153
- self.timesteps = timesteps
154
- self.beta = torch.linspace(0.0001, 0.02, timesteps)
155
- self.alpha = 1 - self.beta
156
- self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
157
-
158
- def train(self, images, epochs=50):
159
- dataset = TinyDiffusionDataset(images)
160
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
161
- optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
162
- device = torch.device("cpu")
163
- self.model.to(device)
164
- for epoch in range(epochs):
165
- total_loss = 0
166
- for x in dataloader:
167
- x = x.to(device)
168
- t = torch.randint(0, self.timesteps, (x.size(0),), device=device).float()
169
- noise = torch.randn_like(x)
170
- alpha_t = self.alpha_cumprod[t.long()].view(-1, 1, 1, 1)
171
- x_noisy = torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise
172
- pred_noise = self.model(x_noisy, t)
173
- loss = F.mse_loss(pred_noise, noise)
174
- optimizer.zero_grad()
175
- loss.backward()
176
- optimizer.step()
177
- total_loss += loss.item()
178
- logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")
179
- return self
180
-
181
- def generate(self, size=(64, 64), steps=100):
182
- device = torch.device("cpu")
183
- x = torch.randn(1, 3, size[0], size[1], device=device)
184
- for t in reversed(range(steps)):
185
- t_tensor = torch.full((1,), t, device=device, dtype=torch.float32)
186
- alpha_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
187
- pred_noise = self.model(x, t_tensor)
188
- x = (x - (1 - self.alpha[t]) / torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(self.alpha[t])
189
- if t > 0:
190
- x += torch.sqrt(self.beta[t]) * torch.randn_like(x)
191
- x = torch.clamp(x * 255, 0, 255).byte()
192
- return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
193
-
194
- def upscale(self, image, scale_factor=2):
195
- img_tensor = torch.tensor(np.array(image.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0) / 255.0
196
- upscaled = F.interpolate(img_tensor, scale_factor=scale_factor, mode='bilinear', align_corners=False)
197
- upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
198
- return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
199
-
200
  class ModelBuilder:
201
  def __init__(self):
202
  self.config = None
203
  self.model = None
204
  self.tokenizer = None
205
- self.sft_data = None
206
  self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
207
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
208
  with st.spinner(f"Loading {model_path}... ⏳"):
@@ -215,53 +108,12 @@ class ModelBuilder:
215
  self.model.to("cuda" if torch.cuda.is_available() else "cpu")
216
  st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
217
  return self
218
- def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
219
- self.sft_data = []
220
- with open(csv_path, "r") as f:
221
- reader = csv.DictReader(f)
222
- for row in reader:
223
- self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
224
- dataset = SFTDataset(self.sft_data, self.tokenizer)
225
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
226
- optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
227
- self.model.train()
228
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
229
- self.model.to(device)
230
- for epoch in range(epochs):
231
- with st.spinner(f"Training epoch {epoch + 1}/{epochs}... ⚙️"):
232
- total_loss = 0
233
- for batch in dataloader:
234
- optimizer.zero_grad()
235
- input_ids = batch["input_ids"].to(device)
236
- attention_mask = batch["attention_mask"].to(device)
237
- labels = batch["labels"].to(device)
238
- outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
239
- loss = outputs.loss
240
- loss.backward()
241
- optimizer.step()
242
- total_loss += loss.item()
243
- st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
244
- st.success(f"SFT Fine-tuning completed! 🎉 {random.choice(self.jokes)}")
245
- return self
246
  def save_model(self, path: str):
247
  with st.spinner("Saving model... 💾"):
248
  os.makedirs(os.path.dirname(path), exist_ok=True)
249
  self.model.save_pretrained(path)
250
  self.tokenizer.save_pretrained(path)
251
  st.success(f"Model saved at {path}! ✅")
252
- def evaluate(self, prompt: str, status_container=None):
253
- self.model.eval()
254
- if status_container:
255
- status_container.write("Preparing to evaluate... 🧠")
256
- try:
257
- with torch.no_grad():
258
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
259
- outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
260
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
261
- except Exception as e:
262
- if status_container:
263
- status_container.error(f"Oops! Something broke: {str(e)} 💥")
264
- return f"Error: {str(e)}"
265
 
266
  class DiffusionBuilder:
267
  def __init__(self):
@@ -274,31 +126,6 @@ class DiffusionBuilder:
274
  self.config = config
275
  st.success(f"Diffusion model loaded! 🎨")
276
  return self
277
- def fine_tune_sft(self, images, texts, epochs=3):
278
- dataset = DiffusionDataset(images, texts)
279
- dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
280
- optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
281
- self.pipeline.unet.train()
282
- for epoch in range(epochs):
283
- with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... ⚙️"):
284
- total_loss = 0
285
- for batch in dataloader:
286
- optimizer.zero_grad()
287
- image = batch["image"][0].to(self.pipeline.device)
288
- text = batch["text"][0]
289
- latents = self.pipeline.vae.encode(torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float().to(self.pipeline.device)).latent_dist.sample()
290
- noise = torch.randn_like(latents)
291
- timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
292
- noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
293
- text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
294
- pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
295
- loss = torch.nn.functional.mse_loss(pred_noise, noise)
296
- loss.backward()
297
- optimizer.step()
298
- total_loss += loss.item()
299
- st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
300
- st.success("Diffusion SFT Fine-tuning completed! 🎨")
301
- return self
302
  def save_model(self, path: str):
303
  with st.spinner("Saving diffusion model... 💾"):
304
  os.makedirs(os.path.dirname(path), exist_ok=True)
@@ -329,7 +156,8 @@ def zip_directory(directory_path, zip_path):
329
 
330
  def get_model_files(model_type="causal_lm"):
331
  path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
332
- return [d for d in glob.glob(path) if os.path.isdir(d)]
 
333
 
334
  def get_gallery_files(file_types=["png", "pdf"]):
335
  return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files
@@ -426,87 +254,102 @@ async def process_custom_diffusion(images, output_file, model_name):
426
  update_gallery()
427
  return upscaled_image
428
 
429
- def mock_search(query: str) -> str:
430
- if "superhero" in query.lower():
431
- return "Latest trends: Gold-plated Batman statues, VR superhero battles."
432
- return "No relevant results found."
433
-
434
- def mock_duckduckgo_search(query: str) -> str:
435
- if "superhero party trends" in query.lower():
436
- return """
437
- Latest trends for 2025:
438
- - Luxury decorations: Gold-plated Batman statues, holographic Avengers displays.
439
- - Entertainment: Live stunt shows with Iron Man suits, VR superhero battles.
440
- - Catering: Gourmet kryptonite-green cocktails, Thor’s hammer-shaped appetizers.
441
- """
442
- return "No relevant results found."
443
-
444
- class PartyPlannerAgent:
445
- def __init__(self, model, tokenizer):
 
 
 
 
 
 
 
446
  self.model = model
447
- self.tokenizer = tokenizer
448
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
449
- self.model.to(self.device)
450
- def generate(self, prompt: str) -> str:
451
- self.model.eval()
452
- with torch.no_grad():
453
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
454
- outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
455
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
456
- def plan_party(self, task: str) -> pd.DataFrame:
457
- search_result = mock_duckduckgo_search("latest superhero party trends")
458
- prompt = f"Given this context: '{search_result}'\n{task}"
459
- plan_text = self.generate(prompt)
460
- locations = {
461
- "Wayne Manor": (42.3601, -71.0589),
462
- "New York": (40.7128, -74.0060),
463
- "Los Angeles": (34.0522, -118.2437),
464
- "London": (51.5074, -0.1278)
465
- }
466
- wayne_coords = locations["Wayne Manor"]
467
- travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
468
- catchphrases = ["To the Batmobile!", "Avengers, assemble!", "I am Iron Man!", "By the power of Grayskull!"]
469
- data = [
470
- {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues", "Catchphrase": random.choice(catchphrases)},
471
- {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Holographic Avengers displays", "Catchphrase": random.choice(catchphrases)},
472
- {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows with Iron Man suits", "Catchphrase": random.choice(catchphrases)},
473
- {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "VR superhero battles", "Catchphrase": random.choice(catchphrases)},
474
- {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gourmet kryptonite-green cocktails", "Catchphrase": random.choice(catchphrases)},
475
- {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "Thor’s hammer-shaped appetizers", "Catchphrase": random.choice(catchphrases)},
476
- ]
477
- return pd.DataFrame(data)
478
 
479
- class CVPartyPlannerAgent:
480
- def __init__(self, pipeline):
481
- self.pipeline = pipeline
482
- def generate(self, prompt: str) -> Image.Image:
483
- return self.pipeline(prompt, num_inference_steps=20).images[0]
484
- def plan_party(self, task: str) -> pd.DataFrame:
485
- search_result = mock_search("superhero party trends")
486
- prompt = f"Given this context: '{search_result}'\n{task}"
487
- data = [
488
- {"Theme": "Batman", "Image Idea": "Gold-plated Batman statue"},
489
- {"Theme": "Avengers", "Image Idea": "VR superhero battle scene"}
490
- ]
491
- return pd.DataFrame(data)
492
-
493
- def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
494
- def to_radians(degrees: float) -> float:
495
- return degrees * (math.pi / 180)
496
- lat1, lon1 = map(to_radians, origin_coords)
497
- lat2, lon2 = map(to_radians, destination_coords)
498
- EARTH_RADIUS_KM = 6371.0
499
- dlon = lon2 - lon1
500
- dlat = lat2 - lat1
501
- a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
502
- c = 2 * math.asin(math.sqrt(a))
503
- distance = EARTH_RADIUS_KM * c
504
- actual_distance = distance * 1.1
505
- flight_time = (actual_distance / cruising_speed_kmh) + 1.0
506
- return round(flight_time, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  st.title("AI Vision & SFT Titans 🚀")
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  st.sidebar.header("Captured Files 📜")
511
  cols = st.sidebar.columns(2)
512
  with cols[0]:
@@ -533,7 +376,7 @@ def update_gallery():
533
  cols = st.sidebar.columns(2)
534
  for idx, file in enumerate(all_files[:gallery_size * 2]):
535
  with cols[idx % 2]:
536
- st.session_state['unique_counter'] += 1 # Increment counter for uniqueness
537
  unique_id = st.session_state['unique_counter']
538
  if file.endswith('.png'):
539
  st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
@@ -543,7 +386,7 @@ def update_gallery():
543
  img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
544
  st.image(img, caption=os.path.basename(file), use_container_width=True)
545
  doc.close()
546
- checkbox_key = f"asset_{file}_{unique_id}" # Unique key with counter
547
  st.session_state['asset_checkboxes'][file] = st.checkbox(
548
  "Use for SFT/Input",
549
  value=st.session_state['asset_checkboxes'].get(file, False),
@@ -551,7 +394,7 @@ def update_gallery():
551
  )
552
  mime_type = "image/png" if file.endswith('.png') else "application/pdf"
553
  st.markdown(get_download_link(file, mime_type, "Snag It! 📥"), unsafe_allow_html=True)
554
- if st.button("Zap It! 🗑️", key=f"delete_{file}_{unique_id}"): # Unique key with counter
555
  os.remove(file)
556
  if file in st.session_state['asset_checkboxes']:
557
  del st.session_state['asset_checkboxes'][file]
@@ -563,18 +406,6 @@ def update_gallery():
563
  st.rerun()
564
  update_gallery()
565
 
566
- st.sidebar.subheader("Model Management 🗂️")
567
- model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"], key="sidebar_model_type")
568
- model_dirs = get_model_files(model_type)
569
- selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs, key="sidebar_model_select")
570
- if selected_model != "None" and st.sidebar.button("Load Model 📂"):
571
- builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
572
- config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
573
- builder.load_model(selected_model, config)
574
- st.session_state['builder'] = builder
575
- st.session_state['model_loaded'] = True
576
- st.rerun()
577
-
578
  st.sidebar.subheader("Action Logs 📜")
579
  log_container = st.sidebar.empty()
580
  with log_container:
@@ -587,9 +418,8 @@ with history_container:
587
  for entry in st.session_state['history'][-gallery_size * 2:]:
588
  st.write(entry)
589
 
590
- tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9 = st.tabs([
591
- "Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Fine-Tune Titan 🔧",
592
- "Test Titan 🧪", "Agentic RAG Party 🌐", "Test OCR 🔍", "Test Image Gen 🎨", "Custom Diffusion 🎨🤓"
593
  ])
594
 
595
  with tab1:
@@ -694,6 +524,8 @@ with tab3:
694
  builder.save_model(config.model_path)
695
  st.session_state['builder'] = builder
696
  st.session_state['model_loaded'] = True
 
 
697
  entry = f"Built {model_type} model: {model_name}"
698
  if entry not in st.session_state['history']:
699
  st.session_state['history'].append(entry)
@@ -701,141 +533,30 @@ with tab3:
701
  st.rerun()
702
 
703
  with tab4:
704
- st.header("Fine-Tune Titan 🔧")
705
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
706
- st.warning("Please build or load a Titan first! ⚠️")
707
- else:
708
- if isinstance(st.session_state['builder'], ModelBuilder):
709
- if st.button("Generate Sample CSV 📝"):
710
- sample_data = [
711
- {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human smarts in machines."},
712
- {"prompt": "Explain machine learning", "response": "Machine learning is AI’s gym where models bulk up on data."},
713
- ]
714
- csv_path = f"sft_data_{int(time.time())}.csv"
715
- with open(csv_path, "w", newline="") as f:
716
- writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
717
- writer.writeheader()
718
- writer.writerows(sample_data)
719
- st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
720
- st.success(f"Sample CSV generated as {csv_path}! ✅")
721
-
722
- uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
723
- if uploaded_csv and st.button("Fine-Tune with Uploaded CSV 🔄"):
724
- csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
725
- with open(csv_path, "wb") as f:
726
- f.write(uploaded_csv.read())
727
- new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
728
- new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
729
- st.session_state['builder'].config = new_config
730
- st.session_state['builder'].fine_tune_sft(csv_path)
731
- st.session_state['builder'].save_model(new_config.model_path)
732
- zip_path = f"{new_config.model_path}.zip"
733
- zip_directory(new_config.model_path, zip_path)
734
- entry = f"Fine-tuned Causal LM: {new_model_name}"
735
- if entry not in st.session_state['history']:
736
- st.session_state['history'].append(entry)
737
- st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
738
- st.rerun()
739
- elif isinstance(st.session_state['builder'], DiffusionBuilder):
740
- selected_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
741
- if len(selected_files) >= 2:
742
- demo_data = [{"image": file, "text": f"Asset {os.path.basename(file).split('.')[0]}"} for file in selected_files]
743
- edited_data = st.data_editor(pd.DataFrame(demo_data), num_rows="dynamic")
744
- if st.button("Fine-Tune with Dataset 🔄"):
745
- images = [Image.open(row["image"]) if row["image"].endswith('.png') else Image.frombytes("RGB", fitz.open(row["image"])[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)).size, fitz.open(row["image"])[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)).samples) for _, row in edited_data.iterrows()]
746
- texts = [row["text"] for _, row in edited_data.iterrows()]
747
- new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
748
- new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small", domain=st.session_state['builder'].config.domain)
749
- st.session_state['builder'].config = new_config
750
- st.session_state['builder'].fine_tune_sft(images, texts)
751
- st.session_state['builder'].save_model(new_config.model_path)
752
- zip_path = f"{new_config.model_path}.zip"
753
- zip_directory(new_config.model_path, zip_path)
754
- entry = f"Fine-tuned Diffusion: {new_model_name}"
755
- if entry not in st.session_state['history']:
756
- st.session_state['history'].append(entry)
757
- st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)
758
- csv_path = f"sft_dataset_{int(time.time())}.csv"
759
- with open(csv_path, "w", newline="") as f:
760
- writer = csv.writer(f)
761
- writer.writerow(["image", "text"])
762
- for _, row in edited_data.iterrows():
763
- writer.writerow([row["image"], row["text"]])
764
- st.markdown(get_download_link(csv_path, "text/csv", "Download SFT Dataset CSV"), unsafe_allow_html=True)
765
-
766
- with tab5:
767
- st.header("Test Titan 🧪")
768
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
769
- st.warning("Please build or load a Titan first! ⚠️")
770
- else:
771
- if isinstance(st.session_state['builder'], ModelBuilder):
772
- if st.session_state['builder'].sft_data:
773
- st.write("Testing with SFT Data:")
774
- for item in st.session_state['builder'].sft_data[:3]:
775
- prompt = item["prompt"]
776
- expected = item["response"]
777
- status_container = st.empty()
778
- generated = st.session_state['builder'].evaluate(prompt, status_container)
779
- st.write(f"**Prompt**: {prompt}")
780
- st.write(f"**Expected**: {expected}")
781
- st.write(f"**Generated**: {generated}")
782
- st.write("---")
783
- status_container.empty()
784
- test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
785
- if st.button("Run Test ▶️"):
786
- status_container = st.empty()
787
- result = st.session_state['builder'].evaluate(test_prompt, status_container)
788
- entry = f"Causal LM Test: {test_prompt} -> {result}"
789
- if entry not in st.session_state['history']:
790
- st.session_state['history'].append(entry)
791
- st.write(f"**Generated Response**: {result}")
792
- status_container.empty()
793
- elif isinstance(st.session_state['builder'], DiffusionBuilder):
794
- test_prompt = st.text_area("Enter Test Prompt", "Neon Batman")
795
- if st.button("Run Test ▶️"):
796
- image = st.session_state['builder'].generate(test_prompt)
797
- output_file = generate_filename("diffusion_test", "png")
798
- image.save(output_file)
799
- entry = f"Diffusion Test: {test_prompt} -> {output_file}"
800
- if entry not in st.session_state['history']:
801
- st.session_state['history'].append(entry)
802
- st.image(image, caption="Generated Image")
803
- update_gallery()
804
-
805
- with tab6:
806
- st.header("Agentic RAG Party 🌐")
807
- if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
808
- st.warning("Please build or load a Titan first! ⚠️")
809
- else:
810
- if isinstance(st.session_state['builder'], ModelBuilder):
811
- if st.button("Run NLP RAG Demo 🎉"):
812
- agent = PartyPlannerAgent(st.session_state['builder'].model, st.session_state['builder'].tokenizer)
813
- task = "Plan a luxury superhero-themed party at Wayne Manor."
814
- plan_df = agent.plan_party(task)
815
- entry = f"NLP RAG Demo: Planned party at Wayne Manor"
816
- if entry not in st.session_state['history']:
817
- st.session_state['history'].append(entry)
818
- st.dataframe(plan_df)
819
- elif isinstance(st.session_state['builder'], DiffusionBuilder):
820
- if st.button("Run CV RAG Demo 🎉"):
821
- agent = CVPartyPlannerAgent(st.session_state['builder'].pipeline)
822
- task = "Generate images for a luxury superhero-themed party."
823
- plan_df = agent.plan_party(task)
824
- entry = f"CV RAG Demo: Generated party images"
825
- if entry not in st.session_state['history']:
826
- st.session_state['history'].append(entry)
827
- st.dataframe(plan_df)
828
- for _, row in plan_df.iterrows():
829
- image = agent.generate(row["Image Idea"])
830
- output_file = generate_filename(f"cv_rag_{row['Theme'].lower()}", "png")
831
- image.save(output_file)
832
- st.image(image, caption=f"{row['Theme']} - {row['Image Idea']}")
833
- update_gallery()
834
-
835
- with tab7:
836
  st.header("Test OCR 🔍")
837
- all_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
838
  if all_files:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  selected_file = st.selectbox("Select Image or PDF", all_files, key="ocr_select")
840
  if selected_file:
841
  if selected_file.endswith('.png'):
@@ -856,12 +577,29 @@ with tab7:
856
  st.text_area("OCR Result", result, height=200, key="ocr_result")
857
  st.success(f"OCR output saved to {output_file}")
858
  st.session_state['processing']['ocr'] = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
  else:
860
- st.warning("No images or PDFs selected yet. Check some boxes in the sidebar gallery!")
861
 
862
- with tab8:
863
  st.header("Test Image Gen 🎨")
864
- all_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
865
  if all_files:
866
  selected_file = st.selectbox("Select Image or PDF", all_files, key="gen_select")
867
  if selected_file:
@@ -873,7 +611,7 @@ with tab8:
873
  image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
874
  doc.close()
875
  st.image(image, caption="Reference Image", use_container_width=True)
876
- prompt = st.text_area("Prompt", "Generate a similar superhero image", key="gen_prompt")
877
  if st.button("Run Image Gen 🚀", key="gen_run"):
878
  output_file = generate_filename("gen_output", "png")
879
  st.session_state['processing']['gen'] = True
@@ -885,50 +623,6 @@ with tab8:
885
  st.success(f"Image saved to {output_file}")
886
  st.session_state['processing']['gen'] = False
887
  else:
888
- st.warning("No images or PDFs selected yet. Check some boxes in the sidebar gallery!")
889
-
890
- with tab9:
891
- st.header("Custom Diffusion 🎨🤓")
892
- st.write("Unleash your inner artist with our tiny diffusion models!")
893
- all_files = [path for path in get_gallery_files() if st.session_state['asset_checkboxes'].get(path, False)]
894
- if all_files:
895
- st.subheader("Select Images or PDFs to Train")
896
- selected_files = st.multiselect("Pick Images or PDFs", all_files, key="diffusion_select")
897
- images = []
898
- for file in selected_files:
899
- if file.endswith('.png'):
900
- images.append(Image.open(file))
901
- else:
902
- doc = fitz.open(file)
903
- pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
904
- images.append(Image.frombytes("RGB", [pix.width, pix.height], pix.samples))
905
- doc.close()
906
-
907
- model_options = [
908
- ("PixelTickler 🎨✨", "OFA-Sys/small-stable-diffusion-v0"),
909
- ("DreamWeaver 🌙🖌️", "stabilityai/stable-diffusion-2-base"),
910
- ("TinyArtBot 🤖🖼️", "custom")
911
- ]
912
- model_choice = st.selectbox("Choose Your Diffusion Dynamo", [opt[0] for opt in model_options], key="diffusion_model")
913
- model_name = next(opt[1] for opt in model_options if opt[0] == model_choice)
914
-
915
- if st.button("Train & Generate 🚀", key="diffusion_run"):
916
- output_file = generate_filename("custom_diffusion", "png")
917
- st.session_state['processing']['diffusion'] = True
918
- if model_name == "custom":
919
- result = asyncio.run(process_custom_diffusion(images, output_file, model_choice))
920
- else:
921
- builder = DiffusionBuilder()
922
- builder.load_model(model_name)
923
- result = builder.generate("A superhero scene inspired by captured images")
924
- result.save(output_file)
925
- entry = f"Custom Diffusion: {model_choice} -> {output_file}"
926
- if entry not in st.session_state['history']:
927
- st.session_state['history'].append(entry)
928
- st.image(result, caption=f"{model_choice} Masterpiece", use_container_width=True)
929
- st.success(f"Image saved to {output_file}")
930
- st.session_state['processing']['diffusion'] = False
931
- else:
932
- st.warning("No images or PDFs selected yet. Check some boxes in the sidebar gallery!")
933
 
934
  update_gallery()
 
64
  if 'downloaded_pdfs' not in st.session_state:
65
  st.session_state['downloaded_pdfs'] = {}
66
  if 'unique_counter' not in st.session_state:
67
+ st.session_state['unique_counter'] = 0
68
+ if 'selected_model_type' not in st.session_state:
69
+ st.session_state['selected_model_type'] = "Causal LM"
70
+ if 'selected_model' not in st.session_state:
71
+ st.session_state['selected_model'] = "None"
72
 
73
  @dataclass
74
  class ModelConfig:
 
91
  def model_path(self):
92
  return f"diffusion_models/{self.name}"
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  class ModelBuilder:
95
  def __init__(self):
96
  self.config = None
97
  self.model = None
98
  self.tokenizer = None
 
99
  self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! 😂", "Training complete! Time for a binary coffee break. ☕"]
100
  def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
101
  with st.spinner(f"Loading {model_path}... ⏳"):
 
108
  self.model.to("cuda" if torch.cuda.is_available() else "cpu")
109
  st.success(f"Model loaded! 🎉 {random.choice(self.jokes)}")
110
  return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def save_model(self, path: str):
112
  with st.spinner("Saving model... 💾"):
113
  os.makedirs(os.path.dirname(path), exist_ok=True)
114
  self.model.save_pretrained(path)
115
  self.tokenizer.save_pretrained(path)
116
  st.success(f"Model saved at {path}! ✅")
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  class DiffusionBuilder:
119
  def __init__(self):
 
126
  self.config = config
127
  st.success(f"Diffusion model loaded! 🎨")
128
  return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def save_model(self, path: str):
130
  with st.spinner("Saving diffusion model... 💾"):
131
  os.makedirs(os.path.dirname(path), exist_ok=True)
 
156
 
157
  def get_model_files(model_type="causal_lm"):
158
  path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
159
+ dirs = [d for d in glob.glob(path) if os.path.isdir(d)]
160
+ return dirs if dirs else ["None"]
161
 
162
  def get_gallery_files(file_types=["png", "pdf"]):
163
  return sorted(list(set([f for ext in file_types for f in glob.glob(f"*.{ext}")]))) # Deduplicate files
 
254
  update_gallery()
255
  return upscaled_image
256
 
257
+ class TinyUNet(nn.Module):
258
+ def __init__(self, in_channels=3, out_channels=3):
259
+ super(TinyUNet, self).__init__()
260
+ self.down1 = nn.Conv2d(in_channels, 32, 3, padding=1)
261
+ self.down2 = nn.Conv2d(32, 64, 3, padding=1, stride=2)
262
+ self.mid = nn.Conv2d(64, 128, 3, padding=1)
263
+ self.up1 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
264
+ self.up2 = nn.Conv2d(64 + 32, 32, 3, padding=1)
265
+ self.out = nn.Conv2d(32, out_channels, 3, padding=1)
266
+ self.time_embed = nn.Linear(1, 64)
267
+
268
+ def forward(self, x, t):
269
+ t_embed = F.relu(self.time_embed(t.unsqueeze(-1)))
270
+ t_embed = t_embed.view(t_embed.size(0), t_embed.size(1), 1, 1)
271
+
272
+ x1 = F.relu(self.down1(x))
273
+ x2 = F.relu(self.down2(x1))
274
+ x_mid = F.relu(self.mid(x2)) + t_embed
275
+ x_up1 = F.relu(self.up1(x_mid))
276
+ x_up2 = F.relu(self.up2(torch.cat([x_up1, x1], dim=1)))
277
+ return self.out(x_up2)
278
+
279
+ class TinyDiffusion:
280
+ def __init__(self, model, timesteps=100):
281
  self.model = model
282
+ self.timesteps = timesteps
283
+ self.beta = torch.linspace(0.0001, 0.02, timesteps)
284
+ self.alpha = 1 - self.beta
285
+ self.alpha_cumprod = torch.cumprod(self.alpha, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ def train(self, images, epochs=50):
288
+ dataset = TinyDiffusionDataset(images)
289
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
290
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
291
+ device = torch.device("cpu")
292
+ self.model.to(device)
293
+ for epoch in range(epochs):
294
+ total_loss = 0
295
+ for x in dataloader:
296
+ x = x.to(device)
297
+ t = torch.randint(0, self.timesteps, (x.size(0),), device=device).float()
298
+ noise = torch.randn_like(x)
299
+ alpha_t = self.alpha_cumprod[t.long()].view(-1, 1, 1, 1)
300
+ x_noisy = torch.sqrt(alpha_t) * x + torch.sqrt(1 - alpha_t) * noise
301
+ pred_noise = self.model(x_noisy, t)
302
+ loss = F.mse_loss(pred_noise, noise)
303
+ optimizer.zero_grad()
304
+ loss.backward()
305
+ optimizer.step()
306
+ total_loss += loss.item()
307
+ logger.info(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")
308
+ return self
309
+
310
+ def generate(self, size=(64, 64), steps=100):
311
+ device = torch.device("cpu")
312
+ x = torch.randn(1, 3, size[0], size[1], device=device)
313
+ for t in reversed(range(steps)):
314
+ t_tensor = torch.full((1,), t, device=device, dtype=torch.float32)
315
+ alpha_t = self.alpha_cumprod[t].view(-1, 1, 1, 1)
316
+ pred_noise = self.model(x, t_tensor)
317
+ x = (x - (1 - self.alpha[t]) / torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(self.alpha[t])
318
+ if t > 0:
319
+ x += torch.sqrt(self.beta[t]) * torch.randn_like(x)
320
+ x = torch.clamp(x * 255, 0, 255).byte()
321
+ return Image.fromarray(x.squeeze(0).permute(1, 2, 0).cpu().numpy())
322
+
323
+ def upscale(self, image, scale_factor=2):
324
+ img_tensor = torch.tensor(np.array(image.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0) / 255.0
325
+ upscaled = F.interpolate(img_tensor, scale_factor=scale_factor, mode='bilinear', align_corners=False)
326
+ upscaled = torch.clamp(upscaled * 255, 0, 255).byte()
327
+ return Image.fromarray(upscaled.squeeze(0).permute(1, 2, 0).cpu().numpy())
328
+
329
+ class TinyDiffusionDataset(Dataset):
330
+ def __init__(self, images):
331
+ self.images = [torch.tensor(np.array(img.convert("RGB")).transpose(2, 0, 1), dtype=torch.float32) / 255.0 for img in images]
332
+ def __len__(self):
333
+ return len(self.images)
334
+ def __getitem__(self, idx):
335
+ return self.images[idx]
336
 
337
  st.title("AI Vision & SFT Titans 🚀")
338
 
339
+ # Sidebar
340
+ model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"], key="sidebar_model_type", index=0 if st.session_state['selected_model_type'] == "Causal LM" else 1)
341
+ model_dirs = get_model_files(model_type)
342
+ if model_dirs and st.session_state['selected_model'] == "None" and "None" not in model_dirs:
343
+ st.session_state['selected_model'] = model_dirs[0]
344
+ selected_model = st.sidebar.selectbox("Select Saved Model", model_dirs, key="sidebar_model_select", index=model_dirs.index(st.session_state['selected_model']) if st.session_state['selected_model'] in model_dirs else 0)
345
+ if selected_model != "None" and st.sidebar.button("Load Model 📂"):
346
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
347
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
348
+ builder.load_model(selected_model, config)
349
+ st.session_state['builder'] = builder
350
+ st.session_state['model_loaded'] = True
351
+ st.rerun()
352
+
353
  st.sidebar.header("Captured Files 📜")
354
  cols = st.sidebar.columns(2)
355
  with cols[0]:
 
376
  cols = st.sidebar.columns(2)
377
  for idx, file in enumerate(all_files[:gallery_size * 2]):
378
  with cols[idx % 2]:
379
+ st.session_state['unique_counter'] += 1
380
  unique_id = st.session_state['unique_counter']
381
  if file.endswith('.png'):
382
  st.image(Image.open(file), caption=os.path.basename(file), use_container_width=True)
 
386
  img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
387
  st.image(img, caption=os.path.basename(file), use_container_width=True)
388
  doc.close()
389
+ checkbox_key = f"asset_{file}_{unique_id}"
390
  st.session_state['asset_checkboxes'][file] = st.checkbox(
391
  "Use for SFT/Input",
392
  value=st.session_state['asset_checkboxes'].get(file, False),
 
394
  )
395
  mime_type = "image/png" if file.endswith('.png') else "application/pdf"
396
  st.markdown(get_download_link(file, mime_type, "Snag It! 📥"), unsafe_allow_html=True)
397
+ if st.button("Zap It! 🗑️", key=f"delete_{file}_{unique_id}"):
398
  os.remove(file)
399
  if file in st.session_state['asset_checkboxes']:
400
  del st.session_state['asset_checkboxes'][file]
 
406
  st.rerun()
407
  update_gallery()
408
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  st.sidebar.subheader("Action Logs 📜")
410
  log_container = st.sidebar.empty()
411
  with log_container:
 
418
  for entry in st.session_state['history'][-gallery_size * 2:]:
419
  st.write(entry)
420
 
421
+ tab1, tab2, tab3, tab4, tab5 = st.tabs([
422
+ "Camera Snap 📷", "Download PDFs 📥", "Build Titan 🌱", "Test OCR 🔍", "Test Image Gen 🎨"
 
423
  ])
424
 
425
  with tab1:
 
524
  builder.save_model(config.model_path)
525
  st.session_state['builder'] = builder
526
  st.session_state['model_loaded'] = True
527
+ st.session_state['selected_model_type'] = model_type
528
+ st.session_state['selected_model'] = config.model_path
529
  entry = f"Built {model_type} model: {model_name}"
530
  if entry not in st.session_state['history']:
531
  st.session_state['history'].append(entry)
 
533
  st.rerun()
534
 
535
  with tab4:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  st.header("Test OCR 🔍")
537
+ all_files = get_gallery_files()
538
  if all_files:
539
+ if st.button("OCR All Assets 🚀"):
540
+ full_text = "# OCR Results\n\n"
541
+ for file in all_files:
542
+ if file.endswith('.png'):
543
+ image = Image.open(file)
544
+ else:
545
+ doc = fitz.open(file)
546
+ pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
547
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
548
+ doc.close()
549
+ output_file = generate_filename(f"ocr_{os.path.basename(file)}", "txt")
550
+ result = asyncio.run(process_ocr(image, output_file))
551
+ full_text += f"## {os.path.basename(file)}\n\n{result}\n\n"
552
+ entry = f"OCR Test: {file} -> {output_file}"
553
+ if entry not in st.session_state['history']:
554
+ st.session_state['history'].append(entry)
555
+ md_output_file = f"full_ocr_{int(time.time())}.md"
556
+ with open(md_output_file, "w") as f:
557
+ f.write(full_text)
558
+ st.success(f"Full OCR saved to {md_output_file}")
559
+ st.markdown(get_download_link(md_output_file, "text/markdown", "Download Full OCR Markdown"), unsafe_allow_html=True)
560
  selected_file = st.selectbox("Select Image or PDF", all_files, key="ocr_select")
561
  if selected_file:
562
  if selected_file.endswith('.png'):
 
577
  st.text_area("OCR Result", result, height=200, key="ocr_result")
578
  st.success(f"OCR output saved to {output_file}")
579
  st.session_state['processing']['ocr'] = False
580
+ if selected_file.endswith('.pdf') and st.button("OCR All Pages 🚀", key="ocr_all_pages"):
581
+ doc = fitz.open(selected_file)
582
+ full_text = f"# OCR Results for {os.path.basename(selected_file)}\n\n"
583
+ for i in range(len(doc)):
584
+ pix = doc[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0))
585
+ image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
586
+ output_file = generate_filename(f"ocr_page_{i}", "txt")
587
+ result = asyncio.run(process_ocr(image, output_file))
588
+ full_text += f"## Page {i + 1}\n\n{result}\n\n"
589
+ entry = f"OCR Test: {selected_file} Page {i + 1} -> {output_file}"
590
+ if entry not in st.session_state['history']:
591
+ st.session_state['history'].append(entry)
592
+ md_output_file = f"full_ocr_{os.path.basename(selected_file)}_{int(time.time())}.md"
593
+ with open(md_output_file, "w") as f:
594
+ f.write(full_text)
595
+ st.success(f"Full OCR saved to {md_output_file}")
596
+ st.markdown(get_download_link(md_output_file, "text/markdown", "Download Full OCR Markdown"), unsafe_allow_html=True)
597
  else:
598
+ st.warning("No assets in gallery yet. Use Camera Snap or Download PDFs!")
599
 
600
+ with tab5:
601
  st.header("Test Image Gen 🎨")
602
+ all_files = get_gallery_files()
603
  if all_files:
604
  selected_file = st.selectbox("Select Image or PDF", all_files, key="gen_select")
605
  if selected_file:
 
611
  image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
612
  doc.close()
613
  st.image(image, caption="Reference Image", use_container_width=True)
614
+ prompt = st.text_area("Prompt", "Generate a neon superhero version of this image", key="gen_prompt")
615
  if st.button("Run Image Gen 🚀", key="gen_run"):
616
  output_file = generate_filename("gen_output", "png")
617
  st.session_state['processing']['gen'] = True
 
623
  st.success(f"Image saved to {output_file}")
624
  st.session_state['processing']['gen'] = False
625
  else:
626
+ st.warning("No images or PDFs in gallery yet. Use Camera Snap or Download PDFs!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
 
628
  update_gallery()