Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -11,11 +11,13 @@ torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQ
|
|
11 |
model_name = "ahmed-masry/unichart-base-960"
|
12 |
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
13 |
processor = DonutProcessor.from_pretrained(model_name)
|
14 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
-
model.to(device)
|
16 |
|
17 |
|
|
|
18 |
def predict(image, input_prompt):
|
|
|
|
|
|
|
19 |
input_prompt += " <s_answer>"
|
20 |
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
21 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
|
|
11 |
model_name = "ahmed-masry/unichart-base-960"
|
12 |
model = VisionEncoderDecoderModel.from_pretrained(model_name)
|
13 |
processor = DonutProcessor.from_pretrained(model_name)
|
|
|
|
|
14 |
|
15 |
|
16 |
+
@spaces.GPU
|
17 |
def predict(image, input_prompt):
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
model.to(device)
|
20 |
+
|
21 |
input_prompt += " <s_answer>"
|
22 |
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
23 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|