Vinay15 commited on
Commit
26236d1
·
verified ·
1 Parent(s): 3a5fa1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -30
app.py CHANGED
@@ -1,36 +1,31 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForImageClassification
3
- from PIL import Image
4
- import torch
5
 
6
- # Load the model and tokenizer with trust_remote_code set to True
7
- tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
8
- model = AutoModelForImageClassification.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
 
 
 
 
 
9
 
10
- def perform_ocr(image):
11
- # Ensure the image is in the right format
12
- if isinstance(image, Image.Image):
13
- image = image.convert("RGB")
14
- else:
15
- raise ValueError("Input must be a PIL Image")
16
 
17
- # Use the model to perform OCR
18
- inputs = tokenizer(image, return_tensors="pt")
19
- with torch.no_grad():
20
- outputs = model(**inputs)
21
-
22
- # Get the predictions
23
- predictions = outputs.logits.argmax(dim=1).item()
24
- return predictions
25
 
26
- # Create the Gradio interface
27
- iface = gr.Interface(
28
- fn=perform_ocr,
29
- inputs=gr.inputs.Image(type="pil"),
30
- outputs="text",
31
- title="OCR with GOT-OCR2.0",
32
- description="Upload an image for Optical Character Recognition."
33
- )
34
 
35
- # Launch the interface
36
- iface.launch()
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForImageClassification
 
 
2
 
3
+ def load_model_and_tokenizer():
4
+ try:
5
+ # Load the tokenizer with the specific revision
6
+ tokenizer = AutoTokenizer.from_pretrained(
7
+ 'stepfun-ai/GOT-OCR2_0',
8
+ revision='cf6b7386bc89a54f09785612ba74cb12de6fa17c', # Pin the specific revision
9
+ trust_remote_code=True
10
+ )
11
 
12
+ # Load the model with the specific revision
13
+ model = AutoModelForImageClassification.from_pretrained(
14
+ 'stepfun-ai/GOT-OCR2_0',
15
+ revision='cf6b7386bc89a54f09785612ba74cb12de6fa17c', # Pin the specific revision
16
+ trust_remote_code=True
17
+ )
18
 
19
+ return model, tokenizer
 
 
 
 
 
 
 
20
 
21
+ except Exception as e:
22
+ print(f"An error occurred while loading the model and tokenizer: {e}")
23
+ return None, None
 
 
 
 
 
24
 
25
+ # Example usage
26
+ if __name__ == "__main__":
27
+ model, tokenizer = load_model_and_tokenizer()
28
+ if model and tokenizer:
29
+ print("Model and tokenizer loaded successfully!")
30
+ else:
31
+ print("Failed to load model and tokenizer.")