|
import io |
|
import os |
|
import re |
|
import base64 |
|
import glob |
|
import logging |
|
import random |
|
import shutil |
|
import time |
|
import zipfile |
|
import json |
|
import asyncio |
|
import aiofiles |
|
import toml |
|
from datetime import datetime |
|
from collections import Counter |
|
from dataclasses import dataclass, field |
|
from io import BytesIO |
|
from typing import Optional, List, Dict, Any |
|
import pandas as pd |
|
import pytz |
|
import streamlit as st |
|
from PIL import Image, ImageDraw |
|
from reportlab.pdfgen import canvas |
|
from reportlab.lib.utils import ImageReader |
|
from reportlab.lib.pagesizes import letter |
|
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, PageBreak |
|
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle |
|
from reportlab.lib.enums import TA_JUSTIFY |
|
import fitz |
|
import requests |
|
try: |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq, pipeline |
|
_transformers_available = True |
|
except ImportError: |
|
_transformers_available = False |
|
st.sidebar.warning("AI/ML libraries (torch, transformers) not found. Local model features disabled.") |
|
try: |
|
from diffusers import StableDiffusionPipeline |
|
_diffusers_available = True |
|
except ImportError: |
|
_diffusers_available = False |
|
if _transformers_available: |
|
st.sidebar.warning("Diffusers library not found. Diffusion model features disabled.") |
|
try: |
|
from openai import OpenAI |
|
_openai_available = True |
|
except ImportError: |
|
_openai_available = False |
|
st.sidebar.warning("OpenAI library not found. OpenAI model features disabled.") |
|
from huggingface_hub import InferenceClient, HfApi, list_models |
|
from huggingface_hub.utils import RepositoryNotFoundError, GatedRepoError |
|
|
|
|
|
st.set_page_config( |
|
page_title="Vision & Layout Titans ππΌοΈ", |
|
page_icon="π€", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
menu_items={ |
|
'Get Help': 'https://huggingface.co/docs', |
|
'Report a Bug': None, |
|
'About': "Combined App: Image/MD->PDF Layout + AI-Powered Tools π" |
|
} |
|
) |
|
|
|
|
|
try: |
|
secrets = toml.load(".streamlit/secrets.toml") if os.path.exists(".streamlit/secrets.toml") else {} |
|
HF_TOKEN = secrets.get("HF_TOKEN", os.getenv("HF_TOKEN", "")) |
|
OPENAI_API_KEY = secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY", "")) |
|
except Exception as e: |
|
st.error(f"Error loading secrets: {e}") |
|
HF_TOKEN = os.getenv("HF_TOKEN", "") |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") |
|
|
|
if not HF_TOKEN: |
|
st.sidebar.warning("Hugging Face token not found in secrets or environment. Some features may be limited.") |
|
if not OPENAI_API_KEY and _openai_available: |
|
st.sidebar.warning("OpenAI API key not found in secrets or environment. OpenAI features disabled.") |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
log_records = [] |
|
class LogCaptureHandler(logging.Handler): |
|
def emit(self, record): |
|
log_records.append(record) |
|
logger.addHandler(LogCaptureHandler()) |
|
|
|
|
|
DEFAULT_PROVIDER = "hf-inference" |
|
FEATURED_MODELS_LIST = [ |
|
"meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
"mistralai/Mistral-7B-Instruct-v0.3", |
|
"google/gemma-2-9b-it", |
|
"Qwen/Qwen2-7B-Instruct", |
|
"microsoft/Phi-3-mini-4k-instruct", |
|
"HuggingFaceH4/zephyr-7b-beta", |
|
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", |
|
"HuggingFaceTB/SmolLM-1.7B-Instruct" |
|
] |
|
VISION_MODELS_LIST = [ |
|
"Salesforce/blip-image-captioning-large", |
|
"microsoft/trocr-large-handwritten", |
|
"llava-hf/llava-1.5-7b-hf", |
|
"google/vit-base-patch16-224" |
|
] |
|
DIFFUSION_MODELS_LIST = [ |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
"runwayml/stable-diffusion-v1-5", |
|
"OFA-Sys/small-stable-diffusion-v0" |
|
] |
|
OPENAI_MODELS_LIST = [ |
|
"gpt-4o", |
|
"gpt-4-turbo", |
|
"gpt-3.5-turbo", |
|
"text-davinci-003" |
|
] |
|
st.session_state.setdefault('local_models', {}) |
|
st.session_state.setdefault('hf_inference_client', None) |
|
st.session_state.setdefault('openai_client', None) |
|
if _openai_available and OPENAI_API_KEY: |
|
try: |
|
st.session_state['openai_client'] = OpenAI(api_key=OPENAI_API_KEY) |
|
logger.info("OpenAI client initialized successfully.") |
|
except Exception as e: |
|
st.error(f"Failed to initialize OpenAI client: {e}") |
|
logger.error(f"OpenAI client initialization failed: {e}") |
|
st.session_state['openai_client'] = None |
|
|
|
|
|
st.session_state.setdefault('layout_snapshots', []) |
|
st.session_state.setdefault('layout_new_uploads', []) |
|
st.session_state.setdefault('history', []) |
|
st.session_state.setdefault('processing', {}) |
|
st.session_state.setdefault('asset_checkboxes', {'image': {}, 'md': {}, 'pdf': {}}) |
|
st.session_state.setdefault('downloaded_pdfs', {}) |
|
st.session_state.setdefault('unique_counter', 0) |
|
st.session_state.setdefault('cam0_file', None) |
|
st.session_state.setdefault('cam1_file', None) |
|
st.session_state.setdefault('characters', []) |
|
st.session_state.setdefault('char_form_reset_key', 0) |
|
st.session_state.setdefault('gallery_size', 10) |
|
st.session_state.setdefault('hf_provider', DEFAULT_PROVIDER) |
|
st.session_state.setdefault('hf_custom_key', "") |
|
st.session_state.setdefault('hf_selected_api_model', FEATURED_MODELS_LIST[0]) |
|
st.session_state.setdefault('hf_custom_api_model', "") |
|
st.session_state.setdefault('openai_selected_model', OPENAI_MODELS_LIST[0] if _openai_available else "") |
|
st.session_state.setdefault('selected_local_model_path', None) |
|
st.session_state.setdefault('gen_max_tokens', 512) |
|
st.session_state.setdefault('gen_temperature', 0.7) |
|
st.session_state.setdefault('gen_top_p', 0.95) |
|
st.session_state.setdefault('gen_frequency_penalty', 0.0) |
|
if 'asset_gallery_container' not in st.session_state: |
|
st.session_state['asset_gallery_container'] = {'image': st.sidebar.empty(), 'md': st.sidebar.empty(), 'pdf': st.sidebar.empty()} |
|
|
|
|
|
@dataclass |
|
class LocalModelConfig: |
|
name: str |
|
hf_id: str |
|
model_type: str |
|
size_category: str = "unknown" |
|
domain: Optional[str] = None |
|
local_path: str = field(init=False) |
|
def __post_init__(self): |
|
type_folder = f"{self.model_type}_models" |
|
safe_name = re.sub(r'[^\w\-]+', '_', self.name) |
|
self.local_path = os.path.join(type_folder, safe_name) |
|
def get_full_path(self): |
|
return os.path.abspath(self.local_path) |
|
|
|
@dataclass |
|
class DiffusionConfig: |
|
name: str |
|
base_model: str |
|
size: str |
|
domain: Optional[str] = None |
|
@property |
|
def model_path(self): |
|
return f"diffusion_models/{self.name}" |
|
|
|
|
|
def generate_filename(sequence, ext="png"): |
|
timestamp = time.strftime('%Y%m%d_%H%M%S') |
|
safe_sequence = re.sub(r'[^\w\-]+', '_', str(sequence)) |
|
return f"{safe_sequence}_{timestamp}.{ext}" |
|
|
|
def pdf_url_to_filename(url): |
|
name = re.sub(r'^https?://', '', url) |
|
name = re.sub(r'[<>:"/\\|?*]', '_', name) |
|
return name[:100] + ".pdf" |
|
|
|
def get_download_link(file_path, mime_type="application/octet-stream", label="Download"): |
|
if not os.path.exists(file_path): |
|
return f"{label} (File not found)" |
|
try: |
|
with open(file_path, "rb") as f: |
|
file_bytes = f.read() |
|
b64 = base64.b64encode(file_bytes).decode() |
|
return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>' |
|
except Exception as e: |
|
logger.error(f"Error creating download link for {file_path}: {e}") |
|
return f"{label} (Error)" |
|
|
|
def zip_directory(directory_path, zip_path): |
|
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
|
for root, _, files in os.walk(directory_path): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
zipf.write(file_path, os.path.relpath(file_path, os.path.dirname(directory_path))) |
|
|
|
def get_local_model_paths(model_type="causal"): |
|
pattern = f"{model_type}_models/*" |
|
dirs = [d for d in glob.glob(pattern) if os.path.isdir(d)] |
|
return dirs |
|
|
|
def get_gallery_files(file_types=("png", "pdf", "jpg", "jpeg", "md", "txt")): |
|
all_files = set() |
|
for ext in file_types: |
|
all_files.update(glob.glob(f"*.{ext.lower()}")) |
|
all_files.update(glob.glob(f"*.{ext.upper()}")) |
|
return sorted([f for f in all_files if os.path.basename(f).lower() != 'readme.md']) |
|
|
|
def get_typed_gallery_files(file_type): |
|
if file_type == 'image': |
|
return get_gallery_files(('png', 'jpg', 'jpeg')) |
|
elif file_type == 'md': |
|
return get_gallery_files(('md',)) |
|
elif file_type == 'pdf': |
|
return get_gallery_files(('pdf',)) |
|
return [] |
|
|
|
def download_pdf(url, output_path): |
|
try: |
|
headers = {'User-Agent': 'Mozilla/5.0'} |
|
response = requests.get(url, stream=True, timeout=20, headers=headers) |
|
response.raise_for_status() |
|
with open(output_path, "wb") as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
logger.info(f"Successfully downloaded {url} to {output_path}") |
|
return True |
|
except requests.exceptions.RequestException as e: |
|
logger.error(f"Failed to download {url}: {e}") |
|
if os.path.exists(output_path): |
|
try: |
|
os.remove(output_path) |
|
except: |
|
pass |
|
return False |
|
except Exception as e: |
|
logger.error(f"An unexpected error occurred during download of {url}: {e}") |
|
if os.path.exists(output_path): |
|
try: |
|
os.remove(output_path) |
|
except: |
|
pass |
|
return False |
|
|
|
async def process_pdf_snapshot(pdf_path, mode="single", resolution_factor=2.0): |
|
start_time = time.time() |
|
status_placeholder = st.empty() |
|
status_placeholder.text(f"Processing PDF Snapshot ({mode}, Res: {resolution_factor}x)... (0s)") |
|
output_files = [] |
|
try: |
|
doc = fitz.open(pdf_path) |
|
matrix = fitz.Matrix(resolution_factor, resolution_factor) |
|
num_pages_to_process = min(1, len(doc)) if mode == "single" else min(2, len(doc)) if mode == "twopage" else len(doc) |
|
for i in range(num_pages_to_process): |
|
page_start_time = time.time() |
|
page = doc[i] |
|
pix = page.get_pixmap(matrix=matrix) |
|
base_name = os.path.splitext(os.path.basename(pdf_path))[0] |
|
output_file = generate_filename(f"{base_name}_pg{i+1}_{mode}", "png") |
|
await asyncio.to_thread(pix.save, output_file) |
|
output_files.append(output_file) |
|
elapsed_page = int(time.time() - page_start_time) |
|
status_placeholder.text(f"Processing PDF Snapshot ({mode}, Res: {resolution_factor}x)... Page {i+1}/{num_pages_to_process} done ({elapsed_page}s)") |
|
await asyncio.sleep(0.01) |
|
doc.close() |
|
elapsed = int(time.time() - start_time) |
|
status_placeholder.success(f"PDF Snapshot ({mode}, {len(output_files)} files) completed in {elapsed}s!") |
|
return output_files |
|
except Exception as e: |
|
logger.error(f"Failed to process PDF snapshot for {pdf_path}: {e}") |
|
status_placeholder.error(f"Failed to process PDF {os.path.basename(pdf_path)}: {e}") |
|
for f in output_files: |
|
if os.path.exists(f): |
|
os.remove(f) |
|
return [] |
|
|
|
def get_hf_client() -> Optional[InferenceClient]: |
|
provider = st.session_state.hf_provider |
|
custom_key = st.session_state.hf_custom_key.strip() |
|
token_to_use = custom_key if custom_key else HF_TOKEN |
|
if not token_to_use and provider != "hf-inference": |
|
st.error(f"Provider '{provider}' requires a Hugging Face API token.") |
|
return None |
|
if provider == "hf-inference" and not token_to_use: |
|
logger.warning("Using hf-inference provider without a token. Rate limits may apply.") |
|
token_to_use = None |
|
current_client = st.session_state.get('hf_inference_client') |
|
needs_reinit = True |
|
if current_client: |
|
client_uses_custom = hasattr(current_client, '_token') and current_client._token == custom_key |
|
client_uses_default = hasattr(current_client, '_token') and current_client._token == HF_TOKEN |
|
client_uses_no_token = not hasattr(current_client, '_token') or current_client._token is None |
|
if current_client.provider == provider: |
|
if custom_key and client_uses_custom: |
|
needs_reinit = False |
|
elif not custom_key and HF_TOKEN and client_uses_default: |
|
needs_reinit = False |
|
elif not custom_key and not HF_TOKEN and client_uses_no_token: |
|
needs_reinit = False |
|
if needs_reinit: |
|
try: |
|
logger.info(f"Initializing InferenceClient for provider: {provider}.") |
|
st.session_state.hf_inference_client = InferenceClient(token=token_to_use, provider=provider) |
|
logger.info("InferenceClient initialized successfully.") |
|
except Exception as e: |
|
st.error(f"Failed to initialize Hugging Face client: {e}") |
|
logger.error(f"InferenceClient initialization failed: {e}") |
|
st.session_state.hf_inference_client = None |
|
return st.session_state.hf_inference_client |
|
|
|
def process_text_hf(text: str, prompt: str, use_api: bool, model_id: str = None) -> str: |
|
status_placeholder = st.empty() |
|
start_time = time.time() |
|
result_text = "" |
|
params = { |
|
"max_new_tokens": st.session_state.gen_max_tokens, |
|
"temperature": st.session_state.gen_temperature, |
|
"top_p": st.session_state.gen_top_p, |
|
"repetition_penalty": st.session_state.gen_frequency_penalty + 1.0, |
|
} |
|
seed = st.session_state.gen_seed |
|
if seed != -1: |
|
params["seed"] = seed |
|
system_prompt = "You are a helpful assistant. Process the following text based on the user's request." |
|
full_prompt = f"{prompt}\n\n---\n\n{text}" |
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": full_prompt} |
|
] |
|
if use_api: |
|
status_placeholder.info("Processing text using Hugging Face API...") |
|
client = get_hf_client() |
|
if not client: |
|
return "Error: Hugging Face client not available." |
|
model_id = model_id or st.session_state.hf_custom_api_model.strip() or st.session_state.hf_selected_api_model |
|
status_placeholder.info(f"Using API Model: {model_id}") |
|
try: |
|
response = client.chat_completion( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=params['max_new_tokens'], |
|
temperature=params['temperature'], |
|
top_p=params['top_p'], |
|
) |
|
result_text = response.choices[0].message.content or "" |
|
logger.info(f"HF API text processing successful for model {model_id}.") |
|
except Exception as e: |
|
logger.error(f"HF API text processing failed for model {model_id}: {e}") |
|
result_text = f"Error during Hugging Face API inference: {str(e)}" |
|
else: |
|
status_placeholder.info("Processing text using local model...") |
|
if not _transformers_available: |
|
return "Error: Transformers library not available." |
|
model_path = st.session_state.get('selected_local_model_path') |
|
if not model_path or model_path not in st.session_state.get('local_models', {}): |
|
return "Error: No suitable local model selected." |
|
local_model_data = st.session_state['local_models'][model_path] |
|
if local_model_data.get('type') != 'causal': |
|
return f"Error: Loaded model '{os.path.basename(model_path)}' is not a Causal LM." |
|
status_placeholder.info(f"Using Local Model: {os.path.basename(model_path)}") |
|
model = local_model_data.get('model') |
|
tokenizer = local_model_data.get('tokenizer') |
|
if not model or not tokenizer: |
|
return f"Error: Model or tokenizer not found for {os.path.basename(model_path)}." |
|
try: |
|
try: |
|
prompt_for_model = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
except Exception: |
|
logger.warning(f"Could not apply chat template for {model_path}. Using basic formatting.") |
|
prompt_for_model = f"System: {system_prompt}\nUser: {full_prompt}\nAssistant:" |
|
inputs = tokenizer(prompt_for_model, return_tensors="pt", padding=True, truncation=True, max_length=params['max_new_tokens'] * 2) |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
generate_params = { |
|
"max_new_tokens": params['max_new_tokens'], |
|
"temperature": params['temperature'], |
|
"top_p": params['top_p'], |
|
"repetition_penalty": params.get('repetition_penalty', 1.0), |
|
"do_sample": True if params['temperature'] > 0.1 else False, |
|
"pad_token_id": tokenizer.eos_token_id |
|
} |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, **generate_params) |
|
input_length = inputs['input_ids'].shape[1] |
|
generated_ids = outputs[0][input_length:] |
|
result_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
logger.info(f"Local text processing successful for model {model_path}.") |
|
except Exception as e: |
|
logger.error(f"Local text processing failed for model {model_path}: {e}") |
|
result_text = f"Error during local model inference: {str(e)}" |
|
elapsed = int(time.time() - start_time) |
|
status_placeholder.success(f"Text processing completed in {elapsed}s.") |
|
return result_text |
|
|
|
def process_text_openai(text: str, prompt: str, model_id: str) -> str: |
|
if not _openai_available or not st.session_state.get('openai_client'): |
|
return "Error: OpenAI client not available or API key missing." |
|
status_placeholder = st.empty() |
|
start_time = time.time() |
|
client = st.session_state['openai_client'] |
|
system_prompt = "You are a helpful assistant. Process the following text based on the user's request." |
|
full_prompt = f"{prompt}\n\n---\n\n{text}" |
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": full_prompt} |
|
] |
|
status_placeholder.info(f"Processing text using OpenAI model: {model_id}...") |
|
try: |
|
response = client.chat.completions.create( |
|
model=model_id, |
|
messages=messages, |
|
max_tokens=st.session_state.gen_max_tokens, |
|
temperature=st.session_state.gen_temperature, |
|
top_p=st.session_state.gen_top_p, |
|
) |
|
result_text = response.choices[0].message.content or "" |
|
logger.info(f"OpenAI text processing successful for model {model_id}.") |
|
except Exception as e: |
|
logger.error(f"OpenAI text processing failed for model {model_id}: {e}") |
|
result_text = f"Error during OpenAI inference: {str(e)}" |
|
elapsed = int(time.time() - start_time) |
|
status_placeholder.success(f"Text processing completed in {elapsed}s.") |
|
return result_text |
|
|
|
def process_image_hf(image: Image.Image, prompt: str, use_api: bool, model_id: str = None) -> str: |
|
status_placeholder = st.empty() |
|
start_time = time.time() |
|
result_text = "" |
|
if use_api: |
|
status_placeholder.info("Processing image using Hugging Face API...") |
|
client = get_hf_client() |
|
if not client: |
|
return "Error: HF client not configured." |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG" if image.format != 'JPEG' else 'JPEG') |
|
img_bytes = buffered.getvalue() |
|
model_id = model_id or "Salesforce/blip-image-captioning-large" |
|
status_placeholder.info(f"Using API Image-to-Text Model: {model_id}") |
|
try: |
|
response_list = client.image_to_text(data=img_bytes, model=model_id) |
|
if response_list and isinstance(response_list, list) and 'generated_text' in response_list[0]: |
|
result_text = response_list[0]['generated_text'] |
|
logger.info(f"HF API image captioning successful for model {model_id}.") |
|
else: |
|
result_text = "Error: Unexpected response format from image-to-text API." |
|
logger.warning(f"Unexpected API response for image-to-text: {response_list}") |
|
except Exception as e: |
|
logger.error(f"HF API image processing failed: {e}") |
|
result_text = f"Error during Hugging Face API image inference: {str(e)}" |
|
else: |
|
status_placeholder.info("Processing image using local model...") |
|
if not _transformers_available: |
|
return "Error: Transformers library needed." |
|
model_path = st.session_state.get('selected_local_model_path') |
|
if not model_path or model_path not in st.session_state.get('local_models', {}): |
|
return "Error: No suitable local model selected." |
|
local_model_data = st.session_state['local_models'][model_path] |
|
model_type = local_model_data.get('type') |
|
if model_type == 'vision': |
|
processor = local_model_data.get('processor') |
|
model = local_model_data.get('model') |
|
if processor and model: |
|
try: |
|
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device) |
|
generated_ids = model.generate(**inputs, max_new_tokens=st.session_state.gen_max_tokens) |
|
result_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
|
except Exception as e: |
|
result_text = f"Error during local vision model inference: {e}" |
|
else: |
|
result_text = "Error: Processor or model missing for local vision task." |
|
elif model_type == 'ocr': |
|
processor = local_model_data.get('processor') |
|
model = local_model_data.get('model') |
|
if processor and model: |
|
try: |
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(model.device) |
|
generated_ids = model.generate(pixel_values, max_new_tokens=st.session_state.gen_max_tokens) |
|
result_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
except Exception as e: |
|
result_text = f"Error during local OCR model inference: {e}" |
|
else: |
|
result_text = "Error: Processor or model missing for local OCR task." |
|
else: |
|
result_text = f"Error: Loaded model '{os.path.basename(model_path)}' is not a recognized vision/OCR type." |
|
elapsed = int(time.time() - start_time) |
|
status_placeholder.success(f"Image processing completed in {elapsed}s.") |
|
return result_text |
|
|
|
def process_image_openai(image: Image.Image, prompt: str, model_id: str = "gpt-4o") -> str: |
|
if not _openai_available or not st.session_state.get('openai_client'): |
|
return "Error: OpenAI client not available or API key missing." |
|
status_placeholder = st.empty() |
|
start_time = time.time() |
|
client = st.session_state['openai_client'] |
|
buffered = BytesIO() |
|
image.save(buffered, format="PNG") |
|
img_b64 = base64.b64encode(buffered.getvalue()).decode() |
|
status_placeholder.info(f"Processing image using OpenAI model: {model_id}...") |
|
try: |
|
response = client.chat.completions.create( |
|
model=model_id, |
|
messages=[ |
|
{"role": "user", "content": [ |
|
{"type": "text", "text": prompt}, |
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}} |
|
]} |
|
], |
|
max_tokens=st.session_state.gen_max_tokens, |
|
temperature=st.session_state.gen_temperature, |
|
) |
|
result_text = response.choices[0].message.content or "" |
|
logger.info(f"OpenAI image processing successful for model {model_id}.") |
|
except Exception as e: |
|
logger.error(f"OpenAI image processing failed for model {model_id}: {e}") |
|
result_text = f"Error during OpenAI image inference: {str(e)}" |
|
elapsed = int(time.time() - start_time) |
|
status_placeholder.success(f"Image processing completed in {elapsed}s.") |
|
return result_text |
|
|
|
async def process_hf_ocr(image: Image.Image, output_file: str, use_api: bool, model_id: str = None) -> str: |
|
ocr_prompt = "Extract text content from this image." |
|
result = process_image_hf(image, ocr_prompt, use_api, model_id=model_id or "microsoft/trocr-large-handwritten") |
|
if result and not result.startswith("Error") and not result.startswith("["): |
|
try: |
|
async with aiofiles.open(output_file, "w", encoding='utf-8') as f: |
|
await f.write(result) |
|
logger.info(f"HF OCR result saved to {output_file}") |
|
except IOError as e: |
|
logger.error(f"Failed to save HF OCR output to {output_file}: {e}") |
|
result += f"\n[Error saving file: {e}]" |
|
elif os.path.exists(output_file): |
|
try: |
|
os.remove(output_file) |
|
except OSError: |
|
pass |
|
return result |
|
|
|
async def process_openai_ocr(image: Image.Image, output_file: str, model_id: str = "gpt-4o") -> str: |
|
ocr_prompt = "Extract text content from this image." |
|
result = process_image_openai(image, ocr_prompt, model_id) |
|
if result and not result.startswith("Error"): |
|
try: |
|
async with aiofiles.open(output_file, "w", encoding='utf-8') as f: |
|
await f.write(result) |
|
logger.info(f"OpenAI OCR result saved to {output_file}") |
|
except IOError as e: |
|
logger.error(f"Failed to save OpenAI OCR output to {output_file}: {e}") |
|
result += f"\n[Error saving file: {e}]" |
|
elif os.path.exists(output_file): |
|
try: |
|
os.remove(output_file) |
|
except OSError: |
|
pass |
|
return result |
|
|
|
def randomize_character_content(): |
|
intro_templates = [ |
|
"{char} is a valiant knight...", "{char} is a mischievous thief...", |
|
"{char} is a wise scholar...", "{char} is a fiery warrior...", "{char} is a gentle healer..." |
|
] |
|
greeting_templates = [ |
|
"'I am from the knight's guild...'", "'I heard you needed helpβnameβs {char}...", |
|
"'Oh, hello! Iβm {char}, didnβt see you there...'", "'Iβm {char}, and Iβm here to fight...'", |
|
"'Iβm {char}, here to heal...'" |
|
] |
|
name = f"Character_{random.randint(1000, 9999)}" |
|
gender = random.choice(["Male", "Female"]) |
|
intro = random.choice(intro_templates).format(char=name) |
|
greeting = random.choice(greeting_templates).format(char=name) |
|
return name, gender, intro, greeting |
|
|
|
def save_character(character_data): |
|
characters = st.session_state.get('characters', []) |
|
if any(c['name'] == character_data['name'] for c in characters): |
|
st.error(f"Character name '{character_data['name']}' already exists.") |
|
return False |
|
characters.append(character_data) |
|
st.session_state['characters'] = characters |
|
try: |
|
with open("characters.json", "w", encoding='utf-8') as f: |
|
json.dump(characters, f, indent=2) |
|
logger.info(f"Saved character: {character_data['name']}") |
|
return True |
|
except IOError as e: |
|
logger.error(f"Failed to save characters.json: {e}") |
|
st.error(f"Failed to save character file: {e}") |
|
return False |
|
|
|
def load_characters(): |
|
if not os.path.exists("characters.json"): |
|
st.session_state['characters'] = [] |
|
return |
|
try: |
|
with open("characters.json", "r", encoding='utf-8') as f: |
|
characters = json.load(f) |
|
if isinstance(characters, list): |
|
st.session_state['characters'] = characters |
|
logger.info(f"Loaded {len(characters)} characters.") |
|
else: |
|
st.session_state['characters'] = [] |
|
logger.warning("characters.json is not a list, resetting.") |
|
os.remove("characters.json") |
|
except (json.JSONDecodeError, IOError) as e: |
|
logger.error(f"Failed to load or decode characters.json: {e}") |
|
st.error(f"Error loading character file: {e}. Starting fresh.") |
|
st.session_state['characters'] = [] |
|
try: |
|
corrupt_filename = f"characters_corrupt_{int(time.time())}.json" |
|
shutil.copy("characters.json", corrupt_filename) |
|
logger.info(f"Backed up corrupted character file to {corrupt_filename}") |
|
os.remove("characters.json") |
|
except Exception as backup_e: |
|
logger.error(f"Could not backup corrupted character file: {backup_e}") |
|
|
|
def clean_stem(fn: str) -> str: |
|
name = os.path.splitext(os.path.basename(fn))[0] |
|
name = name.replace('-', ' ').replace('_', ' ') |
|
return name.strip().title() |
|
|
|
def make_image_sized_pdf(sources, is_markdown_flags): |
|
if not sources: |
|
st.warning("No sources provided for PDF generation.") |
|
return None |
|
buf = BytesIO() |
|
styles = getSampleStyleSheet() |
|
md_style = ParagraphStyle( |
|
name='Markdown', |
|
fontSize=10, |
|
leading=12, |
|
spaceAfter=6, |
|
alignment=TA_JUSTIFY, |
|
fontName='Helvetica' |
|
) |
|
doc = SimpleDocTemplate(buf, pagesize=letter, rightMargin=36, leftMargin=36, topMargin=36, bottomMargin=36) |
|
story = [] |
|
try: |
|
for idx, (src, is_md) in enumerate(zip(sources, is_markdown_flags), start=1): |
|
status_placeholder = st.empty() |
|
filename = 'page_' + str(idx) |
|
status_placeholder.info(f"Adding page {idx}/{len(sources)}: {os.path.basename(str(src))}...") |
|
try: |
|
if is_md: |
|
with open(src, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
content = re.sub(r'!\[.*?\]\(.*?\)', '', content) |
|
paragraphs = content.split('\n\n') |
|
for para in paragraphs: |
|
if para.strip(): |
|
story.append(Paragraph(para.strip(), md_style)) |
|
story.append(PageBreak()) |
|
status_placeholder.success(f"Added markdown page {idx}/{len(sources)}: {filename}") |
|
else: |
|
if isinstance(src, str): |
|
if not os.path.exists(src): |
|
logger.warning(f"Image file not found: {src}. Skipping.") |
|
status_placeholder.warning(f"Skipping missing file: {os.path.basename(src)}") |
|
continue |
|
img_obj = Image.open(src) |
|
filename = os.path.basename(src) |
|
else: |
|
src.seek(0) |
|
img_obj = Image.open(src) |
|
filename = getattr(src, 'name', f'uploaded_image_{idx}') |
|
src.seek(0) |
|
with img_obj: |
|
iw, ih = img_obj.size |
|
if iw <= 0 or ih <= 0: |
|
logger.warning(f"Invalid image dimensions ({iw}x{ih}) for {filename}. Skipping.") |
|
status_placeholder.warning(f"Skipping invalid image: {filename}") |
|
continue |
|
cap_h = 30 |
|
c = canvas.Canvas(BytesIO(), pagesize=(iw, ih + cap_h)) |
|
img_reader = ImageReader(img_obj) |
|
c.drawImage(img_reader, 0, cap_h, width=iw, height=ih, preserveAspectRatio=True, anchor='c', mask='auto') |
|
caption = clean_stem(filename) |
|
c.setFont('Helvetica', 12) |
|
c.setFillColorRGB(0, 0, 0) |
|
c.drawCentredString(iw / 2, cap_h / 2 + 3, caption) |
|
c.setFont('Helvetica', 8) |
|
c.setFillColorRGB(0.5, 0.5, 0.5) |
|
c.drawRightString(iw - 10, 8, f"Page {idx}") |
|
c.save() |
|
story.append(PageBreak()) |
|
status_placeholder.success(f"Added image page {idx}/{len(sources)}: {filename}") |
|
except Exception as e: |
|
logger.error(f"Error processing source {src}: {e}") |
|
status_placeholder.error(f"Error adding page {idx}: {e}") |
|
doc.build(story) |
|
buf.seek(0) |
|
if buf.getbuffer().nbytes < 100: |
|
st.error("PDF generation resulted in an empty file.") |
|
return None |
|
return buf.getvalue() |
|
except Exception as e: |
|
logger.error(f"Fatal error during PDF generation: {e}") |
|
st.error(f"PDF Generation Failed: {e}") |
|
return None |
|
|
|
def update_gallery(gallery_type='image'): |
|
container = st.session_state['asset_gallery_container'][gallery_type] |
|
with container: |
|
st.markdown(f"### {gallery_type.capitalize()} Gallery πΈ") |
|
files = get_typed_gallery_files(gallery_type) |
|
if not files: |
|
st.info(f"No {gallery_type} assets found yet.") |
|
return |
|
st.caption(f"Found {len(files)} assets:") |
|
for idx, file in enumerate(files[:st.session_state.gallery_size]): |
|
st.session_state['unique_counter'] += 1 |
|
unique_id = st.session_state['unique_counter'] |
|
item_key_base = f"{gallery_type}_gallery_item_{os.path.basename(file)}_{unique_id}" |
|
basename = os.path.basename(file) |
|
st.markdown(f"**{basename}**") |
|
try: |
|
file_ext = os.path.splitext(file)[1].lower() |
|
if gallery_type == 'image' and file_ext in ['.png', '.jpg', '.jpeg']: |
|
with st.expander("Preview", expanded=False): |
|
st.image(Image.open(file), use_container_width=True) |
|
elif gallery_type == 'pdf' and file_ext == '.pdf': |
|
with st.expander("Preview (Page 1)", expanded=False): |
|
doc = fitz.open(file) |
|
if len(doc) > 0: |
|
pix = doc[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) |
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
st.image(img, use_container_width=True) |
|
else: |
|
st.warning("Empty PDF") |
|
doc.close() |
|
elif gallery_type == 'md' and file_ext == '.md': |
|
with st.expander("Preview (Start)", expanded=False): |
|
with open(file, 'r', encoding='utf-8', errors='ignore') as f: |
|
content_preview = f.read(200) |
|
st.code(content_preview + "...", language='markdown') |
|
action_cols = st.columns(3) |
|
with action_cols[0]: |
|
checkbox_key = f"cb_{item_key_base}" |
|
st.session_state['asset_checkboxes'][gallery_type][file] = st.checkbox( |
|
"Select", |
|
value=st.session_state['asset_checkboxes'][gallery_type].get(file, False), |
|
key=checkbox_key |
|
) |
|
with action_cols[1]: |
|
mime_map = {'.png': 'image/png', '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.pdf': 'application/pdf', '.md': 'text/markdown'} |
|
mime_type = mime_map.get(file_ext, "application/octet-stream") |
|
dl_key = f"dl_{item_key_base}" |
|
try: |
|
with open(file, "rb") as fp: |
|
st.download_button( |
|
label="π₯", |
|
data=fp, |
|
file_name=basename, |
|
mime=mime_type, |
|
key=dl_key, |
|
help="Download this file" |
|
) |
|
except Exception as dl_e: |
|
st.error(f"Download Error: {dl_e}") |
|
with action_cols[2]: |
|
delete_key = f"del_{item_key_base}" |
|
if st.button("ποΈ", key=delete_key, help=f"Delete {basename}"): |
|
try: |
|
os.remove(file) |
|
st.session_state['asset_checkboxes'][gallery_type].pop(file, None) |
|
if file in st.session_state.get('layout_snapshots', []): |
|
st.session_state['layout_snapshots'].remove(file) |
|
logger.info(f"Deleted {gallery_type} asset: {file}") |
|
st.toast(f"Deleted {basename}!", icon="β
") |
|
st.rerun() |
|
except OSError as e: |
|
logger.error(f"Error deleting file {file}: {e}") |
|
st.error(f"Could not delete {basename}") |
|
except Exception as e: |
|
st.error(f"Error displaying {basename}: {e}") |
|
logger.error(f"Error displaying asset {file}: {e}") |
|
st.markdown("---") |
|
|
|
|
|
st.sidebar.subheader("π€ AI Settings") |
|
with st.sidebar.expander("API Inference Settings", expanded=False): |
|
st.session_state.hf_custom_key = st.text_input( |
|
"Custom HF Token", |
|
value=st.session_state.get('hf_custom_key', ""), |
|
type="password", |
|
key="hf_custom_key_input" |
|
) |
|
token_status = "Custom Key Set" if st.session_state.hf_custom_key else ("Default HF_TOKEN Set" if HF_TOKEN else "No Token Set") |
|
st.caption(f"HF Token Status: {token_status}") |
|
providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"] |
|
st.session_state.hf_provider = st.selectbox( |
|
"HF Inference Provider", |
|
options=providers_list, |
|
index=providers_list.index(st.session_state.get('hf_provider', DEFAULT_PROVIDER)), |
|
key="hf_provider_select" |
|
) |
|
st.session_state.hf_custom_api_model = st.text_input( |
|
"Custom HF API Model ID", |
|
value=st.session_state.get('hf_custom_api_model', ""), |
|
key="hf_custom_model_input" |
|
) |
|
effective_hf_model = st.session_state.hf_custom_api_model.strip() or st.session_state.hf_selected_api_model |
|
st.session_state.hf_selected_api_model = st.selectbox( |
|
"Featured HF API Model", |
|
options=FEATURED_MODELS_LIST, |
|
index=FEATURED_MODELS_LIST.index(st.session_state.get('hf_selected_api_model', FEATURED_MODELS_LIST[0])), |
|
key="hf_featured_model_select" |
|
) |
|
st.caption(f"Effective HF API Model: {effective_hf_model}") |
|
if _openai_available: |
|
st.session_state.openai_selected_model = st.selectbox( |
|
"OpenAI Model", |
|
options=OPENAI_MODELS_LIST, |
|
index=OPENAI_MODELS_LIST.index(st.session_state.get('openai_selected_model', OPENAI_MODELS_LIST[0])), |
|
key="openai_model_select" |
|
) |
|
|
|
with st.sidebar.expander("Local Model Selection", expanded=True): |
|
if not _transformers_available: |
|
st.warning("Transformers library not found. Cannot load local models.") |
|
else: |
|
local_model_options = ["None"] + list(st.session_state.get('local_models', {}).keys()) |
|
current_selection = st.session_state.get('selected_local_model_path', "None") |
|
if current_selection not in local_model_options: |
|
current_selection = "None" |
|
selected_path = st.selectbox( |
|
"Active Local Model", |
|
options=local_model_options, |
|
index=local_model_options.index(current_selection), |
|
format_func=lambda x: os.path.basename(x) if x != "None" else "None", |
|
key="local_model_selector" |
|
) |
|
st.session_state.selected_local_model_path = selected_path if selected_path != "None" else None |
|
if st.session_state.selected_local_model_path: |
|
model_info = st.session_state.local_models[st.session_state.selected_local_model_path] |
|
st.caption(f"Type: {model_info.get('type', 'Unknown')}") |
|
st.caption(f"Device: {model_info.get('model').device if model_info.get('model') else 'N/A'}") |
|
else: |
|
st.caption("No local model selected.") |
|
|
|
with st.sidebar.expander("Generation Parameters", expanded=False): |
|
st.session_state.gen_max_tokens = st.slider("Max New Tokens", 1, 4096, st.session_state.get('gen_max_tokens', 512), key="param_max_tokens") |
|
st.session_state.gen_temperature = st.slider("Temperature", 0.01, 2.0, st.session_state.get('gen_temperature', 0.7), step=0.01, key="param_temp") |
|
st.session_state.gen_top_p = st.slider("Top-P", 0.01, 1.0, st.session_state.get('gen_top_p', 0.95), step=0.01, key="param_top_p") |
|
st.session_state.gen_frequency_penalty = st.slider("Repetition Penalty", 0.0, 1.0, st.session_state.get('gen_frequency_penalty', 0.0), step=0.05, key="param_repetition") |
|
st.session_state.gen_seed = st.slider("Seed", -1, 65535, st.session_state.get('gen_seed', -1), step=1, key="param_seed") |
|
|
|
st.sidebar.subheader("πΌοΈ Gallery Settings") |
|
st.slider( |
|
"Max Items Shown", |
|
min_value=2, |
|
max_value=50, |
|
value=st.session_state.get('gallery_size', 10), |
|
key="gallery_size_slider" |
|
) |
|
st.session_state.gallery_size = st.session_state.gallery_size_slider |
|
st.sidebar.markdown("---") |
|
update_gallery('image') |
|
update_gallery('md') |
|
update_gallery('pdf') |
|
|
|
|
|
st.title("Vision & Layout Titans ππΌοΈπ") |
|
st.markdown("Create PDFs from images and markdown, process with AI, and manage characters.") |
|
tabs = st.tabs([ |
|
"Image/MD->PDF Layout πΌοΈβ‘οΈπ", |
|
"Camera Snap π·", |
|
"Download PDFs π₯", |
|
"Build Titan (Local Models) π±", |
|
"PDF Process (AI) π", |
|
"Image Process (AI) πΌοΈ", |
|
"Text Process (AI) π", |
|
"Test OCR (AI) π", |
|
"Test Image Gen (Diffusers) π¨", |
|
"Character Editor π§βπ¨", |
|
"Character Gallery πΌοΈ" |
|
]) |
|
|
|
with tabs[0]: |
|
st.header("Image/Markdown to PDF Layout Generator") |
|
st.markdown("Select images and markdown files, reorder them, and generate a PDF.") |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.subheader("A. Select Assets") |
|
selected_images = [f for f in get_typed_gallery_files('image') if st.session_state['asset_checkboxes']['image'].get(f, False)] |
|
selected_mds = [f for f in get_typed_gallery_files('md') if st.session_state['asset_checkboxes']['md'].get(f, False)] |
|
st.write(f"Selected Images: {len(selected_images)}") |
|
st.write(f"Selected Markdown Files: {len(selected_mds)}") |
|
with col2: |
|
st.subheader("B. Review and Reorder") |
|
layout_records = [] |
|
for idx, path in enumerate(selected_images + selected_mds, start=1): |
|
is_md = path in selected_mds |
|
try: |
|
if is_md: |
|
with open(path, 'r', encoding='utf-8') as f: |
|
content = f.read(50) |
|
layout_records.append({ |
|
"filename": os.path.basename(path), |
|
"source": path, |
|
"type": "Markdown", |
|
"preview": content + "...", |
|
"order": idx |
|
}) |
|
else: |
|
with Image.open(path) as im: |
|
w, h = im.size |
|
ar = round(w / h, 2) if h > 0 else 0 |
|
orient = "Square" if 0.9 <= ar <= 1.1 else ("Landscape" if ar > 1.1 else "Portrait") |
|
layout_records.append({ |
|
"filename": os.path.basename(path), |
|
"source": path, |
|
"type": "Image", |
|
"width": w, |
|
"height": h, |
|
"aspect_ratio": ar, |
|
"orientation": orient, |
|
"order": idx |
|
}) |
|
except Exception as e: |
|
logger.warning(f"Could not process {path}: {e}") |
|
st.warning(f"Skipping invalid file: {os.path.basename(path)}") |
|
if not layout_records: |
|
st.infoperiod |