Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -64,7 +64,7 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
64 |
|
65 |
return False
|
66 |
|
67 |
-
device = torch.device("
|
68 |
|
69 |
def get_blip_config(model="base"):
|
70 |
config = dict()
|
@@ -129,7 +129,7 @@ tar_img_feats = torch.cat(tar_img_feats, dim=0)
|
|
129 |
|
130 |
class Chat:
|
131 |
|
132 |
-
def __init__(self, model, transform, dataframe, tar_img_feats, device='
|
133 |
self.device = device
|
134 |
self.model = model
|
135 |
self.transform = transform
|
@@ -393,7 +393,7 @@ def answer_formatter_node(question, context):
|
|
393 |
CURR_CONTEXT = ''
|
394 |
CURR_SESSION_KEY = generate_session_key()
|
395 |
|
396 |
-
|
397 |
def get_answer(image=[], message='', sessionID='abc123'):
|
398 |
global CURR_CONTEXT
|
399 |
global CURR_SESSION_KEY
|
@@ -401,7 +401,7 @@ def get_answer(image=[], message='', sessionID='abc123'):
|
|
401 |
if image is not None:
|
402 |
try:
|
403 |
# Process the image and message here
|
404 |
-
device = torch.device("
|
405 |
chat = Chat(model,transform,df,tar_img_feats, device)
|
406 |
chat.encode_image(image)
|
407 |
data = chat.ask()
|
@@ -470,7 +470,7 @@ def handle_message(data):
|
|
470 |
image_bytes = base64.b64decode(image_bytes)
|
471 |
image = Image.open(BytesIO(image_bytes))
|
472 |
image_array = np.array(image)
|
473 |
-
device = torch.device("
|
474 |
chat = Chat(model, transform, df, tar_img_feats, device)
|
475 |
chat.encode_image(image_array)
|
476 |
context = chat.ask()
|
@@ -545,7 +545,7 @@ def handle_message(data):
|
|
545 |
# @spaces.GPU
|
546 |
def respond_to_user(image, message):
|
547 |
# Process the image and message here
|
548 |
-
device = torch.device("
|
549 |
chat = Chat(model,transform,df,tar_img_feats, device)
|
550 |
chat.encode_image(image)
|
551 |
data = chat.ask()
|
|
|
64 |
|
65 |
return False
|
66 |
|
67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
|
69 |
def get_blip_config(model="base"):
|
70 |
config = dict()
|
|
|
129 |
|
130 |
class Chat:
|
131 |
|
132 |
+
def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda', stopping_criteria=None):
|
133 |
self.device = device
|
134 |
self.model = model
|
135 |
self.transform = transform
|
|
|
393 |
CURR_CONTEXT = ''
|
394 |
CURR_SESSION_KEY = generate_session_key()
|
395 |
|
396 |
+
@spaces.GPU
|
397 |
def get_answer(image=[], message='', sessionID='abc123'):
|
398 |
global CURR_CONTEXT
|
399 |
global CURR_SESSION_KEY
|
|
|
401 |
if image is not None:
|
402 |
try:
|
403 |
# Process the image and message here
|
404 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
405 |
chat = Chat(model,transform,df,tar_img_feats, device)
|
406 |
chat.encode_image(image)
|
407 |
data = chat.ask()
|
|
|
470 |
image_bytes = base64.b64decode(image_bytes)
|
471 |
image = Image.open(BytesIO(image_bytes))
|
472 |
image_array = np.array(image)
|
473 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
474 |
chat = Chat(model, transform, df, tar_img_feats, device)
|
475 |
chat.encode_image(image_array)
|
476 |
context = chat.ask()
|
|
|
545 |
# @spaces.GPU
|
546 |
def respond_to_user(image, message):
|
547 |
# Process the image and message here
|
548 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
549 |
chat = Chat(model,transform,df,tar_img_feats, device)
|
550 |
chat.encode_image(image)
|
551 |
data = chat.ask()
|