File size: 5,530 Bytes
cbfe773
 
 
 
6a1b293
 
 
cbfe773
 
00821bd
cbfe773
 
 
 
6a1b293
 
 
 
 
 
 
feae4d7
 
6a1b293
 
 
 
 
feae4d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a1b293
 
 
 
 
 
 
 
00821bd
feae4d7
 
 
 
 
 
 
 
 
 
 
6a1b293
cbfe773
6a1b293
 
 
00821bd
6a1b293
 
 
00821bd
6a1b293
cbfe773
 
 
 
 
 
 
6a1b293
00821bd
cbfe773
6a1b293
00821bd
cbfe773
00821bd
 
 
feae4d7
00821bd
 
 
feae4d7
00821bd
feae4d7
 
 
00821bd
 
 
 
cbfe773
 
6a1b293
cbfe773
 
 
 
6a1b293
 
 
 
 
 
00821bd
6a1b293
 
 
 
feae4d7
 
6a1b293
 
 
00821bd
cbfe773
 
 
 
 
 
6a1b293
cbfe773
 
00821bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
from PIL import Image
import io
import base64
import requests
import os

# Page configuration
st.set_page_config(
    page_title="Vision OCR",
    page_icon="πŸ”Ž",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Set up Hugging Face API
HF_API_KEY = os.environ.get("HF_API_KEY", "")  # Get API key from environment variable
if not HF_API_KEY:
    HF_API_KEY = st.secrets.get("HF_API_KEY", "")  # Try getting from Streamlit secrets

# Hugging Face API function
def process_image_with_hf(image_bytes, model_id):
    API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
    headers = {"Authorization": f"Bearer {HF_API_KEY}"}
    
    # Convert image to base64
    image_b64 = base64.b64encode(image_bytes).decode('utf-8')
    
    # Prepare payload based on model type
    if "llava" in model_id.lower():
        payload = {
            "inputs": {
                "image": image_b64,
                "prompt": """Analyze the text in the provided image. Extract all readable content
                        and present it in a structured Markdown format that is clear, concise, 
                        and well-organized. Ensure proper formatting (e.g., headings, lists, or
                        code blocks) as necessary to represent the content effectively."""
            },
            "parameters": {
                "max_new_tokens": 1024
            }
        }
    else:
        # Generic payload format for other models
        payload = {
            "inputs": {
                "image": image_b64,
                "text": """Analyze the text in the provided image. Extract all readable content
                        and present it in a structured Markdown format that is clear, concise, 
                        and well-organized. Ensure proper formatting (e.g., headings, lists, or
                        code blocks) as necessary to represent the content effectively."""
            }
        }
    
    # Make API request
    response = requests.post(API_URL, headers=headers, json=payload)
    
    if response.status_code != 200:
        raise Exception(f"API request failed with status code {response.status_code}: {response.text}")
    
    # Handle different response formats
    response_json = response.json()
    if isinstance(response_json, list):
        return response_json[0]["generated_text"]
    elif isinstance(response_json, dict):
        if "generated_text" in response_json:
            return response_json["generated_text"]
        elif "text" in response_json:
            return response_json["text"]
    
    # Fallback
    return str(response_json)

# Title and description in main area
try:
    # Try to load the image from assets folder
    st.markdown("""
        # <img src="data:image/png;base64,{}" width="50" style="vertical-align: -12px;"> Vision OCR
    """.format(base64.b64encode(open("./assets/gemma3.png", "rb").read()).decode()), unsafe_allow_html=True)
except FileNotFoundError:
    # Fallback if image doesn't exist
    st.title("Vision OCR")

# Add clear button to top right
col1, col2 = st.columns([6,1])
with col2:
    if st.button("Clear πŸ—‘οΈ"):
        if 'ocr_result' in st.session_state:
            del st.session_state['ocr_result']
        st.rerun()

st.markdown('<p style="margin-top: -20px;">Extract structured text from images using advanced vision models!</p>', unsafe_allow_html=True)
st.markdown("---")

# Add model selection
with st.sidebar:
    st.header("Settings")
    model_option = st.selectbox(
        "Select Vision Model",
        ["LLaVA-1.5-7B", "MiniGPT-4", "Idefics"],
        index=0
    )
    
    # Updated model mapping with confirmed working models
    model_mapping = {
        "LLaVA-1.5-7B": "llava-hf/llava-1.5-7b-hf",
        "MiniGPT-4": "Vision-CAIR/MiniGPT-4",
        "Idefics": "HuggingFaceM4/idefics-9b-instruct"
    }
    
    selected_model = model_mapping[model_option]
    
    st.header("Upload Image")
    uploaded_file = st.file_uploader("Choose an image...", type=['png', 'jpg', 'jpeg'])
    
    if uploaded_file is not None:
        # Display the uploaded image
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded Image")
        
        # Check if API key is available
        if not HF_API_KEY:
            st.error("Hugging Face API key is missing. Please set it as an environment variable or in Streamlit secrets.")
        else:
            if st.button("Extract Text πŸ”", type="primary"):
                with st.spinner(f"Processing image with {model_option}..."):
                    try:
                        # Get image bytes
                        img_bytes = uploaded_file.getvalue()
                        
                        # Process with Hugging Face API using selected model
                        result = process_image_with_hf(img_bytes, selected_model)
                        st.session_state['ocr_result'] = result
                    except Exception as e:
                        st.error(f"Error processing image: {str(e)}")
                        st.info("Try selecting a different model from the dropdown.")

# Main content area for results
if 'ocr_result' in st.session_state:
    st.markdown(st.session_state['ocr_result'])
else:
    st.info("Upload an image and click 'Extract Text' to see the results here.")

# Footer
st.markdown("---")
st.markdown("Made with ❀️ using Hugging Face Vision Models | [Report an Issue](https://github.com/bulentsoykan/streamlit-OCR-app/issues)")