mattia-re-learn commited on
Commit
cab6235
·
verified ·
1 Parent(s): c8eff1e

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +10 -0
code/inference.py CHANGED
@@ -3,6 +3,8 @@ from PIL import Image
3
  from io import BytesIO
4
  import torch
5
  from transformers import AutoTokenizer
 
 
6
 
7
  from llava.model import LlavaLlamaForCausalLM
8
  from llava.utils import disable_torch_init
@@ -68,6 +70,14 @@ def predict_fn(data, model_and_tokenizer):
68
  if image_file.startswith("http") or image_file.startswith("https"):
69
  response = requests.get(image_file)
70
  image = Image.open(BytesIO(response.content)).convert("RGB")
 
 
 
 
 
 
 
 
71
  else:
72
  image = Image.open(image_file).convert("RGB")
73
 
 
3
  from io import BytesIO
4
  import torch
5
  from transformers import AutoTokenizer
6
+ import boto3
7
+ import tempfile
8
 
9
  from llava.model import LlavaLlamaForCausalLM
10
  from llava.utils import disable_torch_init
 
70
  if image_file.startswith("http") or image_file.startswith("https"):
71
  response = requests.get(image_file)
72
  image = Image.open(BytesIO(response.content)).convert("RGB")
73
+ elif image_file.startswith("s3://"):
74
+ s3 = boto3.client("s3")
75
+ s3_path = s3_path[5:]
76
+ bucket = s3_path.split('/')[0]
77
+ s3_key = '/'.join(s3_path.split('/')[1:])
78
+ with tempfile.NamedTemporaryFile() as temp_file:
79
+ s3.download_file(bucket, s3_key, temp_file.name)
80
+ image = Image.open(temp_file).convert("RGB")
81
  else:
82
  image = Image.open(image_file).convert("RGB")
83