Temuzin64 commited on
Commit
2d493f2
Β·
verified Β·
1 Parent(s): 877aa9d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +61 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,63 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image, ImageFilter, ImageEnhance
3
+ import tempfile
4
+ import os
5
+ import easyocr
6
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer, pipeline
7
+
8
+ # Load tokenizer and model once at startup with proper config to avoid warnings
9
+ tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", legacy=False, use_fast=False)
10
+ model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
11
+ pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
12
+
13
+ # Preprocess uploaded image to improve OCR accuracy
14
+ def preprocess_image_pillow(image):
15
+ img = image.convert("L") # Grayscale
16
+ width, height = img.size
17
+ img = img.resize((width * 2, height * 2), Image.LANCZOS)
18
+ enhancer = ImageEnhance.Contrast(img)
19
+ img = enhancer.enhance(2.0)
20
+ img = img.filter(ImageFilter.SHARPEN)
21
+ return img
22
+
23
+ # Streamlit App UI
24
+ st.set_page_config(page_title="πŸ“ Telugu OCR & Correction", layout="centered")
25
+ st.title("πŸ“ Telugu Handwriting to Typed Text")
26
+
27
+ uploaded_file = st.file_uploader("πŸ“€ Upload Telugu handwritten image", type=["png", "jpg", "jpeg"])
28
+
29
+ if uploaded_file:
30
+ image = Image.open(uploaded_file).convert("RGB")
31
+ enhanced_image = preprocess_image_pillow(image)
32
+ st.image(enhanced_image, caption="Preprocessed Image", use_container_width=True)
33
+
34
+ # Save temporarily for EasyOCR
35
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
36
+ enhanced_image.save(temp.name)
37
+
38
+ try:
39
+ reader = easyocr.Reader(['te'], gpu=False)
40
+ results = reader.readtext(temp.name)
41
+
42
+ raw_text = "\n".join([text for (_, text, _) in results])
43
+
44
+ st.markdown("### πŸ“„ OCR Extracted Text")
45
+ st.text_area("πŸ“ Telugu OCR", raw_text, height=150)
46
+
47
+ # Generate correction using mT5
48
+ if raw_text.strip():
49
+ st.markdown("### βœ… LLM Corrected Telugu Text")
50
+ prompt = f"Correct the following Telugu text spelling and grammar:\n{raw_text}"
51
+ try:
52
+ response = pipe(prompt, max_new_tokens=256, do_sample=False)[0]['generated_text']
53
+ st.text_area("πŸ€– Corrected Text", response, height=150)
54
+ st.download_button("⬇️ Download", response, file_name="corrected_telugu.txt")
55
+ except Exception as e:
56
+ st.error(f"LLM Correction Error: {e}")
57
+ else:
58
+ st.warning("OCR did not extract any usable Telugu text.")
59
+ finally:
60
+ # Always remove the temp file
61
+ if os.path.exists(temp.name):
62
+ os.remove(temp.name)
63