Pavanb commited on
Commit
17b89fa
·
verified ·
1 Parent(s): 1c75a3d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +75 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from address_extractor import AddressExtractor
3
+ import tempfile
4
+ import os
5
+
6
+ # Instantiate your AddressExtractor class
7
+ address_extractor = AddressExtractor()
8
+
9
+ def extract_from_text(input_text):
10
+ if not input_text.strip():
11
+ return "Error: No text provided."
12
+ messages = [
13
+ {"role": "system", "content": address_extractor.system_prompt_text},
14
+ {"role": "user", "content": input_text},
15
+ ]
16
+ prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
17
+ chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)
18
+
19
+ chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
20
+ generated_text = address_extractor.tokenizer.decode(
21
+ chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
22
+ )
23
+
24
+ return generated_text.strip() or "No address detected."
25
+
26
+ def extract_from_audio(audio_file):
27
+ if audio_file is None:
28
+ return "Error: No audio provided."
29
+
30
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
31
+ tmp_file.write(audio_file.read())
32
+ tmp_file_path = tmp_file.name
33
+
34
+ try:
35
+ segments = address_extractor.whisper_model.transcribe(tmp_file_path)
36
+ input_text = " ".join([seg.text.strip() for seg in segments])
37
+ input_text = address_extractor.preprocess_text(input_text)
38
+
39
+ messages = [
40
+ {"role": "system", "content": address_extractor.system_prompt_speech},
41
+ {"role": "user", "content": input_text},
42
+ ]
43
+ prompt = address_extractor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
44
+ chat_input = address_extractor.tokenizer(prompt, return_tensors="pt").to(address_extractor.bitnet_model.device)
45
+
46
+ chat_outputs = address_extractor.bitnet_model.generate(**chat_input, max_new_tokens=256)
47
+ generated_text = address_extractor.tokenizer.decode(
48
+ chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True
49
+ )
50
+ result = generated_text.strip() or "No address detected."
51
+
52
+ finally:
53
+ os.remove(tmp_file_path)
54
+
55
+ return result
56
+
57
+ # Gradio UI
58
+ with gr.Blocks() as demo:
59
+ gr.Markdown("## 📦 US Address Extractor")
60
+ with gr.Tab("Text Input"):
61
+ text_input = gr.Textbox(lines=3, label="Enter Text")
62
+ text_output = gr.Textbox(label="Extracted Address")
63
+ text_button = gr.Button("Extract Address")
64
+
65
+ text_button.click(fn=extract_from_text, inputs=text_input, outputs=text_output)
66
+
67
+ with gr.Tab("Audio Input (.wav)"):
68
+ audio_input = gr.Audio(source="upload", type="file", label="Upload a .wav Audio File")
69
+ audio_output = gr.Textbox(label="Extracted Address")
70
+ audio_button = gr.Button("Extract Address")
71
+
72
+ audio_button.click(fn=extract_from_audio, inputs=audio_input, outputs=audio_output)
73
+
74
+ demo.launch(server_name="0.0.0.0", server_port=7860)
75
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers @ git+https://github.com/shumingma/transformers.git@21f5a84cc5624b5f058a4dea435877594ba89bad
2
+ accelerate==1.6.0
3
+ pywhispercpp==1.3.0
4
+ gradio==4.26.0