pacman2223 commited on
Commit
10ecefc
·
verified ·
1 Parent(s): e61f26a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -16,25 +16,25 @@ image.save("hack.png")
16
 
17
  def demo_process(img, question):
18
 
19
- # processor = AutoProcessor.from_pretrained(model_checkpoint)
20
- # model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)
21
- # with torch.no_grad():
22
- # encoding = processor(img.convert("RGB"), question, return_tensors="pt")
23
- # outputs = model(**encoding)
24
- # start_logits = outputs.start_logits
25
- # end_logits = outputs.end_logits
26
- # predicted_start_idx = start_logits.argmax(-1).item()
27
- # predicted_end_idx = end_logits.argmax(-1).item()
28
 
29
- # processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])
30
- # predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
31
 
32
- # predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
33
 
34
- # return predicted_answer
35
 
36
- qa_pipeline = pipeline("document-question-answering", model="pacman2223/test-mod")
37
- qa_pipeline(img, question)
38
 
39
  return qa_pipeline
40
 
 
16
 
17
  def demo_process(img, question):
18
 
19
+ processor = AutoProcessor.from_pretrained(model_checkpoint)
20
+ model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)
21
+ with torch.no_grad():
22
+ encoding = processor(img.convert("RGB"), question, return_tensors="pt")
23
+ outputs = model(**encoding)
24
+ start_logits = outputs.start_logits
25
+ end_logits = outputs.end_logits
26
+ predicted_start_idx = start_logits.argmax(-1).item()
27
+ predicted_end_idx = end_logits.argmax(-1).item()
28
 
29
+ processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])
30
+ predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
31
 
32
+ predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
33
 
34
+ return predicted_answer
35
 
36
+ # qa_pipeline = pipeline("document-question-answering", model="pacman2223/test-mod")
37
+ # qa_pipeline(img, question)
38
 
39
  return qa_pipeline
40