Ritesh-hf commited on
Commit
766b4e0
·
verified ·
1 Parent(s): 686f5f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -64,7 +64,7 @@ class StoppingCriteriaSub(StoppingCriteria):
64
 
65
  return False
66
 
67
- device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
68
 
69
  def get_blip_config(model="base"):
70
  config = dict()
@@ -129,7 +129,7 @@ tar_img_feats = torch.cat(tar_img_feats, dim=0)
129
 
130
  class Chat:
131
 
132
- def __init__(self, model, transform, dataframe, tar_img_feats, device='cpu', stopping_criteria=None):
133
  self.device = device
134
  self.model = model
135
  self.transform = transform
@@ -393,7 +393,7 @@ def answer_formatter_node(question, context):
393
  CURR_CONTEXT = ''
394
  CURR_SESSION_KEY = generate_session_key()
395
 
396
- # @spaces.GPU
397
  def get_answer(image=[], message='', sessionID='abc123'):
398
  global CURR_CONTEXT
399
  global CURR_SESSION_KEY
@@ -401,7 +401,7 @@ def get_answer(image=[], message='', sessionID='abc123'):
401
  if image is not None:
402
  try:
403
  # Process the image and message here
404
- device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
405
  chat = Chat(model,transform,df,tar_img_feats, device)
406
  chat.encode_image(image)
407
  data = chat.ask()
@@ -470,7 +470,7 @@ def handle_message(data):
470
  image_bytes = base64.b64decode(image_bytes)
471
  image = Image.open(BytesIO(image_bytes))
472
  image_array = np.array(image)
473
- device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
474
  chat = Chat(model, transform, df, tar_img_feats, device)
475
  chat.encode_image(image_array)
476
  context = chat.ask()
@@ -545,7 +545,7 @@ def handle_message(data):
545
  # @spaces.GPU
546
  def respond_to_user(image, message):
547
  # Process the image and message here
548
- device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
549
  chat = Chat(model,transform,df,tar_img_feats, device)
550
  chat.encode_image(image)
551
  data = chat.ask()
 
64
 
65
  return False
66
 
67
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
68
 
69
  def get_blip_config(model="base"):
70
  config = dict()
 
129
 
130
  class Chat:
131
 
132
+ def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda', stopping_criteria=None):
133
  self.device = device
134
  self.model = model
135
  self.transform = transform
 
393
  CURR_CONTEXT = ''
394
  CURR_SESSION_KEY = generate_session_key()
395
 
396
+ @spaces.GPU
397
  def get_answer(image=[], message='', sessionID='abc123'):
398
  global CURR_CONTEXT
399
  global CURR_SESSION_KEY
 
401
  if image is not None:
402
  try:
403
  # Process the image and message here
404
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
405
  chat = Chat(model,transform,df,tar_img_feats, device)
406
  chat.encode_image(image)
407
  data = chat.ask()
 
470
  image_bytes = base64.b64decode(image_bytes)
471
  image = Image.open(BytesIO(image_bytes))
472
  image_array = np.array(image)
473
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
474
  chat = Chat(model, transform, df, tar_img_feats, device)
475
  chat.encode_image(image_array)
476
  context = chat.ask()
 
545
  # @spaces.GPU
546
  def respond_to_user(image, message):
547
  # Process the image and message here
548
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
549
  chat = Chat(model,transform,df,tar_img_feats, device)
550
  chat.encode_image(image)
551
  data = chat.ask()