Vinay15 commited on
Commit
aa4c474
·
verified ·
1 Parent(s): 2b3d2ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -27
app.py CHANGED
@@ -1,38 +1,36 @@
1
  import gradio as gr
 
2
  from PIL import Image
3
- # Assuming 'model' and 'tokenizer' are defined elsewhere in your code
4
- # from your_model_file import model, tokenizer
5
 
6
- def load_image(image_file):
7
- """Load and preprocess the image."""
8
- if isinstance(image_file, Image.Image): # Check if the input is an Image object
9
- return image_file.convert("RGB") # Convert to RGB if necessary
10
- elif isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
11
- # Handle URL case (you can use an external library to fetch the image if needed)
12
- return Image.open(requests.get(image_file, stream=True).raw).convert("RGB")
13
- else:
14
- # Handle file path case
15
- return Image.open(image_file).convert("RGB")
16
 
17
  def perform_ocr(image):
18
- """Perform OCR on the uploaded image."""
19
- try:
20
- # Load and preprocess the image
21
- processed_image = load_image(image)
22
- # Use the model for OCR
23
- res = model.chat(tokenizer, processed_image, ocr_type='ocr')
24
- return res
25
- except Exception as e:
26
- return str(e) # Return the error message
 
 
 
 
 
27
 
28
- # Gradio interface setup
29
  iface = gr.Interface(
30
  fn=perform_ocr,
31
- inputs=gr.Image(type="pil"), # Ensure Gradio accepts images as PIL images
32
  outputs="text",
33
- title="OCR Application",
34
- description="Upload an image to perform Optical Character Recognition (OCR)."
35
  )
36
 
37
- if __name__ == "__main__":
38
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForImageClassification
3
  from PIL import Image
4
+ import torch
 
5
 
6
+ # Load the model and tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0')
8
+ model = AutoModelForImageClassification.from_pretrained('stepfun-ai/GOT-OCR2_0')
 
 
 
 
 
 
 
9
 
10
  def perform_ocr(image):
11
+ # Ensure the image is in the right format
12
+ if isinstance(image, Image.Image):
13
+ image = image.convert("RGB")
14
+ else:
15
+ raise ValueError("Input must be a PIL Image")
16
+
17
+ # Use the model to perform OCR
18
+ inputs = tokenizer(image, return_tensors="pt")
19
+ with torch.no_grad():
20
+ outputs = model(**inputs)
21
+
22
+ # Get the predictions
23
+ predictions = outputs.logits.argmax(dim=1).item()
24
+ return predictions
25
 
26
+ # Create the Gradio interface
27
  iface = gr.Interface(
28
  fn=perform_ocr,
29
+ inputs=gr.inputs.Image(type="pil"),
30
  outputs="text",
31
+ title="OCR with GOT-OCR2.0",
32
+ description="Upload an image for Optical Character Recognition."
33
  )
34
 
35
+ # Launch the interface
36
+ iface.launch()