Two steps only need.

First step. (git clone and install required packages)

git clone https://github.com/ByungKwanLee/TroL
bash install

Second step. (open, edit, and run demo.py)

import torch
from config import *
from PIL import Image
from utils.utils import *
import torch.nn.functional as F
from trol.load_trol import load_trol
from torchvision.transforms.functional import pil_to_tensor

# model selection
link = "TroL-3.8B" # [Select One] 'TroL-1.8B' | 'TroL-3.8B' | 'TroL-7B'

# User prompt
prompt_type="with_image" # Select one option "text_only", "with_image"
img_path='figures/demo.png'
question="What is the troll doing? Provide the detail in the image and imagine what the event happens."

# loading model
model, tokenizer = load_trol(link=link)
    
# cpu -> gpu
for param in model.parameters():
    if not param.is_cuda:
        param.data = param.to('cuda:0')

# prompt type -> input prompt
image_token_number = None
if prompt_type == 'with_image':
    # Image Load
    image = pil_to_tensor(Image.open(img_path).convert("RGB"))
    if not "3.8B" in link:
        image_token_number = 1225
        image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
    inputs = [{'image': image, 'question': question}]
elif prompt_type=='text_only':
    inputs = [{'question': question}]

# Generate
with torch.inference_mode():
    _inputs = model.eval_process(inputs=inputs,
                                 data='demo',
                                 tokenizer=tokenizer,
                                 device='cuda:0',
                                 img_token_number=image_token_number)
    generate_ids = model.generate(**_inputs, max_new_tokens=256, use_cache=True)
    response = output_filtering(tokenizer.batch_decode(generate_ids, skip_special_tokens=False)[0], model)
print(response)

So easy Let's say TroL!

Paper arxiv.org/abs/2406.12246

Downloads last month
36
Safetensors
Model size
4.15B params
Tensor type
FP16
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Space using BK-Lee/TroL-3.8B 1

Collection including BK-Lee/TroL-3.8B