Commit
·
26d3bd8
1
Parent(s):
d14c041
log: test path
Browse files
model.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import wandb
|
4 |
import streamlit as st
|
|
|
5 |
|
6 |
import clip
|
7 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
@@ -128,8 +129,9 @@ def load_model():
|
|
128 |
gpt_model, tokenizer = load_gpt_model()
|
129 |
|
130 |
|
131 |
-
|
132 |
# Load weights
|
|
|
|
|
133 |
model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0)
|
134 |
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
|
135 |
model.load_state_dict(checkpoint["state_dict"])
|
|
|
2 |
import torch.nn as nn
|
3 |
import wandb
|
4 |
import streamlit as st
|
5 |
+
import os
|
6 |
|
7 |
import clip
|
8 |
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
|
129 |
gpt_model, tokenizer = load_gpt_model()
|
130 |
|
131 |
|
|
|
132 |
# Load weights
|
133 |
+
print(PATH)
|
134 |
+
print(os.getcwd())
|
135 |
model = ImageCaptioner(clip_model, gpt_model, tokenizer, 0)
|
136 |
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
|
137 |
model.load_state_dict(checkpoint["state_dict"])
|