ahmed-masry commited on
Commit
66ef4a9
·
verified ·
1 Parent(s): 850d522

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -11,11 +11,13 @@ torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQ
11
  model_name = "ahmed-masry/unichart-base-960"
12
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
13
  processor = DonutProcessor.from_pretrained(model_name)
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- model.to(device)
16
 
17
 
 
18
  def predict(image, input_prompt):
 
 
 
19
  input_prompt += " <s_answer>"
20
  decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
21
  pixel_values = processor(image, return_tensors="pt").pixel_values
 
11
  model_name = "ahmed-masry/unichart-base-960"
12
  model = VisionEncoderDecoderModel.from_pretrained(model_name)
13
  processor = DonutProcessor.from_pretrained(model_name)
 
 
14
 
15
 
16
+ @spaces.GPU
17
  def predict(image, input_prompt):
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
20
+
21
  input_prompt += " <s_answer>"
22
  decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
23
  pixel_values = processor(image, return_tensors="pt").pixel_values