Spaces:
Sleeping
Sleeping
Create nl_converter.py
Browse files- nl_converter.py +117 -0
nl_converter.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
from typing import Dict, Any, List
|
5 |
+
import spacy
|
6 |
+
from spacy.cli import download
|
7 |
+
|
8 |
+
class NLConverter:
|
9 |
+
def __init__(self, api_key: str):
|
10 |
+
self.api_key = api_key
|
11 |
+
self.api_url = "https://api.groq.com/openai/v1/chat/completions"
|
12 |
+
self.headers = {
|
13 |
+
"Authorization": f"Bearer {api_key}",
|
14 |
+
"Content-Type": "application/json"
|
15 |
+
}
|
16 |
+
|
17 |
+
# Attempt to load spaCy model, and install if missing
|
18 |
+
try:
|
19 |
+
self.nlp = spacy.load("en_core_web_sm")
|
20 |
+
except OSError:
|
21 |
+
logging.info("Model 'en_core_web_sm' not found. Downloading...")
|
22 |
+
download("en_core_web_sm")
|
23 |
+
self.nlp = spacy.load("en_core_web_sm")
|
24 |
+
|
25 |
+
logging.basicConfig(level=logging.INFO)
|
26 |
+
self.logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
def extract_entities(self, text: str) -> List[Dict[str, Any]]:
|
29 |
+
"""Use spaCy to extract named entities from text."""
|
30 |
+
doc = self.nlp(text)
|
31 |
+
entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]
|
32 |
+
self.logger.info(f"Extracted entities: {entities}")
|
33 |
+
|
34 |
+
# Print extracted entities
|
35 |
+
for entity in entities:
|
36 |
+
print(f"Entity: {entity['text']} - Label: {entity['label']}")
|
37 |
+
|
38 |
+
return entities
|
39 |
+
|
40 |
+
def convert_to_natural_language(self, query_result: Dict[str, Any], original_query: str) -> Dict[str, Any]:
|
41 |
+
"""Convert query results to natural language using LLaMA-3 and spaCy for NER."""
|
42 |
+
if not query_result["success"]:
|
43 |
+
self.logger.error("Query execution failed; no results to process.")
|
44 |
+
return {"success": False, "error": "No results to process"}
|
45 |
+
|
46 |
+
# Format the query result data
|
47 |
+
formatted_data = [dict(zip(query_result["columns"], row)) for row in query_result["results"]]
|
48 |
+
|
49 |
+
# Convert formatted data to a string for entity extraction
|
50 |
+
formatted_text = "\n".join([str(row) for row in formatted_data])
|
51 |
+
|
52 |
+
# Extract named entities from the query result
|
53 |
+
entities = self.extract_entities(formatted_text)
|
54 |
+
|
55 |
+
# Prepare system and user prompts
|
56 |
+
system_prompt = (
|
57 |
+
"You are a data interpreter that uses named entities to create a clear, natural language explanation. "
|
58 |
+
"Your job is to make sense of the given entities, summarize key insights, and answer the original question."
|
59 |
+
)
|
60 |
+
|
61 |
+
# Include the original query and the extracted entities in the user prompt
|
62 |
+
user_prompt = f"""
|
63 |
+
Original question: {original_query}
|
64 |
+
Extracted Entities:
|
65 |
+
{json.dumps(entities, indent=2)}
|
66 |
+
Data Summary:
|
67 |
+
{formatted_text}
|
68 |
+
Based on this information, generate a natural language explanation of the query results.
|
69 |
+
"""
|
70 |
+
|
71 |
+
try:
|
72 |
+
# Prepare the payload to send to the API
|
73 |
+
payload = {
|
74 |
+
"model": "llama3-8b-8192", # Adjust model name if necessary
|
75 |
+
"messages": [
|
76 |
+
{"role": "system", "content": system_prompt},
|
77 |
+
{"role": "user", "content": user_prompt}
|
78 |
+
],
|
79 |
+
"max_tokens": 500,
|
80 |
+
"temperature": 0.3
|
81 |
+
}
|
82 |
+
|
83 |
+
# Send the request to the API
|
84 |
+
response = requests.post(self.api_url, headers=self.headers, json=payload, timeout=30)
|
85 |
+
response.raise_for_status()
|
86 |
+
|
87 |
+
result = response.json()
|
88 |
+
if 'choices' in result and result['choices']:
|
89 |
+
explanation = result['choices'][0]['message']['content'].strip()
|
90 |
+
self.logger.info(f"Generated natural language explanation: {explanation}")
|
91 |
+
return {"success": True, "explanation": explanation}
|
92 |
+
|
93 |
+
return {"success": False, "error": "Failed to generate explanation"}
|
94 |
+
|
95 |
+
except requests.exceptions.RequestException as e:
|
96 |
+
self.logger.error(f"API request error: {str(e)}")
|
97 |
+
return {"success": False, "error": f"API request failed: {str(e)}"}
|
98 |
+
|
99 |
+
# Example Usage
|
100 |
+
if __name__ == "__main__":
|
101 |
+
api_key = "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"
|
102 |
+
query_result = {
|
103 |
+
"success": True,
|
104 |
+
"columns": ["order_id", "total_price", "order_date"],
|
105 |
+
"results": [
|
106 |
+
[1001, 150, "2024-10-01"],
|
107 |
+
[1002, 200, "2024-10-02"]
|
108 |
+
]
|
109 |
+
}
|
110 |
+
original_query = "Show me the orders with total price greater than 100"
|
111 |
+
|
112 |
+
nl_converter = NLConverter(api_key)
|
113 |
+
result = nl_converter.convert_to_natural_language(query_result, original_query)
|
114 |
+
if result["success"]:
|
115 |
+
print(result["explanation"])
|
116 |
+
else:
|
117 |
+
print(f"Error: {result['error']}")
|