Spaces:
Sleeping
Sleeping
tlskillman
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -4,21 +4,15 @@ import io
|
|
4 |
import base64
|
5 |
import requests
|
6 |
import json
|
7 |
-
import threading
|
8 |
-
import uuid
|
9 |
import os
|
10 |
-
from datetime import datetime
|
11 |
import urllib.parse
|
12 |
-
|
13 |
from dotenv import load_dotenv
|
14 |
|
15 |
load_dotenv()
|
16 |
hugging_face_api_key = os.getenv("HUGGING_FACE_API_KEY")
|
17 |
openai_key = os.getenv("OPENAI_API_KEY")
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
user_prompt_examples =[
|
22 |
"Find me the stars in the hyades cluster.",
|
23 |
"Find me the stars in the orion cluster.",
|
24 |
"Find me the stars in the andromeda cluster.",
|
@@ -37,7 +31,7 @@ user_prompt_examples =[
|
|
37 |
"Find stars in the Scorpius constellation with a magnitude brighter than 5.",
|
38 |
"Search for stars with high proper motion in the Ursa Major group."
|
39 |
]
|
40 |
-
|
41 |
def talk_to_llm(user_prompt):
|
42 |
headers = {
|
43 |
"Content-Type": "application/json",
|
@@ -48,11 +42,9 @@ def talk_to_llm(user_prompt):
|
|
48 |
As StarGateVR, your role is specialized in customizing ADQL (Astronomical Data Query Language)
|
49 |
queries for astronomers. Your focus is particularly on integrating specific 'WHERE' clauses into
|
50 |
a standard query template. We will put your WHERE clause into the completed query template.
|
51 |
-
|
52 |
The query includes essential SELECT fields like source_id, positional data (ra, dec),
|
53 |
motion data (pmra, pmdec), and light parameters. Note that any
|
54 |
fields used in the WHERE clause must also be added to the SELECT clause.
|
55 |
-
|
56 |
Customizing 'WHERE' Clause: Your primary task is to adapt the 'WHERE' clause to fit
|
57 |
the user's specific astronomical requirements. This often involves filtering stars based on
|
58 |
various criteria such as distance, location in the sky, brightness, etc.
|
@@ -71,25 +63,21 @@ def talk_to_llm(user_prompt):
|
|
71 |
phot_g_mean_mag,
|
72 |
bp_rp as bp_rp_mag,
|
73 |
```
|
74 |
-
|
75 |
Note that the WHERE clause must reference variables by the field name and not the "AS" name.
|
76 |
There is a special case for the part of the SELECT that is " 'Gaia DR3 ' || source_id as source_id",
|
77 |
in the WHERE clause this field should always be referred to by "source_id".
|
78 |
-
|
79 |
Here is an example of the WHERE clause:
|
80 |
```
|
81 |
WHERE (parallax >= 11.11 AND parallax_over_error>=20 AND
|
82 |
astrometric_excess_noise<=2)
|
83 |
```
|
84 |
-
|
85 |
Here is the preferred structure for the FROM clause:
|
86 |
```
|
87 |
FROM gaiadr3.gaia_source
|
88 |
```
|
89 |
-
|
90 |
Bounds on Parallax: Always include bounds on parallax in the 'WHERE' clause. This is
|
91 |
important as it helps in retrieving stars within a specified 3D region of space.
|
92 |
-
|
93 |
The json structure to return is
|
94 |
{{
|
95 |
"reasoning": "<Think through what the user is asking for, and what you know about the GAIA DB
|
@@ -99,9 +87,7 @@ def talk_to_llm(user_prompt):
|
|
99 |
the downstream logic.>",
|
100 |
"the_query": "<a properly formatted ADQL query that will return the stars the
|
101 |
user is asking for>"
|
102 |
-
|
103 |
}}
|
104 |
-
|
105 |
The users prompt is "{user_prompt}"
|
106 |
'''
|
107 |
|
@@ -114,38 +100,28 @@ def talk_to_llm(user_prompt):
|
|
114 |
"temperature": .2
|
115 |
}
|
116 |
|
117 |
-
print ("PROMPT TEXT:\n")
|
118 |
-
print (prompt_text)
|
119 |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data)
|
120 |
response_json = response.json()
|
121 |
-
|
122 |
-
output_json = json.loads(response_json['choices'][0]['message']['content'])
|
123 |
-
|
124 |
try:
|
|
|
125 |
the_query = output_json['the_query']
|
126 |
reasoning = output_json['reasoning']
|
127 |
-
print(f"dialog: {reasoning}")
|
128 |
-
print(f"query: {the_query}")
|
129 |
-
# Update theme on the server
|
130 |
return reasoning, the_query
|
131 |
except KeyError as e:
|
132 |
print(f"Key error: {e}")
|
133 |
-
return "Failed to generate query."
|
134 |
except json.JSONDecodeError:
|
135 |
print("JSON decoding failed")
|
136 |
-
return "Invalid response query."
|
137 |
|
138 |
def complete_query(partial_query):
|
139 |
query_template = f'''
|
140 |
SELECT TOP 300000
|
141 |
-
|
142 |
-- IMPORTANT NOTE: Parameters that are in units of Magnitude must have an "as" name that ends in "_mag"
|
143 |
-
|
144 |
--Required parameters
|
145 |
-
|
146 |
-- ID - force a leading hash symbol to stop Excel from reading the ID number as a float
|
147 |
'Gaia DR3 ' || source_id as source_id,
|
148 |
-
|
149 |
-- Measured Position
|
150 |
ra,
|
151 |
dec,
|
@@ -154,11 +130,9 @@ def complete_query(partial_query):
|
|
154 |
pmra,
|
155 |
pmdec,
|
156 |
radial_velocity as rv,
|
157 |
-
|
158 |
--Key source light params for HR diagram
|
159 |
phot_g_mean_mag,
|
160 |
bp_rp as bp_rp_mag,
|
161 |
-
|
162 |
--Optional plot parameters (you can add anything you want here, just give good unique "as" names - it will show up in the .cvs and hence in StarGate
|
163 |
phot_rp_mean_mag,
|
164 |
phot_bp_mean_mag,
|
@@ -166,9 +140,7 @@ def complete_query(partial_query):
|
|
166 |
bp_g as bp_g_mag,
|
167 |
radial_velocity_error as rv_error,
|
168 |
parallax_error,
|
169 |
-
|
170 |
-- Additional parameters that appear in the WHERE clause should be added here
|
171 |
-
|
172 |
-- Note: No comma after this last SELECT item
|
173 |
parallax_over_error
|
174 |
-- Use DR3
|
@@ -176,10 +148,6 @@ def complete_query(partial_query):
|
|
176 |
{partial_query}
|
177 |
'''
|
178 |
|
179 |
-
print ("completed query")
|
180 |
-
print (query_template)
|
181 |
-
print (" ")
|
182 |
-
|
183 |
return query_template
|
184 |
|
185 |
def download_url_from_query(query, user_prompt):
|
@@ -187,9 +155,7 @@ def download_url_from_query(query, user_prompt):
|
|
187 |
tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY="
|
188 |
query = urllib.parse.quote_plus(query)
|
189 |
tap_url = tap_url + query
|
190 |
-
print ("downloading from tap")
|
191 |
filename = user_prompt.replace(' ', '_').replace('.', '') + ".csv"
|
192 |
-
print (filename)
|
193 |
download_from_tap(tap_url, filename)
|
194 |
|
195 |
def create_markdown_url_from_query(query):
|
@@ -197,8 +163,6 @@ def create_markdown_url_from_query(query):
|
|
197 |
tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY="
|
198 |
query = urllib.parse.quote_plus(query)
|
199 |
tap_url = tap_url + query
|
200 |
-
print (tap_url)
|
201 |
-
download_from_tap(tap_url, "output_file.csv")
|
202 |
markdown_link = download_data(tap_url)
|
203 |
return markdown_link
|
204 |
|
@@ -221,10 +185,9 @@ def download_data(tap_url):
|
|
221 |
def process_queries():
|
222 |
results = {}
|
223 |
for prompt in user_prompt_examples:
|
224 |
-
print(f"Processing prompt: {prompt}")
|
225 |
reasoning, the_query = talk_to_llm(prompt)
|
226 |
-
|
227 |
-
|
228 |
return results
|
229 |
|
230 |
with gr.Blocks() as demo:
|
@@ -237,10 +200,8 @@ with gr.Blocks() as demo:
|
|
237 |
|
238 |
create_tap_url_bt = gr.Button("Create TAP URL")
|
239 |
download_data_output = gr.Markdown()
|
240 |
-
create_tap_url_bt.click(fn=create_markdown_url_from_query, inputs=the_query_output,
|
241 |
-
outputs=[download_data_output])
|
242 |
|
243 |
-
#demo.launch(server_name="0.0.0.0", server_port=7861, share=True, debug=True)
|
244 |
demo.launch()
|
245 |
|
246 |
|
|
|
4 |
import base64
|
5 |
import requests
|
6 |
import json
|
|
|
|
|
7 |
import os
|
|
|
8 |
import urllib.parse
|
|
|
9 |
from dotenv import load_dotenv
|
10 |
|
11 |
load_dotenv()
|
12 |
hugging_face_api_key = os.getenv("HUGGING_FACE_API_KEY")
|
13 |
openai_key = os.getenv("OPENAI_API_KEY")
|
14 |
|
15 |
+
user_prompt_examples = [
|
|
|
|
|
16 |
"Find me the stars in the hyades cluster.",
|
17 |
"Find me the stars in the orion cluster.",
|
18 |
"Find me the stars in the andromeda cluster.",
|
|
|
31 |
"Find stars in the Scorpius constellation with a magnitude brighter than 5.",
|
32 |
"Search for stars with high proper motion in the Ursa Major group."
|
33 |
]
|
34 |
+
|
35 |
def talk_to_llm(user_prompt):
|
36 |
headers = {
|
37 |
"Content-Type": "application/json",
|
|
|
42 |
As StarGateVR, your role is specialized in customizing ADQL (Astronomical Data Query Language)
|
43 |
queries for astronomers. Your focus is particularly on integrating specific 'WHERE' clauses into
|
44 |
a standard query template. We will put your WHERE clause into the completed query template.
|
|
|
45 |
The query includes essential SELECT fields like source_id, positional data (ra, dec),
|
46 |
motion data (pmra, pmdec), and light parameters. Note that any
|
47 |
fields used in the WHERE clause must also be added to the SELECT clause.
|
|
|
48 |
Customizing 'WHERE' Clause: Your primary task is to adapt the 'WHERE' clause to fit
|
49 |
the user's specific astronomical requirements. This often involves filtering stars based on
|
50 |
various criteria such as distance, location in the sky, brightness, etc.
|
|
|
63 |
phot_g_mean_mag,
|
64 |
bp_rp as bp_rp_mag,
|
65 |
```
|
66 |
+
|
67 |
Note that the WHERE clause must reference variables by the field name and not the "AS" name.
|
68 |
There is a special case for the part of the SELECT that is " 'Gaia DR3 ' || source_id as source_id",
|
69 |
in the WHERE clause this field should always be referred to by "source_id".
|
|
|
70 |
Here is an example of the WHERE clause:
|
71 |
```
|
72 |
WHERE (parallax >= 11.11 AND parallax_over_error>=20 AND
|
73 |
astrometric_excess_noise<=2)
|
74 |
```
|
|
|
75 |
Here is the preferred structure for the FROM clause:
|
76 |
```
|
77 |
FROM gaiadr3.gaia_source
|
78 |
```
|
|
|
79 |
Bounds on Parallax: Always include bounds on parallax in the 'WHERE' clause. This is
|
80 |
important as it helps in retrieving stars within a specified 3D region of space.
|
|
|
81 |
The json structure to return is
|
82 |
{{
|
83 |
"reasoning": "<Think through what the user is asking for, and what you know about the GAIA DB
|
|
|
87 |
the downstream logic.>",
|
88 |
"the_query": "<a properly formatted ADQL query that will return the stars the
|
89 |
user is asking for>"
|
|
|
90 |
}}
|
|
|
91 |
The users prompt is "{user_prompt}"
|
92 |
'''
|
93 |
|
|
|
100 |
"temperature": .2
|
101 |
}
|
102 |
|
|
|
|
|
103 |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data)
|
104 |
response_json = response.json()
|
105 |
+
|
|
|
|
|
106 |
try:
|
107 |
+
output_json = json.loads(response_json['choices'][0]['message']['content'])
|
108 |
the_query = output_json['the_query']
|
109 |
reasoning = output_json['reasoning']
|
|
|
|
|
|
|
110 |
return reasoning, the_query
|
111 |
except KeyError as e:
|
112 |
print(f"Key error: {e}")
|
113 |
+
return "Failed to generate query.", ""
|
114 |
except json.JSONDecodeError:
|
115 |
print("JSON decoding failed")
|
116 |
+
return "Invalid response query.", ""
|
117 |
|
118 |
def complete_query(partial_query):
|
119 |
query_template = f'''
|
120 |
SELECT TOP 300000
|
|
|
121 |
-- IMPORTANT NOTE: Parameters that are in units of Magnitude must have an "as" name that ends in "_mag"
|
|
|
122 |
--Required parameters
|
|
|
123 |
-- ID - force a leading hash symbol to stop Excel from reading the ID number as a float
|
124 |
'Gaia DR3 ' || source_id as source_id,
|
|
|
125 |
-- Measured Position
|
126 |
ra,
|
127 |
dec,
|
|
|
130 |
pmra,
|
131 |
pmdec,
|
132 |
radial_velocity as rv,
|
|
|
133 |
--Key source light params for HR diagram
|
134 |
phot_g_mean_mag,
|
135 |
bp_rp as bp_rp_mag,
|
|
|
136 |
--Optional plot parameters (you can add anything you want here, just give good unique "as" names - it will show up in the .cvs and hence in StarGate
|
137 |
phot_rp_mean_mag,
|
138 |
phot_bp_mean_mag,
|
|
|
140 |
bp_g as bp_g_mag,
|
141 |
radial_velocity_error as rv_error,
|
142 |
parallax_error,
|
|
|
143 |
-- Additional parameters that appear in the WHERE clause should be added here
|
|
|
144 |
-- Note: No comma after this last SELECT item
|
145 |
parallax_over_error
|
146 |
-- Use DR3
|
|
|
148 |
{partial_query}
|
149 |
'''
|
150 |
|
|
|
|
|
|
|
|
|
151 |
return query_template
|
152 |
|
153 |
def download_url_from_query(query, user_prompt):
|
|
|
155 |
tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY="
|
156 |
query = urllib.parse.quote_plus(query)
|
157 |
tap_url = tap_url + query
|
|
|
158 |
filename = user_prompt.replace(' ', '_').replace('.', '') + ".csv"
|
|
|
159 |
download_from_tap(tap_url, filename)
|
160 |
|
161 |
def create_markdown_url_from_query(query):
|
|
|
163 |
tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY="
|
164 |
query = urllib.parse.quote_plus(query)
|
165 |
tap_url = tap_url + query
|
|
|
|
|
166 |
markdown_link = download_data(tap_url)
|
167 |
return markdown_link
|
168 |
|
|
|
185 |
def process_queries():
|
186 |
results = {}
|
187 |
for prompt in user_prompt_examples:
|
|
|
188 |
reasoning, the_query = talk_to_llm(prompt)
|
189 |
+
if the_query:
|
190 |
+
download_url_from_query(the_query, prompt)
|
191 |
return results
|
192 |
|
193 |
with gr.Blocks() as demo:
|
|
|
200 |
|
201 |
create_tap_url_bt = gr.Button("Create TAP URL")
|
202 |
download_data_output = gr.Markdown()
|
203 |
+
create_tap_url_bt.click(fn=create_markdown_url_from_query, inputs=the_query_output, outputs=[download_data_output])
|
|
|
204 |
|
|
|
205 |
demo.launch()
|
206 |
|
207 |
|