## Setup

We need `transformers`, `torchvision` and `einops` as basic dependencies for the model. 
For this example, we also use `wget` for fetching data remotely, `decord` for decoding video frames, and `mediapy` for saving videos.

In [1]:
!pip install transformers torchvision einops decord mediapy



In [1]:
import decord
import numpy as np
import torch
from transformers import AutoConfig, AutoModel, AutoProcessor
from IPython.display import Video
import subprocess
import io

  from .autonotebook import tqdm as notebook_tqdm


## Instantiate model

We use `AutoModel` and `AutoProcessor` to download the weights and inference code for Cosmos-Embed1. The model has been trained with bfloat16, so we should cast if the GPU supports it. The preprocessor tokenizes text and resizes/rescales batched video frames. We also override the default resolution to a non-square example.

In [2]:
path = "../"

config = AutoConfig.from_pretrained(path, trust_remote_code=True)

model = AutoModel.from_pretrained(path, trust_remote_code=True, config=config).to("cuda", dtype=torch.bfloat16)
model.eval()
preprocess = AutoProcessor.from_pretrained(path, trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 5/5 [00:00<00:00, 15.64it/s]


## Fetch data

In [3]:
video_url = "https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm"
subprocess.check_call(["wget", "-O", "/tmp/output.mp4", video_url])
video_bytes = open("/tmp/output.mp4", "rb").read()
assert video_bytes
Video(video_url)

--2025-06-03 16:11:10--  https://upload.wikimedia.org/wikipedia/commons/3/3d/Branko_Paukovic%2C_javelin_throw.webm
Resolving upload.wikimedia.org (upload.wikimedia.org)... 198.35.26.112, 2620:0:863:ed1a::2:b
Connecting to upload.wikimedia.org (upload.wikimedia.org)|198.35.26.112|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 159119 (155K) [video/webm]
Saving to: ‘/tmp/output.mp4’

     0K .......... .......... .......... .......... .......... 32% 1.36M 0s
    50K .......... .......... .......... .......... .......... 64% 14.6M 0s
   100K .......... .......... .......... .......... .......... 96% 1.31M 0s
   150K .....                                                 100% 10.0T=0.08s

2025-06-03 16:11:10 (1.98 MB/s) - ‘/tmp/output.mp4’ saved [159119/159119]



We sample 8 frames from the single video and create a tensor of shape `batch_size x num_frames x channel_dim x height x width`. The model has been trained on 8 frames sampled at 1-2FPS. For this example, we linearly sample frames from the entire ~2s clip.

In [4]:
with io.BytesIO(video_bytes) as fp:
    reader = decord.VideoReader(fp)
    frame_ids = np.linspace(0, len(reader)-1, 8, dtype=int).tolist()
    frames = reader.get_batch(frame_ids).asnumpy()
batch = np.transpose(np.expand_dims(frames, 0), (0, 1, 4, 2, 3))  # BTCHW

## Inference

We run inference on the video batch by preprocessing it, moving it to the GPU and calling the `get_video_embeddings` method.

We run inference on text captions by preprocessing them into tokens and attention masks, moving to the GPU and calling the `get_text_embeddings` method. 

We can then calculate the similarity between the text and video embeddings using a dot-product, and rank the captions by highest similarity to the video. The model correctly ranks the most likely caption as being `a man wearing red spandex throwing a javelin`.

In [5]:
video_inputs = preprocess(videos=batch).to("cuda", dtype=torch.bfloat16)
with torch.no_grad():
    video_out = model.get_video_embeddings(**video_inputs)

captions = [
    "a person riding a motorcycle in the night",
    "a car overtaking a white truck",
    "a video of a knight fighting with a sword",
    "a man wearing red spandex throwing a javelin",
    "a young man javelin throwing during the evening", # distractor
    "a man throwing a javelin with both hands", # distractor
]
text_inputs = preprocess(text=captions).to("cuda", dtype=torch.bfloat16)
with torch.no_grad():
    text_out = model.get_text_embeddings(**text_inputs)

probs = (torch.softmax(model.logit_scale.exp() * video_out.visual_proj @ text_out.text_proj.T, dim=-1))[0]
print(captions[probs.argmax()])

a man wearing red spandex throwing a javelin
