Spaces:
Runtime error
Runtime error
Last commit not found
import json | |
import os | |
import werkzeug | |
import tensorflow as tf | |
from config import config, parseArgs, configPDF | |
from extract_feature import get_img_feat, build_model | |
from main import setSession, loadWeights, setSavers | |
from model import MACnet | |
from preprocess import Preprocesser | |
import warnings | |
def predict(image, question): | |
parseArgs() | |
config.parallel = True | |
config.evalTrain = True | |
config.retainVal = True | |
config.useEMA = True | |
config.lrReduce = True | |
config.adam = True | |
config.clip = True | |
config.memoryVariationalDropout = True | |
config.relu='ELU' | |
config.encBi = True | |
config.wrdEmbRandom = True | |
config.wrdEmbUniform = True | |
config.outQuestion = True | |
config.initCtrl='Q' | |
config.controlContextual = True | |
config.controlInputUnshared = True | |
config.readProjInputs = True | |
config.readMemConcatKB = True | |
config.readMemConcatProj = True | |
config.readMemProj = True | |
config.readCtrl = True | |
config.writeMemProj = True | |
config.restore = True | |
config.expName = 'PDF_exp_extra' | |
config.netLength = 16 | |
configPDF() | |
with open(config.configFile(), "a+") as outFile: | |
json.dump(vars(config), outFile) | |
if config.gpus != "": | |
config.gpusNum = len(config.gpus.split(",")) | |
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpus | |
tf.reset_default_graph() | |
tf.Graph().as_default() | |
tf.logging.set_verbosity(tf.logging.ERROR) | |
cnn_model = build_model() | |
imageData = get_img_feat(cnn_model, image) | |
preprocessor = Preprocesser() | |
qData, embeddings, answerDict = preprocessor.preprocessData(question) | |
model = MACnet(embeddings, answerDict) | |
init = tf.global_variables_initializer() | |
savers = setSavers(model) | |
saver, emaSaver = savers["saver"], savers["emaSaver"] | |
sessionConfig = setSession() | |
data = {'data': qData, 'image': imageData} | |
with tf.Session(config=sessionConfig) as sess: | |
sess.graph.finalize() | |
# epoch = loadWeights(sess, saver, init) | |
print('###############', config.weightsFile(25)) | |
os.system('ls -l ./weights/PDF_exp_extra') | |
emaSaver.restore(sess, config.weightsFile(25)) | |
evalRes = model.runBatch(sess, data['data'], data['image'], False) | |
answer = None | |
if evalRes in ['top', 'bottom']: | |
answer = 'The caption at the %s side of the object.' % evalRes | |
elif evalRes in ['True', 'False']: | |
answer = 'There is at least one title object in this image.' | |
else: | |
answer = 'This image contain %s specific object(s).' % evalRes | |
return answer |