awacke1 commited on
Commit
d0aa8c4
Β·
verified Β·
1 Parent(s): 14c12eb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +433 -0
app.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import shutil
4
+ import glob
5
+ import base64
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import csv
12
+ import time
13
+ from dataclasses import dataclass
14
+ from typing import Optional, Tuple
15
+ import zipfile
16
+ import math
17
+ from PIL import Image
18
+ import random
19
+ import logging
20
+ from datetime import datetime
21
+ import pytz
22
+ from diffusers import StableDiffusionPipeline # For diffusion models
23
+ from urllib.parse import quote
24
+
25
+ # Set up logging
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Page Configuration
30
+ st.set_page_config(
31
+ page_title="SFT Tiny Titans πŸš€",
32
+ page_icon="πŸ€–",
33
+ layout="wide",
34
+ initial_sidebar_state="expanded",
35
+ menu_items={
36
+ 'Get Help': 'https://huggingface.co/awacke1',
37
+ 'Report a bug': 'https://huggingface.co/spaces/awacke1',
38
+ 'About': "Tiny Titans: Small models, big dreams, and a sprinkle of chaos! 🌌"
39
+ }
40
+ )
41
+
42
+ # Model Configuration Classes
43
+ @dataclass
44
+ class ModelConfig:
45
+ name: str
46
+ base_model: str
47
+ size: str
48
+ domain: Optional[str] = None
49
+ model_type: str = "causal_lm"
50
+
51
+ @property
52
+ def model_path(self):
53
+ return f"models/{self.name}"
54
+
55
+ @dataclass
56
+ class DiffusionConfig:
57
+ name: str
58
+ base_model: str
59
+ size: str
60
+
61
+ @property
62
+ def model_path(self):
63
+ return f"diffusion_models/{self.name}"
64
+
65
+ # Datasets
66
+ class SFTDataset(Dataset):
67
+ def __init__(self, data, tokenizer, max_length=128):
68
+ self.data = data
69
+ self.tokenizer = tokenizer
70
+ self.max_length = max_length
71
+
72
+ def __len__(self):
73
+ return len(self.data)
74
+
75
+ def __getitem__(self, idx):
76
+ prompt = self.data[idx]["prompt"]
77
+ response = self.data[idx]["response"]
78
+ full_text = f"{prompt} {response}"
79
+ full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
80
+ prompt_encoding = self.tokenizer(prompt, max_length=self.max_length, padding=False, truncation=True, return_tensors="pt")
81
+ input_ids = full_encoding["input_ids"].squeeze()
82
+ attention_mask = full_encoding["attention_mask"].squeeze()
83
+ labels = input_ids.clone()
84
+ prompt_len = prompt_encoding["input_ids"].shape[1]
85
+ if prompt_len < self.max_length:
86
+ labels[:prompt_len] = -100
87
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
88
+
89
+ class DiffusionDataset(Dataset):
90
+ def __init__(self, images, texts):
91
+ self.images = images
92
+ self.texts = texts
93
+
94
+ def __len__(self):
95
+ return len(self.images)
96
+
97
+ def __getitem__(self, idx):
98
+ return {"image": self.images[idx], "text": self.texts[idx]}
99
+
100
+ # Model Builder Classes
101
+ class ModelBuilder:
102
+ def __init__(self):
103
+ self.config = None
104
+ self.model = None
105
+ self.tokenizer = None
106
+ self.sft_data = None
107
+ self.jokes = ["Why did the AI go to therapy? Too many layers to unpack! πŸ˜‚", "Training complete! Time for a binary coffee break. β˜•"]
108
+
109
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
110
+ with st.spinner(f"Loading {model_path}... ⏳"):
111
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
112
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
113
+ if self.tokenizer.pad_token is None:
114
+ self.tokenizer.pad_token = self.tokenizer.eos_token
115
+ if config:
116
+ self.config = config
117
+ st.success(f"Model loaded! πŸŽ‰ {random.choice(self.jokes)}")
118
+ return self
119
+
120
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
121
+ self.sft_data = []
122
+ with open(csv_path, "r") as f:
123
+ reader = csv.DictReader(f)
124
+ for row in reader:
125
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
126
+
127
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
128
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
129
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
130
+
131
+ self.model.train()
132
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
133
+ self.model.to(device)
134
+ for epoch in range(epochs):
135
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... βš™οΈ"):
136
+ total_loss = 0
137
+ for batch in dataloader:
138
+ optimizer.zero_grad()
139
+ input_ids = batch["input_ids"].to(device)
140
+ attention_mask = batch["attention_mask"].to(device)
141
+ labels = batch["labels"].to(device)
142
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
143
+ loss = outputs.loss
144
+ loss.backward()
145
+ optimizer.step()
146
+ total_loss += loss.item()
147
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
148
+ st.success(f"SFT Fine-tuning completed! πŸŽ‰ {random.choice(self.jokes)}")
149
+ return self
150
+
151
+ def save_model(self, path: str):
152
+ with st.spinner("Saving model... πŸ’Ύ"):
153
+ os.makedirs(os.path.dirname(path), exist_ok=True)
154
+ self.model.save_pretrained(path)
155
+ self.tokenizer.save_pretrained(path)
156
+ st.success(f"Model saved at {path}! βœ…")
157
+
158
+ def evaluate(self, prompt: str, status_container=None):
159
+ self.model.eval()
160
+ if status_container:
161
+ status_container.write("Preparing to evaluate... 🧠")
162
+ try:
163
+ with torch.no_grad():
164
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.model.device)
165
+ outputs = self.model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.95, temperature=0.7)
166
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
167
+ except Exception as e:
168
+ if status_container:
169
+ status_container.error(f"Oops! Something broke: {str(e)} πŸ’₯")
170
+ return f"Error: {str(e)}"
171
+
172
+ class DiffusionBuilder:
173
+ def __init__(self):
174
+ self.config = None
175
+ self.pipeline = None
176
+
177
+ def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None):
178
+ with st.spinner(f"Loading diffusion model {model_path}... ⏳"):
179
+ self.pipeline = StableDiffusionPipeline.from_pretrained(model_path)
180
+ self.pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
181
+ if config:
182
+ self.config = config
183
+ st.success(f"Diffusion model loaded! 🎨")
184
+ return self
185
+
186
+ def fine_tune_sft(self, images, texts, epochs=3):
187
+ dataset = DiffusionDataset(images, texts)
188
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
189
+ optimizer = torch.optim.AdamW(self.pipeline.unet.parameters(), lr=1e-5)
190
+
191
+ self.pipeline.unet.train()
192
+ for epoch in range(epochs):
193
+ with st.spinner(f"Training diffusion epoch {epoch + 1}/{epochs}... βš™οΈ"):
194
+ total_loss = 0
195
+ for batch in dataloader:
196
+ optimizer.zero_grad()
197
+ image = batch["image"].to(self.pipeline.device)
198
+ text = batch["text"]
199
+ latents = self.pipeline.vae.encode(image).latent_dist.sample()
200
+ noise = torch.randn_like(latents)
201
+ timesteps = torch.randint(0, self.pipeline.scheduler.num_train_timesteps, (latents.shape[0],), device=latents.device)
202
+ noisy_latents = self.pipeline.scheduler.add_noise(latents, noise, timesteps)
203
+ text_embeddings = self.pipeline.text_encoder(self.pipeline.tokenizer(text, return_tensors="pt").input_ids.to(self.pipeline.device))[0]
204
+ pred_noise = self.pipeline.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
205
+ loss = torch.nn.functional.mse_loss(pred_noise, noise)
206
+ loss.backward()
207
+ optimizer.step()
208
+ total_loss += loss.item()
209
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
210
+ st.success("Diffusion SFT Fine-tuning completed! 🎨")
211
+ return self
212
+
213
+ def save_model(self, path: str):
214
+ with st.spinner("Saving diffusion model... πŸ’Ύ"):
215
+ os.makedirs(os.path.dirname(path), exist_ok=True)
216
+ self.pipeline.save_pretrained(path)
217
+ st.success(f"Diffusion model saved at {path}! βœ…")
218
+
219
+ # Utility Functions
220
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
221
+ with open(file_path, 'rb') as f:
222
+ data = f.read()
223
+ b64 = base64.b64encode(data).decode()
224
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} πŸ“₯</a>'
225
+
226
+ def zip_directory(directory_path, zip_path):
227
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
228
+ for root, _, files in os.walk(directory_path):
229
+ for file in files:
230
+ file_path = os.path.join(root, file)
231
+ arcname = os.path.relpath(file_path, os.path.dirname(directory_path))
232
+ zipf.write(file_path, arcname)
233
+
234
+ def get_model_files(model_type="causal_lm"):
235
+ path = "models/*" if model_type == "causal_lm" else "diffusion_models/*"
236
+ return [d for d in glob.glob(path) if os.path.isdir(d)]
237
+
238
+ def get_gallery_files(file_types):
239
+ files = []
240
+ for ext in file_types:
241
+ files.extend(glob.glob(f"*.{ext}"))
242
+ return sorted(files)
243
+
244
+ def generate_filename(text_line):
245
+ central = pytz.timezone('US/Central')
246
+ timestamp = datetime.now(central).strftime("%Y%m%d_%I%M%S_%p")
247
+ safe_text = ''.join(c if c.isalnum() else '_' for c in text_line[:50])
248
+ return f"{timestamp}_{safe_text}.png"
249
+
250
+ def display_search_links(query):
251
+ search_urls = {
252
+ "ArXiv": f"https://arxiv.org/search/?query={quote(query)}",
253
+ "Wikipedia": f"https://en.wikipedia.org/wiki/{quote(query)}",
254
+ "Google": f"https://www.google.com/search?q={quote(query)}",
255
+ "YouTube": f"https://www.youtube.com/results?search_query={quote(query)}"
256
+ }
257
+ links_md = ' '.join([f"[{name}]({url})" for name, url in search_urls.items()])
258
+ return links_md
259
+
260
+ # Agent Class
261
+ class PartyPlannerAgent:
262
+ def __init__(self, model, tokenizer):
263
+ self.model = model
264
+ self.tokenizer = tokenizer
265
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
266
+ self.model.to(self.device)
267
+
268
+ def generate(self, prompt: str) -> str:
269
+ self.model.eval()
270
+ with torch.no_grad():
271
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=128, truncation=True).to(self.device)
272
+ outputs = self.model.generate(**inputs, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.7)
273
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
274
+
275
+ def plan_party(self, task: str) -> pd.DataFrame:
276
+ search_result = "Latest trends for 2025: Gold-plated Batman statues, VR superhero battles."
277
+ prompt = f"Given this context: '{search_result}'\n{task}"
278
+ plan_text = self.generate(prompt)
279
+ st.markdown(f"Search Links: {display_search_links('superhero party trends')}", unsafe_allow_html=True)
280
+
281
+ locations = {"Wayne Manor": (42.3601, -71.0589), "New York": (40.7128, -74.0060), "Los Angeles": (34.0522, -118.2437), "London": (51.5074, -0.1278)}
282
+ wayne_coords = locations["Wayne Manor"]
283
+ travel_times = {loc: calculate_cargo_travel_time(coords, wayne_coords) for loc, coords in locations.items() if loc != "Wayne Manor"}
284
+
285
+ data = [
286
+ {"Location": "New York", "Travel Time (hrs)": travel_times["New York"], "Luxury Idea": "Gold-plated Batman statues"},
287
+ {"Location": "Los Angeles", "Travel Time (hrs)": travel_times["Los Angeles"], "Luxury Idea": "VR superhero battles"},
288
+ {"Location": "London", "Travel Time (hrs)": travel_times["London"], "Luxury Idea": "Live stunt shows"},
289
+ {"Location": "Wayne Manor", "Travel Time (hrs)": 0.0, "Luxury Idea": "Holographic displays"}
290
+ ]
291
+ return pd.DataFrame(data)
292
+
293
+ def calculate_cargo_travel_time(origin_coords: Tuple[float, float], destination_coords: Tuple[float, float], cruising_speed_kmh: float = 750.0) -> float:
294
+ def to_radians(degrees: float) -> float:
295
+ return degrees * (math.pi / 180)
296
+ lat1, lon1 = map(to_radians, origin_coords)
297
+ lat2, lon2 = map(to_radians, destination_coords)
298
+ EARTH_RADIUS_KM = 6371.0
299
+ dlon = lon2 - lon1
300
+ dlat = lat2 - lat1
301
+ a = (math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2)
302
+ c = 2 * math.asin(math.sqrt(a))
303
+ distance = EARTH_RADIUS_KM * c
304
+ actual_distance = distance * 1.1
305
+ flight_time = (actual_distance / cruising_speed_kmh) + 1.0
306
+ return round(flight_time, 2)
307
+
308
+ # Main App
309
+ st.title("SFT Tiny Titans πŸš€ (Small but Mighty!)")
310
+
311
+ # Sidebar Galleries
312
+ st.sidebar.header("Galleries 🎨")
313
+ for gallery_type, file_types in [
314
+ ("Image Gallery πŸ“Έ", ["png", "jpg", "jpeg"]),
315
+ ("Video Gallery πŸŽ₯", ["mp4"]),
316
+ ("Audio Gallery 🎢", ["mp3"])
317
+ ]:
318
+ st.sidebar.subheader(gallery_type)
319
+ files = get_gallery_files(file_types)
320
+ if files:
321
+ cols_num = st.sidebar.slider(f"{gallery_type} Columns", 1, 5, 3, key=f"{gallery_type}_cols")
322
+ cols = st.sidebar.columns(cols_num)
323
+ for idx, file in enumerate(files[:cols_num * 2]):
324
+ with cols[idx % cols_num]:
325
+ if "Image" in gallery_type:
326
+ st.image(Image.open(file), caption=file, use_column_width=True)
327
+ elif "Video" in gallery_type:
328
+ st.video(file)
329
+ elif "Audio" in gallery_type:
330
+ st.audio(file)
331
+
332
+ st.sidebar.subheader("Model Management πŸ—‚οΈ")
333
+ model_type = st.sidebar.selectbox("Model Type", ["Causal LM", "Diffusion"])
334
+ model_dirs = get_model_files("causal_lm" if model_type == "Causal LM" else "diffusion")
335
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
336
+ if selected_model != "None" and st.sidebar.button("Load Model πŸ“‚"):
337
+ if 'builder' not in st.session_state:
338
+ st.session_state['builder'] = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
339
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=os.path.basename(selected_model), base_model="unknown", size="small")
340
+ st.session_state['builder'].load_model(selected_model, config)
341
+ st.session_state['model_loaded'] = True
342
+ st.rerun()
343
+
344
+ # Tabs
345
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Build Tiny Titan 🌱", "Fine-Tune Titan πŸ”§", "Test Titan πŸ§ͺ", "Agentic RAG Party 🌐", "Diffusion SFT 🎨"])
346
+
347
+ with tab1:
348
+ st.header("Build Tiny Titan 🌱")
349
+ model_type = st.selectbox("Model Type", ["Causal LM", "Diffusion"], key="build_type")
350
+ if model_type == "Causal LM":
351
+ base_model = st.selectbox("Select Tiny Model", ["HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-360M", "Qwen/Qwen1.5-0.5B-Chat"])
352
+ else:
353
+ base_model = st.selectbox("Select Tiny Diffusion Model", ["stabilityai/stable-diffusion-2-1", "runwayml/stable-diffusion-v1-5", "CompVis/stable-diffusion-v1-4"])
354
+ model_name = st.text_input("Model Name", f"tiny-titan-{int(time.time())}")
355
+ if st.button("Download Model ⬇️"):
356
+ config = (ModelConfig if model_type == "Causal LM" else DiffusionConfig)(name=model_name, base_model=base_model, size="small")
357
+ builder = ModelBuilder() if model_type == "Causal LM" else DiffusionBuilder()
358
+ builder.load_model(base_model, config)
359
+ builder.save_model(config.model_path)
360
+ st.session_state['builder'] = builder
361
+ st.session_state['model_loaded'] = True
362
+ st.rerun()
363
+
364
+ with tab2:
365
+ st.header("Fine-Tune Titan πŸ”§")
366
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
367
+ st.warning("Please build or load a Titan first! ⚠️")
368
+ else:
369
+ if isinstance(st.session_state['builder'], ModelBuilder):
370
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
371
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV πŸ”„"):
372
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
373
+ with open(csv_path, "wb") as f:
374
+ f.write(uploaded_csv.read())
375
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
376
+ new_config = ModelConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
377
+ st.session_state['builder'].config = new_config
378
+ st.session_state['builder'].fine_tune_sft(csv_path)
379
+ st.session_state['builder'].save_model(new_config.model_path)
380
+ zip_path = f"{new_config.model_path}.zip"
381
+ zip_directory(new_config.model_path, zip_path)
382
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Titan"), unsafe_allow_html=True)
383
+
384
+ with tab3:
385
+ st.header("Test Titan πŸ§ͺ")
386
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
387
+ st.warning("Please build or load a Titan first! ⚠️")
388
+ else:
389
+ if isinstance(st.session_state['builder'], ModelBuilder):
390
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
391
+ if st.button("Run Test ▢️"):
392
+ result = st.session_state['builder'].evaluate(test_prompt)
393
+ st.write(f"**Generated Response**: {result}")
394
+
395
+ with tab4:
396
+ st.header("Agentic RAG Party 🌐")
397
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], ModelBuilder):
398
+ st.warning("Please build or load a Causal LM Titan first! ⚠️")
399
+ else:
400
+ if st.button("Run Agentic RAG Demo πŸŽ‰"):
401
+ agent = PartyPlannerAgent(model=st.session_state['builder'].model, tokenizer=st.session_state['builder'].tokenizer)
402
+ task = "Plan a luxury superhero-themed party at Wayne Manor."
403
+ plan_df = agent.plan_party(task)
404
+ st.dataframe(plan_df)
405
+
406
+ with tab5:
407
+ st.header("Diffusion SFT 🎨")
408
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False) or not isinstance(st.session_state['builder'], DiffusionBuilder):
409
+ st.warning("Please build or load a Diffusion Titan first! ⚠️")
410
+ else:
411
+ uploaded_files = st.file_uploader("Upload Images/Videos", type=["png", "jpg", "jpeg", "mp4", "mp3"], accept_multiple_files=True)
412
+ text_input = st.text_area("Enter Text (one line per image)", "Line 1\nLine 2\nLine 3")
413
+ if uploaded_files and st.button("Fine-Tune Diffusion Model πŸ”„"):
414
+ images = [Image.open(f) for f in uploaded_files if f.type.startswith("image")]
415
+ texts = text_input.splitlines()
416
+ if len(images) > len(texts):
417
+ texts.extend([""] * (len(images) - len(texts)))
418
+ elif len(texts) > len(images):
419
+ texts = texts[:len(images)]
420
+
421
+ st.session_state['builder'].fine_tune_sft(images, texts)
422
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
423
+ new_config = DiffusionConfig(name=new_model_name, base_model=st.session_state['builder'].config.base_model, size="small")
424
+ st.session_state['builder'].config = new_config
425
+ st.session_state['builder'].save_model(new_config.model_path)
426
+
427
+ for img, text in zip(images, texts):
428
+ filename = generate_filename(text)
429
+ img.save(filename)
430
+ st.image(img, caption=filename)
431
+ zip_path = f"{new_config.model_path}.zip"
432
+ zip_directory(new_config.model_path, zip_path)
433
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Diffusion Model"), unsafe_allow_html=True)