File size: 2,683 Bytes
0da30bd
 
 
 
 
 
 
 
9a8d233
0da30bd
 
 
 
 
 
 
9a8d233
 
 
0da30bd
 
 
 
 
 
 
 
 
 
 
 
9a8d233
 
 
 
 
0da30bd
 
 
 
 
 
 
 
 
 
 
9a8d233
 
 
 
 
0da30bd
9a8d233
0da30bd
 
9a8d233
 
 
 
 
0da30bd
 
9a8d233
899c55b
99ac6dd
0da30bd
9a8d233
99ac6dd
0da30bd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import gradio as gr
import speech_recognition as sr
import pyttsx3
import time
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from happytransformer import HappyTextToText, TTSettings  # Using HappyTransformer

# Load models only once for efficiency
def load_models():
    model_name = "prithivida/grammar_error_correcter_v1"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    happy_tt = HappyTextToText("T5", "prithivida/grammar_error_correcter_v1")  # Using T5-based model
    return tokenizer, model, happy_tt

tokenizer, model, happy_tt = load_models()  # Load models at startup

# Speech-to-text conversion
def transcribe_audio(audio):
    recognizer = sr.Recognizer()
    with sr.AudioFile(audio) as source:
        audio_data = recognizer.record(source)
    try:
        text = recognizer.recognize_google(audio_data)
        return text
    except sr.UnknownValueError:
        return "Could not understand the audio."
    except sr.RequestError as e:
        return f"Speech recognition error: {e}"

# Grammar correction function
def correct_grammar(text):
    if not text.strip():
        return "No input provided.", 0, "No correction available."

    inputs = tokenizer.encode("gec: " + text, return_tensors="pt", max_length=128, truncation=True)
    with torch.no_grad():
        outputs = model.generate(inputs, max_length=128, num_return_sequences=1)
    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    args = TTSettings(num_beams=5, min_length=1)
    correction = happy_tt.generate_text("gec: " + text, args=args).text  # Better correction method
    grammar_score = 100 - abs(len(text) - len(correction))  # Scoring based on text change ratio
    
    return corrected_text, grammar_score, correction

# Unified function for both speech and text input
def process_input(audio, text):
    if audio:  
        text = transcribe_audio(audio)  # If audio is provided, transcribe it
    return correct_grammar(text)

# Gradio UI
def main():
    iface = gr.Interface(
        fn=process_input,
        inputs=[
            gr.Audio(sources=["microphone"], type="filepath", label="Speak your sentence"),
            gr.Textbox(placeholder="Or type here if not speaking...", label="Text Input"),
        ],
        outputs=["text", "number", "text"],
        title="AI Grammar Checker",
        description="Speak or type a sentence to check its grammar, get corrections, and see a score.",
        live=False,  # Only processes when user submit
        api_name = "/predict",
    )

    iface.launch()

if __name__ == "__main__":
    main()