Spaces:
Paused
Paused
File size: 6,690 Bytes
284b864 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import torch
from PIL import Image
import requests
from io import BytesIO
import gradio as gr
import os
import sys
import time
import warnings
# Suppress warnings
warnings.filterwarnings("ignore")
print("Starting InternVL2 with Llama3-76B initialization...")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Set up environment for CUDA
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Check GPU availability
def check_gpu():
if not torch.cuda.is_available():
print("CUDA is not available. This application requires GPU acceleration.")
return False
try:
# Test GPU with a simple operation
test_tensor = torch.rand(10, device="cuda")
_ = test_tensor + test_tensor
print(f"GPU is available: {torch.cuda.get_device_name(0)}")
return True
except Exception as e:
print(f"Error initializing GPU: {str(e)}")
return False
# Global flag for GPU availability
USE_GPU = check_gpu()
# Import InternVL modules
try:
from transformers import AutoModel, AutoProcessor
HAS_TRANSFORMERS = True
print("Successfully imported transformers")
except ImportError as e:
print(f"Error importing transformers: {str(e)}")
HAS_TRANSFORMERS = False
# Initialize models
internvit_model = None
llama_model = None
processor = None
def load_models():
global internvit_model, llama_model, processor
if not USE_GPU:
print("Cannot load models without GPU")
return False
try:
print("Loading InternViT-6B model for visual feature extraction...")
# Following the GitHub repo instructions for using InternViT-6B
processor = AutoProcessor.from_pretrained("OpenGVLab/InternViT-6B-224px")
internvit_model = AutoModel.from_pretrained("OpenGVLab/InternViT-6B-224px")
if USE_GPU:
internvit_model = internvit_model.to("cuda")
print("InternViT-6B model loaded successfully!")
# For demonstration purposes, we'll just extract visual features for now
# In a real implementation, we would load Llama3-76B here
print("Note: Llama3-76B model loading is commented out for this demonstration")
# llama_model = ...
return True
except Exception as e:
print(f"Error loading models: {str(e)}")
return False
# Load models on startup
MODELS_LOADED = load_models()
def process_image(image_path, sample_url=None):
"""Process an image using InternViT-6B for feature extraction"""
# Load image
if sample_url and not image_path:
# Load from URL if provided and no image uploaded
response = requests.get(sample_url)
image = Image.open(BytesIO(response.content))
print(f"Loaded sample image from URL: {sample_url}")
else:
# Use uploaded image
if isinstance(image_path, str):
image = Image.open(image_path)
else:
image = image_path
if not image:
return "No image provided"
if not MODELS_LOADED:
return "Models failed to load. Please check the logs."
try:
# Start timing
start_time = time.time()
# Process image through the visual encoder
print("Processing image through InternViT-6B...")
inputs = processor(images=image, return_tensors="pt")
if USE_GPU:
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
outputs = internvit_model(**inputs)
# Extract image features
image_features = outputs.last_hidden_state
pooled_output = outputs.pooler_output
# In a real implementation, we would pass these features to Llama3-76B
# For now, we'll just return info about the extracted features
feature_info = f"""
Image successfully processed through InternViT-6B:
- Last hidden state shape: {image_features.shape}
- Pooled output shape: {pooled_output.shape}
In a complete implementation, these visual features would be passed to Llama3-76B
for generating text responses about the image.
Note: This is a demonstration of visual feature extraction only.
"""
# Calculate elapsed time
elapsed = time.time() - start_time
return f"{feature_info}\n\nProcessing completed in {elapsed:.2f} seconds."
except Exception as e:
return f"Error processing image: {str(e)}"
# Set up Gradio interface
def create_interface():
with gr.Blocks(title="InternVL2 with Llama3-76B") as demo:
gr.Markdown("# InternVL2 Visual Feature Extraction Demo")
gr.Markdown("## Using InternViT-6B for visual feature extraction")
# System status
status = "✅ Ready" if MODELS_LOADED else "❌ Models failed to load"
gr.Markdown(f"### System Status: {status}")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
sample_btn = gr.Button("Use Sample Image")
with gr.Column():
output_text = gr.Textbox(label="Results", lines=10)
# Process button
process_btn = gr.Button("Extract Visual Features")
process_btn.click(
fn=process_image,
inputs=[input_image],
outputs=output_text
)
# Sample image button logic
sample_image_url = "https://huggingface.co/OpenGVLab/InternVL2/resolve/main/assets/demo.jpg"
def use_sample():
return process_image(None, sample_image_url)
sample_btn.click(
fn=use_sample,
inputs=[],
outputs=output_text
)
# Add some explanation
gr.Markdown("""
## About This Demo
This demonstration shows how to use InternViT-6B for visual feature extraction,
following the instructions from the OpenGVLab/InternVL GitHub repository.
The application extracts visual features from the input image that would typically
be passed to a language model like Llama3-76B. In a complete implementation,
these features would be used to generate text responses about the image.
""")
return demo
# Main function
if __name__ == "__main__":
demo = create_interface()
demo.launch(share=False, server_name="0.0.0.0") |