dwb2023 commited on
Commit
e38f582
·
verified ·
1 Parent(s): 01a4990

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +4 -4
inference.py CHANGED
@@ -1,10 +1,12 @@
1
- import functools
2
  import os
3
  import PIL.Image
4
  import torch
5
  from huggingface_hub import login
6
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
7
- import spaces
 
 
 
8
 
9
  hf_token = os.getenv("HF_TOKEN")
10
  login(token=hf_token, add_to_git_credential=True)
@@ -16,7 +18,6 @@ class PaliGemmaModel:
16
  self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
17
  self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)
18
 
19
- @spaces.GPU
20
  def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
21
  inputs = self.processor(text=text, images=image, return_tensors="pt")
22
  inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device
@@ -125,4 +126,3 @@ class VAEModel:
125
  return x
126
 
127
  return jax.jit(Decoder().apply, backend='cpu')
128
-
 
 
1
  import os
2
  import PIL.Image
3
  import torch
4
  from huggingface_hub import login
5
  from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import functools
10
 
11
  hf_token = os.getenv("HF_TOKEN")
12
  login(token=hf_token, add_to_git_credential=True)
 
18
  self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
19
  self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)
20
 
 
21
  def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
22
  inputs = self.processor(text=text, images=image, return_tensors="pt")
23
  inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device
 
126
  return x
127
 
128
  return jax.jit(Decoder().apply, backend='cpu')