nileshhanotia commited on
Commit
c52a0cb
·
verified ·
1 Parent(s): c92f9e6

Create nl_converter.py

Browse files
Files changed (1) hide show
  1. nl_converter.py +117 -0
nl_converter.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import requests
3
+ import json
4
+ from typing import Dict, Any, List
5
+ import spacy
6
+ from spacy.cli import download
7
+
8
+ class NLConverter:
9
+ def __init__(self, api_key: str):
10
+ self.api_key = api_key
11
+ self.api_url = "https://api.groq.com/openai/v1/chat/completions"
12
+ self.headers = {
13
+ "Authorization": f"Bearer {api_key}",
14
+ "Content-Type": "application/json"
15
+ }
16
+
17
+ # Attempt to load spaCy model, and install if missing
18
+ try:
19
+ self.nlp = spacy.load("en_core_web_sm")
20
+ except OSError:
21
+ logging.info("Model 'en_core_web_sm' not found. Downloading...")
22
+ download("en_core_web_sm")
23
+ self.nlp = spacy.load("en_core_web_sm")
24
+
25
+ logging.basicConfig(level=logging.INFO)
26
+ self.logger = logging.getLogger(__name__)
27
+
28
+ def extract_entities(self, text: str) -> List[Dict[str, Any]]:
29
+ """Use spaCy to extract named entities from text."""
30
+ doc = self.nlp(text)
31
+ entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]
32
+ self.logger.info(f"Extracted entities: {entities}")
33
+
34
+ # Print extracted entities
35
+ for entity in entities:
36
+ print(f"Entity: {entity['text']} - Label: {entity['label']}")
37
+
38
+ return entities
39
+
40
+ def convert_to_natural_language(self, query_result: Dict[str, Any], original_query: str) -> Dict[str, Any]:
41
+ """Convert query results to natural language using LLaMA-3 and spaCy for NER."""
42
+ if not query_result["success"]:
43
+ self.logger.error("Query execution failed; no results to process.")
44
+ return {"success": False, "error": "No results to process"}
45
+
46
+ # Format the query result data
47
+ formatted_data = [dict(zip(query_result["columns"], row)) for row in query_result["results"]]
48
+
49
+ # Convert formatted data to a string for entity extraction
50
+ formatted_text = "\n".join([str(row) for row in formatted_data])
51
+
52
+ # Extract named entities from the query result
53
+ entities = self.extract_entities(formatted_text)
54
+
55
+ # Prepare system and user prompts
56
+ system_prompt = (
57
+ "You are a data interpreter that uses named entities to create a clear, natural language explanation. "
58
+ "Your job is to make sense of the given entities, summarize key insights, and answer the original question."
59
+ )
60
+
61
+ # Include the original query and the extracted entities in the user prompt
62
+ user_prompt = f"""
63
+ Original question: {original_query}
64
+ Extracted Entities:
65
+ {json.dumps(entities, indent=2)}
66
+ Data Summary:
67
+ {formatted_text}
68
+ Based on this information, generate a natural language explanation of the query results.
69
+ """
70
+
71
+ try:
72
+ # Prepare the payload to send to the API
73
+ payload = {
74
+ "model": "llama3-8b-8192", # Adjust model name if necessary
75
+ "messages": [
76
+ {"role": "system", "content": system_prompt},
77
+ {"role": "user", "content": user_prompt}
78
+ ],
79
+ "max_tokens": 500,
80
+ "temperature": 0.3
81
+ }
82
+
83
+ # Send the request to the API
84
+ response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=30)
85
+ response.raise_for_status()
86
+
87
+ result = response.json()
88
+ if 'choices' in result and result['choices']:
89
+ explanation = result['choices'][0]['message']['content'].strip()
90
+ self.logger.info(f"Generated natural language explanation: {explanation}")
91
+ return {"success": True, "explanation": explanation}
92
+
93
+ return {"success": False, "error": "Failed to generate explanation"}
94
+
95
+ except requests.exceptions.RequestException as e:
96
+ self.logger.error(f"API request error: {str(e)}")
97
+ return {"success": False, "error": f"API request failed: {str(e)}"}
98
+
99
+ # Example Usage
100
+ if __name__ == "__main__":
101
+ api_key = "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"
102
+ query_result = {
103
+ "success": True,
104
+ "columns": ["order_id", "total_price", "order_date"],
105
+ "results": [
106
+ [1001, 150, "2024-10-01"],
107
+ [1002, 200, "2024-10-02"]
108
+ ]
109
+ }
110
+ original_query = "Show me the orders with total price greater than 100"
111
+
112
+ nl_converter = NLConverter(api_key)
113
+ result = nl_converter.convert_to_natural_language(query_result, original_query)
114
+ if result["success"]:
115
+ print(result["explanation"])
116
+ else:
117
+ print(f"Error: {result['error']}")