|
|
|
|
|
|
|
""" |
|
Test script for InternVL2-40B-AWQ model using lmdeploy |
|
""" |
|
|
|
import argparse |
|
import os |
|
import time |
|
from PIL import Image |
|
|
|
from lmdeploy import pipeline, TurbomindEngineConfig |
|
from lmdeploy.vl import load_image |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Test InternVL2 model") |
|
parser.add_argument( |
|
"--model", |
|
type=str, |
|
default="OpenGVLab/InternVL2-40B-AWQ", |
|
help="Model name or path" |
|
) |
|
parser.add_argument( |
|
"--image", |
|
type=str, |
|
required=True, |
|
help="Path to the test image" |
|
) |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
default="Describe this image in detail.", |
|
help="Prompt for the model" |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
print(f"Loading model: {args.model}") |
|
start_time = time.time() |
|
|
|
|
|
backend_config = TurbomindEngineConfig(model_format='awq') |
|
|
|
|
|
pipe = pipeline(args.model, backend_config=backend_config, log_level='INFO') |
|
|
|
load_time = time.time() - start_time |
|
print(f"Model loaded in {load_time:.2f} seconds") |
|
|
|
|
|
try: |
|
if args.image.startswith(('http://', 'https://')): |
|
image = load_image(args.image) |
|
print(f"Loaded image from URL: {args.image}") |
|
else: |
|
image_path = os.path.abspath(args.image) |
|
if not os.path.exists(image_path): |
|
raise FileNotFoundError(f"Image not found: {image_path}") |
|
image = Image.open(image_path).convert('RGB') |
|
print(f"Loaded image from path: {image_path}") |
|
except Exception as e: |
|
print(f"Error loading image: {e}") |
|
return |
|
|
|
|
|
print(f"Running inference with prompt: '{args.prompt}'") |
|
start_time = time.time() |
|
|
|
response = pipe((args.prompt, image)) |
|
|
|
inference_time = time.time() - start_time |
|
print(f"Inference completed in {inference_time:.2f} seconds") |
|
|
|
|
|
print("\n--- RESULT ---") |
|
print(response.text) |
|
print("-------------\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |