root-sajjan commited on
Commit
964fedc
·
verified ·
1 Parent(s): 39acce6

updated error handling

Browse files
Files changed (1) hide show
  1. llm/inference.py +69 -2
llm/inference.py CHANGED
@@ -12,7 +12,7 @@ nltk.download('averaged_perceptron_tagger')
12
 
13
  client = InferenceClient(api_key=api_key)
14
 
15
-
16
  def extract_product_info(text):
17
  print(f'Extract function called!')
18
  # Initialize result dictionary
@@ -57,7 +57,74 @@ def extract_product_info(text):
57
  result["description"] = " ".join(description_parts)
58
  print(f'extract function returned:\n{result}')
59
  return result
60
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def extract_info(text):
 
12
 
13
  client = InferenceClient(api_key=api_key)
14
 
15
+ '''
16
  def extract_product_info(text):
17
  print(f'Extract function called!')
18
  # Initialize result dictionary
 
57
  result["description"] = " ".join(description_parts)
58
  print(f'extract function returned:\n{result}')
59
  return result
60
+ '''
61
+ def extract_product_info(text):
62
+ print(f"Extract function called with input: {text}")
63
+
64
+ # Initialize result dictionary
65
+ result = {"brand": None, "model": None, "description": None, "price": None}
66
+
67
+ try:
68
+ # Extract price using regex
69
+ price_match = re.search(r'\$\s?\d{1,3}(?:,\d{3})*(?:\.\d{2})?', text)
70
+ print(f"Price match: {price_match}")
71
+ if price_match:
72
+ result["price"] = price_match.group().replace("$", "").replace(",", "").strip()
73
+ # Remove the price part from the text to prevent interference
74
+ text = text.replace(price_match.group(), "").strip()
75
+ print(f"Text after removing price: {text}")
76
+
77
+ # Tokenize the remaining text
78
+ try:
79
+ tokens = nltk.word_tokenize(text)
80
+ print(f"Tokens: {tokens}")
81
+ except Exception as e:
82
+ print(f"Error during tokenization: {e}")
83
+ # Fall back to a simple split if tokenization fails
84
+ tokens = text.split()
85
+ print(f"Fallback tokens: {tokens}")
86
+
87
+ # POS tagging
88
+ try:
89
+ pos_tags = nltk.pos_tag(tokens)
90
+ print(f"POS Tags: {pos_tags}")
91
+ except Exception as e:
92
+ print(f"Error during POS tagging: {e}")
93
+ # If POS tagging fails, create dummy tags
94
+ pos_tags = [(word, "NN") for word in tokens]
95
+ print(f"Fallback POS Tags: {pos_tags}")
96
+
97
+ # Extract brand, model, and description
98
+ brand_parts = []
99
+ model_parts = []
100
+ description_parts = []
101
+
102
+ for word, tag in pos_tags:
103
+ if tag == 'NNP' or re.match(r'[A-Za-z0-9-]+', word):
104
+ if len(brand_parts) == 0: # Assume the first proper noun is the brand
105
+ brand_parts.append(word)
106
+ else: # Model number tends to follow the brand
107
+ model_parts.append(word)
108
+ else:
109
+ description_parts.append(word)
110
+
111
+ # Assign values to the result dictionary
112
+ if brand_parts:
113
+ result["brand"] = " ".join(brand_parts)
114
+ if model_parts:
115
+ result["model"] = " ".join(model_parts)
116
+ if description_parts:
117
+ result["description"] = " ".join(description_parts)
118
+
119
+ print(f"Extract function returned: {result}")
120
+
121
+ except Exception as e:
122
+ print(f"Unexpected error: {e}")
123
+ # Return a fallback result in case of a critical error
124
+ result["description"] = text
125
+ print(f"Fallback result: {result}")
126
+
127
+ return result
128
 
129
 
130
  def extract_info(text):