Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import random | |
import uuid | |
import time | |
import asyncio | |
from threading import Thread | |
from typing import Tuple | |
import gradio as gr | |
import spaces | |
import torch | |
import numpy as np | |
from PIL import Image | |
import cv2 | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
Qwen2VLForConditionalGeneration, | |
AutoProcessor, | |
) | |
from transformers.image_utils import load_image | |
# --------------------------- | |
# Global Settings and Devices | |
# --------------------------- | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
MAX_SEED = np.iinfo(np.int32).max | |
# --------------------------- | |
# IMAGE GEN LO_RA TAB: SDXL Gen with LoRA Options | |
# --------------------------- | |
# Load the SDXL pipeline | |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # Path from env variable | |
if MODEL_ID_SD is None: | |
MODEL_ID_SD = "SG161222/RealVisXL_V4.0_Lightning" # default fallback | |
# Load SDXL pipeline (use GPU if available) | |
sd_pipe = StableDiffusionXLPipeline.from_pretrained( | |
MODEL_ID_SD, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_safetensors=True, | |
add_watermarker=False, | |
).to(device) | |
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config) | |
if torch.cuda.is_available(): | |
sd_pipe.text_encoder = sd_pipe.text_encoder.half() | |
# Optional: compile or offload if desired | |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" | |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" | |
if USE_TORCH_COMPILE: | |
sd_pipe.compile() | |
if ENABLE_CPU_OFFLOAD: | |
sd_pipe.enable_model_cpu_offload() | |
def save_image(img: Image.Image) -> str: | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
# LoRA options and style definitions | |
LORA_OPTIONS = { | |
"Realism (face/character)π¦π»": ("prithivMLmods/Canopus-Realism-LoRA", "Canopus-Realism-LoRA.safetensors", "rlms"), | |
"Pixar (art/toons)π": ("prithivMLmods/Canopus-Pixar-Art", "Canopus-Pixar-Art.safetensors", "pixar"), | |
"Photoshoot (camera/film)πΈ": ("prithivMLmods/Canopus-Photo-Shoot-Mini-LoRA", "Canopus-Photo-Shoot-Mini-LoRA.safetensors", "photo"), | |
"Clothing (hoodies/pant/shirts)π": ("prithivMLmods/Canopus-Clothing-Adp-LoRA", "Canopus-Dress-Clothing-LoRA.safetensors", "clth"), | |
"Interior Architecture (house/hotel)π ": ("prithivMLmods/Canopus-Interior-Architecture-0.1", "Canopus-Interior-Architecture-0.1Ξ΄.safetensors", "arch"), | |
"Fashion Product (wearing/usable)π": ("prithivMLmods/Canopus-Fashion-Product-Dilation", "Canopus-Fashion-Product-Dilation.safetensors", "fashion"), | |
"Minimalistic Image (minimal/detailed)ποΈ": ("prithivMLmods/Pegasi-Minimalist-Image-Style", "Pegasi-Minimalist-Image-Style.safetensors", "minimalist"), | |
"Modern Clothing (trend/new)π": ("prithivMLmods/Canopus-Modern-Clothing-Design", "Canopus-Modern-Clothing-Design.safetensors", "mdrnclth"), | |
"Animaliea (farm/wild)π«": ("prithivMLmods/Canopus-Animaliea-Artism", "Canopus-Animaliea-Artism.safetensors", "Animaliea"), | |
"Liquid Wallpaper (minimal/illustration)πΌοΈ": ("prithivMLmods/Canopus-Liquid-Wallpaper-Art", "Canopus-Liquid-Wallpaper-Minimalize-LoRA.safetensors", "liquid"), | |
"Canes Cars (realistic/futurecars)π": ("prithivMLmods/Canes-Cars-Model-LoRA", "Canes-Cars-Model-LoRA.safetensors", "car"), | |
"Pencil Art (characteristic/creative)βοΈ": ("prithivMLmods/Canopus-Pencil-Art-LoRA", "Canopus-Pencil-Art-LoRA.safetensors", "Pencil Art"), | |
"Art Minimalistic (paint/semireal)π¨": ("prithivMLmods/Canopus-Art-Medium-LoRA", "Canopus-Art-Medium-LoRA.safetensors", "mdm"), | |
} | |
style_list = [ | |
{ | |
"name": "3840 x 2160", | |
"prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly", | |
}, | |
{ | |
"name": "2560 x 1440", | |
"prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly", | |
}, | |
{ | |
"name": "HD+", | |
"prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic", | |
"negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly", | |
}, | |
{ | |
"name": "Style Zero", | |
"prompt": "{prompt}", | |
"negative_prompt": "", | |
}, | |
] | |
styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list} | |
DEFAULT_STYLE_NAME = "3840 x 2160" | |
STYLE_NAMES = list(styles.keys()) | |
def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]: | |
if style_name in styles: | |
p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME]) | |
else: | |
p, n = styles[DEFAULT_STYLE_NAME] | |
return p.replace("{prompt}", positive), n + negative | |
def generate_image_lora( | |
prompt: str, | |
negative_prompt: str = "", | |
use_negative_prompt: bool = True, | |
seed: int = 0, | |
width: int = 1024, | |
height: int = 1024, | |
guidance_scale: float = 3, | |
randomize_seed: bool = False, | |
style_name: str = DEFAULT_STYLE_NAME, | |
lora_model: str = "Realism (face/character)π¦π»", | |
progress=gr.Progress(track_tqdm=True), | |
): | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
positive_prompt, effective_negative_prompt = apply_style(style_name, prompt, negative_prompt) | |
if not use_negative_prompt: | |
effective_negative_prompt = "" | |
# Set LoRA adapter based on selection | |
model_name, weight_name, adapter_name = LORA_OPTIONS[lora_model] | |
sd_pipe.load_lora_weights(model_name, weight_name=weight_name, adapter_name=adapter_name) | |
sd_pipe.to(device) | |
outputs = sd_pipe( | |
prompt=positive_prompt, | |
negative_prompt=effective_negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=20, | |
num_images_per_prompt=1, | |
cross_attention_kwargs={"scale": 0.65}, | |
output_type="pil", | |
) | |
image_paths = [save_image(img) for img in outputs.images] | |
return image_paths, seed | |
# --------------------------- | |
# Qwen 2 VL OCR TAB | |
# --------------------------- | |
MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True) | |
model_m = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_ID_QWEN, | |
trust_remote_code=True, | |
torch_dtype=torch.float16 | |
).to("cuda" if torch.cuda.is_available() else "cpu").eval() | |
def qwen2vl_ocr_generate( | |
prompt: str, | |
file: list, | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
): | |
# In this tab, we assume the user supplies an image (or multiple images) for OCR. | |
images = [] | |
if file: | |
# load image(s) using the helper function | |
for f in file: | |
images.append(load_image(f)) | |
else: | |
# If no image provided, use an empty list | |
images = [] | |
# Build message content: We use a simple chat template with text and images. | |
messages = [{ | |
"role": "user", | |
"content": [ | |
*[{"type": "image", "image": image} for image in images], | |
{"type": "text", "text": prompt}, | |
] | |
}] | |
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = processor(text=[prompt_full], images=images, return_tensors="pt", padding=True).to("cuda" if torch.cuda.is_available() else "cpu") | |
# Use non-streaming generation for simplicity | |
output_ids = model_m.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
) | |
final_response = processor.tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return final_response | |
# --------------------------- | |
# CHAT INTERFACE TAB (Text-only) | |
# --------------------------- | |
# Load text-only model and tokenizer | |
model_id_text = "prithivMLmods/FastThink-0.5B-Tiny" | |
tokenizer = AutoTokenizer.from_pretrained(model_id_text) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id_text, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
) | |
model.eval() | |
def chat_generate(prompt: str, max_new_tokens: int = 1024, temperature: float = 0.6, | |
top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2): | |
# For simplicity, use a basic generate without streaming. | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
input_ids = input_ids.to(model.device) | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
) | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return response | |
# --------------------------- | |
# GRADIO INTERFACE WITH TABS | |
# --------------------------- | |
with gr.Blocks(title="Multi-Modal Playground") as demo: | |
gr.Markdown("# Multi-Modal Playground") | |
with gr.Tab("Image Gen LoRA"): | |
gr.Markdown("## Generate Images using SDXL + LoRA") | |
with gr.Row(): | |
prompt_img = gr.Textbox(label="Prompt", placeholder="Enter your image prompt here") | |
negative_prompt_img = gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt (optional)", lines=2) | |
with gr.Row(): | |
use_negative = gr.Checkbox(label="Use Negative Prompt", value=True) | |
seed_img = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
with gr.Row(): | |
width_img = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024) | |
height_img = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024) | |
with gr.Row(): | |
guidance_scale_img = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=20.0, step=0.1, value=3.0) | |
with gr.Row(): | |
style_selection = gr.Radio(choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME, label="Quality Style") | |
lora_selection = gr.Dropdown(choices=list(LORA_OPTIONS.keys()), value="Realism (face/character)π¦π»", label="LoRA Selection") | |
run_img = gr.Button("Generate Image") | |
gallery = gr.Gallery(label="Generated Images", columns=1).style(full_width=True) | |
output_seed = gr.Number(label="Seed Used") | |
run_img.click( | |
generate_image_lora, | |
inputs=[prompt_img, negative_prompt_img, use_negative, seed_img, width_img, height_img, guidance_scale_img, | |
randomize_seed, style_selection, lora_selection], | |
outputs=[gallery, output_seed] | |
) | |
with gr.Tab("Qwen 2 VL OCR"): | |
gr.Markdown("## Extract and Generate Text from Images (OCR)") | |
with gr.Row(): | |
prompt_ocr = gr.Textbox(label="OCR Prompt", placeholder="Enter instructions for OCR/text extraction") | |
file_ocr = gr.File(label="Upload Image", file_types=["image"], file_count="multiple") | |
run_ocr = gr.Button("Run OCR") | |
output_ocr = gr.Textbox(label="OCR Output") | |
run_ocr.click( | |
qwen2vl_ocr_generate, | |
inputs=[prompt_ocr, file_ocr], | |
outputs=output_ocr | |
) | |
with gr.Tab("Chat Interface"): | |
gr.Markdown("## Chat with the Text-Only Model") | |
chat_input = gr.Textbox(label="Enter your message", placeholder="Say something...") | |
max_tokens_chat = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS) | |
temperature_chat = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6) | |
top_p_chat = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9) | |
top_k_chat = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50) | |
rep_penalty_chat = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2) | |
run_chat = gr.Button("Send") | |
chat_output = gr.Textbox(label="Response") | |
run_chat.click( | |
chat_generate, | |
inputs=[chat_input, max_tokens_chat, temperature_chat, top_p_chat, top_k_chat, rep_penalty_chat], | |
outputs=chat_output | |
) | |
gr.Markdown("**Adjust parameters in each tab as needed.**") | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch(share=True) |