Pavanb commited on
Commit
edf14f4
·
verified ·
1 Parent(s): 17b89fa

Upload address_extractor.py

Browse files
Files changed (1) hide show
  1. address_extractor.py +155 -0
address_extractor.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import time, sys
4
+ from os import path, listdir
5
+ from pywhispercpp.model import Model
6
+
7
+
8
+ class AddressExtractor():
9
+ def __init__(self):
10
+ model_id = "microsoft/bitnet-b1.58-2B-4T"
11
+
12
+ # Load tokenizer and model
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ self.bitnet_model = AutoModelForCausalLM.from_pretrained(
15
+ model_id,
16
+ torch_dtype=torch.bfloat16,
17
+ device_map = "cpu",
18
+ )
19
+
20
+ # Set pad_token_id to eos_token_id
21
+ self.tokenizer.pad_token = self.tokenizer.eos_token
22
+ self.bitnet_model.config.pad_token_id = self.tokenizer.pad_token_id
23
+
24
+
25
+ self.whisper_model = Model('small.en-q5_1', n_threads = 16, language = 'en')
26
+ # self.whisper_model = Model('small.en', n_threads = 16, language = 'en')
27
+ # self.whisper_model = Model('tiny.en', n_threads = 16, language = 'en')
28
+
29
+ self.system_prompt_speech = """
30
+ Your task is to extract the US address given the ASR inferred text (using whisper-large-v3-turbo model) without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content.
31
+ """
32
+
33
+ self.system_prompt_text = """
34
+ Your task is to extract the US address given the input text without generating any additional text description. Only extract the address related entities and generate the final address from the extracted content.
35
+ """
36
+
37
+ # self.sample_files_path = "./one_sentence_us_address/"
38
+
39
+
40
+
41
+ def compute_latency(self, start_time, end_time):
42
+ tr_duration= end_time-start_time
43
+ hours = tr_duration // 3600
44
+ minutes = (tr_duration - (hours * 3600)) // 60
45
+ seconds = tr_duration - ((hours * 3600) + (minutes * 60))
46
+ msg = f'inference elapsed time was {str(hours)} hours, {minutes:4.1f} minutes, {seconds:4.2f} seconds'
47
+
48
+ return msg
49
+
50
+
51
+ def infer_text_sample(self, input_text):
52
+
53
+ messages = [
54
+ {"role": "system", "content": self.system_prompt_text},
55
+ {"role": "user", "content": input_text},
56
+ ]
57
+
58
+ if input_text.lower().strip() != "":
59
+ prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
60
+ chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device)
61
+
62
+ # Generate response
63
+ chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256)
64
+ generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part
65
+
66
+ if generated_text.strip() != "":
67
+ print("\n\n", "="*100)
68
+ print("Address Extracted: ", generated_text)
69
+ print("="*100, "\n\n")
70
+
71
+
72
+ def preprocess_text(self, input_text):
73
+ ### Preprocessing the ASR generated text
74
+ input_tokens = []
75
+ for word in input_text.split(" "):
76
+ word = word.strip()
77
+ if word != "":
78
+ if "," in word:
79
+ try:
80
+ num = int(word)
81
+ word = word.replace(",", " ")
82
+ except:
83
+ word = word.replace(",", ", ")
84
+ input_tokens.append(word)
85
+ input_text = " ".join(input_tokens)
86
+
87
+ return input_text
88
+
89
+
90
+
91
+ def infer_audio_sample(self, audio_input_file_path):
92
+ input_text = ""
93
+ segments = self.whisper_model.transcribe(audio_input_file_path)
94
+ for segment in segments:
95
+ input_text += segment.text.strip()
96
+
97
+ input_text = self.preprocess_text(input_text)
98
+
99
+ print("\n\n", "="*100)
100
+ print("Transcribe Text: ", input_text)
101
+ print("="*100, "\n")
102
+
103
+ messages = [
104
+ {"role": "system", "content": self.system_prompt_speech},
105
+ {"role": "user", "content": input_text},
106
+ ]
107
+
108
+ if input_text.lower().strip() != "":
109
+ prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
+ chat_input = self.tokenizer(prompt, return_tensors="pt").to(self.bitnet_model.device)
111
+
112
+ # Generate response
113
+ chat_outputs = self.bitnet_model.generate(**chat_input, max_new_tokens=256)
114
+ generated_text = self.tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Decode only the response part
115
+
116
+ if generated_text.strip() != "":
117
+ print("\n\n", "="*100)
118
+ print("Address Extracted: ", generated_text)
119
+ print("="*100, "\n\n")
120
+
121
+
122
+
123
+ def main():
124
+
125
+ address_extract = AddressExtractor()
126
+
127
+ input_data = ""
128
+ while input_data.strip() != "exit":
129
+ input_data = input("Paste audio path or Text (type `exit` to quit): ")
130
+
131
+ if input_data.strip() == "exit":
132
+ sys.exit(0)
133
+
134
+ audio_path = ""
135
+ input_text = ""
136
+ if input_data.strip().endswith(".wav"):
137
+ audio_path = input_data.strip()
138
+ if not path.exists(audio_path):
139
+ print(f"Error: The audio file '{audio_path}' does not exist.")
140
+
141
+ else:
142
+ address_extract.infer_audio_sample(audio_path)
143
+
144
+
145
+ elif input_data.strip() != "":
146
+ input_text = input_data.strip()
147
+ address_extract.infer_text_sample(input_text)
148
+
149
+
150
+ else:
151
+ print("Error: Please provide the valid input")
152
+
153
+
154
+ if __name__ == "__main__":
155
+ main()