import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextStreamer import torch import gc import os from accelerate import init_empty_weights from accelerate.utils import load_checkpoint_in_model import psutil # Enable better CPU performance torch.set_num_threads(4) device = "cpu" def get_free_memory(): """Get available memory in GB""" return psutil.virtual_memory().available / (1024 * 1024 * 1024) def load_model_in_chunks(model_path, chunk_size_gb=2): """Load model in chunks to manage memory""" config = AutoModelForCausalLM.from_pretrained(model_path, return_dict=False).config with init_empty_weights(): empty_model = AutoModelForCausalLM.from_config(config) # Get checkpoint files index_path = os.path.join(model_path, "model.safetensors.index.json") if os.path.exists(index_path): checkpoint_files = [ os.path.join(model_path, f"model-{i:05d}-of-{5:05d}.safetensors") for i in range(1, 6) ] else: checkpoint_files = [os.path.join(model_path, "pytorch_model.bin")] # Load each chunk for checkpoint in checkpoint_files: if get_free_memory() < 2: # If less than 2GB free gc.collect() torch.cuda.empty_cache() load_checkpoint_in_model(empty_model, checkpoint) gc.collect() return empty_model def load_model(): model_name = "forestav/unsloth_vision_radiography_finetune" base_model_name = "unsloth/Llama-3.2-11B-Vision-Instruct" print("Loading tokenizer and processor...") tokenizer = AutoTokenizer.from_pretrained( base_model_name, trust_remote_code=True, cache_dir="model_cache" ) processor = AutoProcessor.from_pretrained( base_model_name, trust_remote_code=True, cache_dir="model_cache" ) print("Loading model in chunks...") model = load_model_in_chunks(model_name) print("Optimizing model for CPU...") # Convert to float32 and quantize model = model.to(torch.float32) model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8 ) return model, tokenizer, processor # Create cache directories os.makedirs("model_cache", exist_ok=True) os.makedirs("offload", exist_ok=True) print(f"Available memory before loading: {get_free_memory():.2f} GB") # Initialize model and tokenizer globally print("Starting model initialization...") try: model, tokenizer, processor = load_model() print("Model loaded successfully!") print(f"Available memory after loading: {get_free_memory():.2f} GB") except Exception as e: print(f"Error loading model: {str(e)}") raise def analyze_image(image, instruction): try: gc.collect() if instruction.strip() == "": instruction = "You are an expert radiographer. Describe accurately what you see in this image." messages = [ {"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": instruction} ]} ] # Process with memory checks if get_free_memory() < 2: gc.collect() inputs = processor( images=image, text=tokenizer.apply_chat_template(messages, add_generation_prompt=True), return_tensors="pt" ) # Generate with minimal memory usage with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=128, temperature=1.0, min_p=0.1, use_cache=True, pad_token_id=tokenizer.eos_token_id, num_beams=1, do_sample=False # Disable sampling to save memory ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) del outputs, inputs gc.collect() return response except Exception as e: return f"Error processing image: {str(e)}\nPlease try again with a smaller image or different settings." # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown(""" # Medical Image Analysis Assistant Upload a medical image and receive a professional description from an AI radiographer. """) with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="Upload Medical Image", max_pixels=1000000 # Reduced max image size ) instruction_input = gr.Textbox( label="Custom Instruction (optional)", placeholder="You are an expert radiographer. Describe accurately what you see in this image.", lines=2 ) submit_btn = gr.Button("Analyze Image") with gr.Column(): output_text = gr.Textbox(label="Analysis Result", lines=10) submit_btn.click( fn=analyze_image, inputs=[image_input, instruction_input], outputs=output_text ) gr.Markdown(""" ### Notes: - The model runs on CPU and may take several minutes to process each image - For best results, upload images smaller than 1MP - Initial loading may take some time - Please be patient during processing """) if __name__ == "__main__": demo.launch()