nileshhanotia commited on
Commit
c92f9e6
·
verified ·
1 Parent(s): b70ad78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -85
app.py CHANGED
@@ -1,124 +1,191 @@
1
- import streamlit as st
2
- import pandas as pd
3
  import os
 
 
 
 
 
4
  from dotenv import load_dotenv
 
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
- import logging
7
 
8
- # Set up logging
9
- logging.basicConfig(level=logging.INFO)
10
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- class LLMService:
13
- def __init__(self, db_path):
14
- self.db_path = db_path
15
- # Load tokenizer and model
16
- self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-72B-Instruct")
17
- self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-72B-Instruct")
18
 
19
- def convert_to_sql_query(self, natural_query):
20
- try:
21
- # Tokenize input
22
- inputs = self.tokenizer(f"Translate this to SQL: {natural_query}", return_tensors="pt")
23
- # Generate output
24
- outputs = self.model.generate(**inputs, max_length=512, num_beams=5)
25
- # Decode output
26
- sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
27
- return {"success": True, "query": sql_query}
28
  except Exception as e:
29
- logger.error(f"Error generating SQL query: {e}")
30
  return {"success": False, "error": str(e)}
31
 
32
- def execute_query(self, sql_query):
 
 
 
 
 
 
33
  try:
34
- import sqlite3
35
- conn = sqlite3.connect(self.db_path)
36
- cursor = conn.cursor()
37
- cursor.execute(sql_query)
38
- results = cursor.fetchall()
39
- columns = [desc[0] for desc in cursor.description]
40
- conn.close()
41
- return {"success": True, "results": results, "columns": columns}
42
- except Exception as e:
43
- logger.error(f"Error executing SQL query: {e}")
 
 
 
 
 
 
44
  return {"success": False, "error": str(e)}
45
 
46
  def main():
47
- st.title("Natural Language to SQL Query Converter")
48
- st.write("Enter your question about the database in natural language.")
49
-
50
- # Load environment variables
51
  load_dotenv()
52
- db_path = os.getenv("DB_PATH")
53
-
54
- if not db_path:
55
- st.error("Missing database path in environment variables.")
56
- logger.error("DB path not found in environment variables.")
 
 
 
57
  return
58
 
59
- # Initialize LLM Service
60
  try:
61
- llm_service = LLMService(db_path=db_path)
62
  except Exception as e:
63
  st.error(f"Error initializing service: {str(e)}")
64
  return
65
 
66
- # Input for natural language query
67
- natural_query = st.text_area("Enter your query", "Show me all albums by artist 'Queen'", height=100)
68
 
69
- if st.button("Generate and Execute Query"):
70
  if not natural_query.strip():
71
  st.warning("Please enter a valid query.")
72
  return
73
 
74
- # Convert to SQL
75
- with st.spinner("Generating SQL query..."):
76
- sql_result = llm_service.convert_to_sql_query(natural_query)
77
 
78
- if not sql_result["success"]:
79
- st.error(f"Error generating SQL query: {sql_result['error']}")
80
  return
81
 
82
- # Display generated SQL
83
- st.subheader("Generated SQL Query:")
84
- st.code(sql_result["query"], language="sql")
85
 
86
- # Execute query
87
  with st.spinner("Executing query..."):
88
- query_result = llm_service.execute_query(sql_result["query"])
89
 
90
  if not query_result["success"]:
91
  st.error(f"Error executing query: {query_result['error']}")
92
  return
93
 
94
- # Check if there are results
95
- if query_result["results"]:
96
- df = pd.DataFrame(query_result["results"], columns=query_result["columns"])
97
-
98
- # Create a collapsible DataFrame using Streamlit's expander
99
- with st.expander("Click to view query results as a DataFrame"):
100
- st.dataframe(df)
101
-
102
- # Extract product details from the JSON result and display them
103
- json_results = df.to_dict(orient='records')
104
- if "title" in json_results[0] and "images" in json_results[0] and "price" in json_results[0]:
105
- st.subheader("Product Details:")
106
- for product in json_results:
107
- price = product.get("price", "N/A")
108
- title = product.get("handle", "N/A")
109
- src = product.get("src", "N/A")
110
-
111
- # Display product details in a neat format using columns for alignment
112
- with st.container():
113
- col1, col2, col3 = st.columns([1, 2, 3]) # Adjust column widths as needed
114
-
115
- with col1:
116
- st.image(src, use_container_width=True) # Display product image with container width
117
- with col2:
118
- st.write(f"**Price:** {price}") # Display price
119
- st.write(f"**Title:** {title}") # Display title
120
- with col3:
121
- st.write(f"**Image Source:** [Link]( {src} )") # Link to the image if needed
122
  else:
123
  st.info("No results found.")
124
 
 
 
 
1
  import os
2
+ import logging
3
+ import requests
4
+ import json
5
+ from typing import Dict, Any, List
6
+ from dataclasses import dataclass
7
  from dotenv import load_dotenv
8
+ import streamlit as st
9
+ import pandas as pd
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
11
 
12
+ @dataclass
13
+ class GraphQLSchemaType:
14
+ """Store GraphQL type information including fields and relationships"""
15
+ name: str
16
+ fields: List[Dict[str, Any]]
17
+ relationships: List[Dict[str, str]]
18
+
19
+ class ShopifyGraphQLConverter:
20
+ def __init__(self, shop_url: str, access_token: str, api_key: str, model_name: str):
21
+ """
22
+ Initialize Shopify GraphQL converter
23
+
24
+ :param shop_url: Shopify store URL
25
+ :param access_token: Shopify Admin API access token
26
+ :param api_key: LLM service API key
27
+ :param model_name: Model name for Hugging Face
28
+ """
29
+ load_dotenv()
30
+
31
+ # Ensure shop URL has https:// scheme
32
+ if not shop_url.startswith(('http://', 'https://')):
33
+ shop_url = f'https://{shop_url}'
34
+
35
+ # Shopify GraphQL endpoint configuration
36
+ self.shop_url = shop_url
37
+ self.graphql_endpoint = f"{shop_url}/admin/api/2024-04/graphql.json"
38
+ self.access_token = access_token
39
+
40
+ # LLM API configuration
41
+ self.api_key = api_key
42
+ self.llm_api_url = "https://api.groq.com/openai/v1/chat/completions"
43
+ self.llm_headers = {
44
+ "Authorization": f"Bearer {api_key}",
45
+ "Content-Type": "application/json"
46
+ }
47
+
48
+ # Load model directly for natural language processing
49
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
51
+
52
+ # Predefined schema for Shopify resources
53
+ self.schema = {
54
+ "Product": GraphQLSchemaType(
55
+ name="Product",
56
+ fields=[
57
+ {"name": "id", "type": "ID", "required": False},
58
+ {"name": "title", "type": "String", "required": False},
59
+ {"name": "description", "type": "String", "required": False},
60
+ {"name": "productType", "type": "String", "required": False},
61
+ {"name": "vendor", "type": "String", "required": False},
62
+ {"name": "priceRangeV2", "type": "ProductPriceRangeV2", "required": False}
63
+ ],
64
+ relationships=[
65
+ {"from_field": "variants", "to_type": "ProductVariant"},
66
+ {"from_field": "collections", "to_type": "Collection"}
67
+ ]
68
+ ),
69
+ }
70
+
71
+ # Setup logging
72
+ logging.basicConfig(level=logging.INFO)
73
+ self.logger = logging.getLogger(__name__)
74
+
75
+ def generate_graphql_query(self, natural_query: str) -> str:
76
+ """
77
+ Generate GraphQL query from natural language using Llama model
78
+
79
+ :param natural_query: The query in natural language
80
+ :return: GraphQL query as a string
81
+ """
82
+ inputs = self.tokenizer(natural_query, return_tensors="pt")
83
+ outputs = self.model.generate(**inputs, max_length=500)
84
+ query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+
86
+ return query
87
+
88
+ def convert_to_graphql_query(self, natural_query: str) -> Dict[str, Any]:
89
+ """
90
+ Convert natural language to Shopify GraphQL query
91
+
92
+ :param natural_query: Natural language query string
93
+ :return: Dictionary containing GraphQL query or error
94
+ """
95
+ try:
96
+ query = self.generate_graphql_query(natural_query)
97
 
98
+ # Basic query validation
99
+ if query.startswith("query") and "products" in query:
100
+ return {"success": True, "query": query}
 
 
 
101
 
102
+ return {"success": False, "error": "Failed to generate valid GraphQL query"}
103
+
 
 
 
 
 
 
 
104
  except Exception as e:
105
+ self.logger.error(f"Query generation error: {str(e)}")
106
  return {"success": False, "error": str(e)}
107
 
108
+ def execute_query(self, graphql_query: str) -> Dict[str, Any]:
109
+ """
110
+ Execute the GraphQL query against Shopify Admin API
111
+
112
+ :param graphql_query: GraphQL query to execute
113
+ :return: Dictionary containing query results or error
114
+ """
115
  try:
116
+ payload = {"query": graphql_query}
117
+ response = requests.post(
118
+ self.graphql_endpoint,
119
+ headers={
120
+ "Content-Type": "application/json",
121
+ "X-Shopify-Access-Token": self.access_token
122
+ },
123
+ json=payload
124
+ )
125
+ response.raise_for_status()
126
+
127
+ result = response.json()
128
+ return {"success": True, "data": result.get('data', {}), "errors": result.get('errors', [])}
129
+
130
+ except requests.exceptions.RequestException as e:
131
+ self.logger.error(f"Shopify GraphQL query execution error: {str(e)}")
132
  return {"success": False, "error": str(e)}
133
 
134
  def main():
135
+ st.title("Shopify GraphQL Natural Language Query Converter")
136
+
 
 
137
  load_dotenv()
138
+
139
+ shop_url = os.getenv("SHOPIFY_STORE_URL", "https://agkd0n-fa.myshopify.com")
140
+ access_token = os.getenv("SHOPIFY_ACCESS_TOKEN")
141
+ groq_api_key = os.getenv("GROQ_API_KEY")
142
+ model_name = "Qwen/Qwen2.5-72B-Instruct" # Modify this for Llama3 if needed
143
+
144
+ if not all([shop_url, access_token, groq_api_key]):
145
+ st.error("Missing environment variables. Please set SHOPIFY_STORE_URL, SHOPIFY_ACCESS_TOKEN, and GROQ_API_KEY")
146
  return
147
 
 
148
  try:
149
+ graphql_converter = ShopifyGraphQLConverter(shop_url, access_token, groq_api_key, model_name)
150
  except Exception as e:
151
  st.error(f"Error initializing service: {str(e)}")
152
  return
153
 
154
+ natural_query = st.text_area("Enter your Shopify query in natural language", "Find shirt with red color", height=100)
 
155
 
156
+ if st.button("Generate and Execute GraphQL Query"):
157
  if not natural_query.strip():
158
  st.warning("Please enter a valid query.")
159
  return
160
 
161
+ with st.spinner("Generating GraphQL query..."):
162
+ graphql_result = graphql_converter.convert_to_graphql_query(natural_query)
 
163
 
164
+ if not graphql_result["success"]:
165
+ st.error(f"Error generating GraphQL query: {graphql_result['error']}")
166
  return
167
 
168
+ st.subheader("Generated GraphQL Query:")
169
+ st.code(graphql_result["query"], language="graphql")
 
170
 
 
171
  with st.spinner("Executing query..."):
172
+ query_result = graphql_converter.execute_query(graphql_result["query"])
173
 
174
  if not query_result["success"]:
175
  st.error(f"Error executing query: {query_result['error']}")
176
  return
177
 
178
+ st.subheader("Query Results:")
179
+ if query_result["errors"]:
180
+ st.error(f"GraphQL Errors: {query_result['errors']}")
181
+
182
+ if query_result["data"]:
183
+ products = query_result["data"].get("products", {}).get("edges", [])
184
+ if products:
185
+ product_list = [{"Title": p["node"]["title"], "Vendor": p["node"]["vendor"]} for p in products]
186
+ st.dataframe(pd.DataFrame(product_list))
187
+ else:
188
+ st.info("No products found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  else:
190
  st.info("No results found.")
191