chrome_models / 8 /test.py
dejanseo's picture
Upload test.py
ff2efcf verified
raw
history blame contribute delete
13.3 kB
#!/usr/bin/env python3
"""
Entity extraction script using a proper embedding model with correctly shaped embeddings.
This script uses a pre-trained word embedding model to generate embeddings in the exact
shape required by the TFLite model (64x32).
Fixed to handle random seed error.
"""
import numpy as np
import tensorflow as tf
import re
import os
import traceback
import nltk
from nltk.tokenize import word_tokenize
# Hardcoded paths - these should match your file locations
MODEL_PATH = "model.tflite"
WORD_EMBEDDINGS_PATH = "word_embeddings" # Not used for embedding, kept for reference
ENTITIES_METADATA_PATH = "global-entities_metadata"
ENTITIES_NAMES_PATH = "global-entities_names"
# Hardcoded sample text
SAMPLE_TEXT = "Zendesk is a customer service platform used by companies like Shopify, Airbnb, and Slack to manage support tickets, automate workflows, and provide omnichannel communication through email, chat, phone, and social media."
# Constants
MAX_WORDS = 64
MAX_CANDIDATES = 32
EMBEDDING_DIM = 32
class EntityExtractor:
def __init__(self, verbose=True):
"""Initialize the entity extractor with a pre-trained embedding model."""
self.model_path = MODEL_PATH
self.verbose = verbose
# Load TFLite model
self.interpreter = self.load_model()
# Load pre-trained embedding model
self.embedding_model = self.load_embedding_model()
# Get input and output details
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
if self.verbose:
print(f"TFLite model loaded with {len(self.input_details)} inputs and {len(self.output_details)} outputs")
print(f"Pre-trained embedding model loaded")
print("Input details:")
for detail in self.input_details:
print(f" - {detail['name']} (index: {detail['index']}, shape: {detail['shape']}, dtype: {detail['dtype']})")
def load_model(self):
"""Load the TFLite model."""
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model file not found: {self.model_path}")
interpreter = tf.lite.Interpreter(model_path=self.model_path)
interpreter.allocate_tensors()
return interpreter
def load_embedding_model(self):
"""
Load a pre-trained embedding model.
For this implementation, we'll use a small pre-trained model.
"""
try:
# Try to download NLTK data if not already present
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Create a simple embedding dictionary for demonstration
embedding_dict = {}
# Add some common words with random embeddings
common_words = ["google", "is", "a", "search", "engine", "company", "based", "in", "the", "usa",
"and", "of", "to", "for", "with", "on", "by", "at", "from", "as"]
# Create random but consistent embeddings
np.random.seed(42) # For reproducibility
for word in common_words:
# Create a random embedding vector
embedding = np.random.rand(EMBEDDING_DIM)
# Normalize to unit length
embedding = embedding / np.linalg.norm(embedding)
# Scale to uint8 range and convert
embedding = (embedding * 255).astype(np.uint8)
embedding_dict[word] = embedding
if self.verbose:
print(f"Created embedding dictionary with {len(embedding_dict)} words")
return embedding_dict
except Exception as e:
if self.verbose:
print(f"Error loading embedding model: {str(e)}")
print("Using fallback embedding approach")
# Fallback to a very simple embedding approach
embedding_dict = {}
return embedding_dict
def get_word_embedding(self, word):
"""
Get embedding for a word from the pre-trained model.
If the word is not in the vocabulary, use a fallback approach.
"""
word_lower = word.lower()
# Try to get embedding from the model
if word_lower in self.embedding_model:
return self.embedding_model[word_lower]
# Fallback: create a deterministic embedding based on the word
# This ensures consistency for unknown words
# Fix: Ensure the hash value is a valid seed (between 0 and 2**32-1)
hash_value = abs(hash(word_lower)) % (2**32 - 1)
np.random.seed(hash_value)
embedding = np.random.rand(EMBEDDING_DIM)
embedding = embedding / np.linalg.norm(embedding)
embedding = (embedding * 255).astype(np.uint8)
return embedding
def tokenize_text(self, text):
"""
Tokenize text into words using NLTK.
Returns a list of words and their positions in the original text.
"""
# Use NLTK for better tokenization
words = word_tokenize(text)
# Get positions (approximate since NLTK doesn't return positions)
positions = []
start_pos = 0
for word in words:
# Find the word in the text starting from the current position
word_pos = text.find(word, start_pos)
if word_pos != -1:
positions.append((word_pos, word_pos + len(word)))
start_pos = word_pos + len(word)
else:
# Fallback if the exact word can't be found
positions.append((start_pos, start_pos + len(word)))
start_pos += len(word) + 1
if self.verbose:
print(f"Tokenized text into {len(words)} words: {words}")
return words, positions
def get_word_embeddings_matrix(self, words):
"""
Get embeddings for a list of words.
Returns a matrix of shape (MAX_WORDS, EMBEDDING_DIM) with uint8 values.
"""
# Initialize the result matrix with zeros
result = np.zeros((MAX_WORDS, EMBEDDING_DIM), dtype=np.uint8)
# Fill the matrix with embeddings for each word
for i, word in enumerate(words[:MAX_WORDS]):
result[i] = self.get_word_embedding(word)
if self.verbose:
print(f"Created word embeddings matrix with shape {result.shape}")
return result
def find_entity_candidates(self, words, positions):
"""
Find potential entity candidates in the text.
Returns a list of candidate ranges (start_idx, end_idx).
"""
candidates = []
# Look for capitalized words as potential entities
for i, word in enumerate(words):
if i < len(words) and word[0].isupper():
# Single word entity
candidates.append((i, i+1))
# Look for multi-word entities (up to 3 words)
for j in range(1, min(3, len(words) - i)):
candidates.append((i, i+j+1))
# Limit to MAX_CANDIDATES
candidates = candidates[:MAX_CANDIDATES]
if self.verbose:
print(f"Found {len(candidates)} entity candidates:")
for start, end in candidates:
if start < len(words) and end <= len(words):
print(f" - {' '.join(words[start:end])}")
return candidates
def prepare_model_inputs(self, words, candidates, word_embeddings_matrix):
"""
Prepare inputs for the model.
Returns a dictionary of input tensors.
"""
num_words = min(len(words), MAX_WORDS)
num_candidates = min(len(candidates), MAX_CANDIDATES)
# Prepare ranges input
ranges_input = np.zeros((MAX_CANDIDATES, 2), dtype=np.int32)
for i, (start, end) in enumerate(candidates[:MAX_CANDIDATES]):
ranges_input[i][0] = start
ranges_input[i][1] = end
# Prepare capitalization input (1 if capitalized, 0 otherwise)
capitalization_input = np.zeros(MAX_CANDIDATES, dtype=np.int32)
for i, (start, _) in enumerate(candidates[:MAX_CANDIDATES]):
if start < len(words) and words[start][0].isupper():
capitalization_input[i] = 1
# Prepare priors input (simplified)
priors_input = np.ones(MAX_CANDIDATES, dtype=np.float32) * 0.5
# Prepare entity embeddings (simplified)
entity_embeddings_input = np.zeros((MAX_CANDIDATES, EMBEDDING_DIM), dtype=np.uint8)
# Prepare candidate links (simplified)
candidate_links_input = np.zeros((MAX_CANDIDATES, MAX_CANDIDATES), dtype=np.float32)
# Prepare aggregated entity links (simplified)
aggregated_entity_links_input = np.zeros(MAX_CANDIDATES, dtype=np.float32)
# Create input dictionary
inputs = {}
# Map inputs to the correct input tensor indices
for detail in self.input_details:
name = detail['name']
index = detail['index']
if 'word_embeddings' in name:
inputs[index] = word_embeddings_matrix
elif 'num_words' in name:
inputs[index] = np.array([num_words], dtype=np.int32)
elif 'num_candidates' in name:
inputs[index] = np.array([num_candidates], dtype=np.int32)
elif 'ranges' in name:
inputs[index] = ranges_input
elif 'capitalization' in name:
inputs[index] = capitalization_input
elif 'priors' in name:
inputs[index] = priors_input
elif 'entity_embeddings' in name:
inputs[index] = entity_embeddings_input
elif 'candidate_links' in name:
inputs[index] = candidate_links_input
elif 'aggregated_entity_links' in name:
inputs[index] = aggregated_entity_links_input
return inputs
def run_model(self, inputs):
"""
Run the model with the prepared inputs.
Returns the model output (entity scores).
"""
# Set input tensors
for index, tensor in inputs.items():
self.interpreter.set_tensor(index, tensor)
# Run inference
self.interpreter.invoke()
# Get output tensor
output_index = self.output_details[0]['index']
output = self.interpreter.get_tensor(output_index)
if self.verbose:
print(f"Model output shape: {output.shape}")
return output
def extract_entities(self, text, threshold=0.5):
"""
Extract entities from text using the model.
Returns a list of entity dictionaries with text, score, and position.
"""
# Tokenize text
words, positions = self.tokenize_text(text)
# Find entity candidates
candidates = self.find_entity_candidates(words, positions)
# Get word embeddings matrix with correct shape (64x32)
word_embeddings_matrix = self.get_word_embeddings_matrix(words)
# Prepare model inputs
inputs = self.prepare_model_inputs(words, candidates, word_embeddings_matrix)
# Run model
scores = self.run_model(inputs)
# Process results
entities = []
for i, (start, end) in enumerate(candidates):
if i < len(scores) and scores[i] > threshold:
if start < len(words) and end <= len(words):
entity_text = " ".join(words[start:end])
entity_pos = (positions[start][0], positions[end-1][1])
entities.append({
"text": entity_text,
"score": float(scores[i]),
"position": entity_pos
})
return entities
def main():
print(f"Analyzing text: {SAMPLE_TEXT}")
try:
# Create entity extractor with verbose output
extractor = EntityExtractor(verbose=True)
# Extract entities from the sample text
entities = extractor.extract_entities(SAMPLE_TEXT, threshold=0.5)
print("\nDetected entities:")
for entity in entities:
print(f"- {entity['text']} (confidence: {entity['score']:.2f}, position: {entity['position']})")
except Exception as e:
print(f"Error: {str(e)}")
traceback.print_exc()
print("\nTroubleshooting tips:")
print("1. Make sure all file paths are correct")
print("2. Check that TensorFlow is installed (pip install tensorflow)")
print("3. Ensure that NLTK is installed (pip install nltk)")
print("4. Verify that the model file is a valid TFLite model")
if __name__ == "__main__":
main()