Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageDraw, ImageFont | |
import io | |
import base64 | |
import requests | |
import json | |
import os | |
import urllib.parse | |
from dotenv import load_dotenv | |
load_dotenv() | |
hugging_face_api_key = os.getenv("HUGGING_FACE_API_KEY") | |
openai_key = os.getenv("OPENAI_API_KEY") | |
user_prompt_examples = [ | |
"Find me the stars in the hyades cluster.", | |
"Find me the stars in the orion cluster.", | |
"Find me the stars in the andromeda cluster.", | |
"Find me the stars in the milky way.", | |
"Find me the stars in the virgo cluster.", | |
"Find me the stars in the bootes cluster.", | |
"Find me the stars in the perseus cluster.", | |
"Find me the stars in the hydra cluster.", | |
"Find me the stars in the Pleiades cluster.", | |
"Retrieve stars within 50 light-years of Earth.", | |
"List all stars in the Orion Nebula.", | |
"Get data on stars with a radial velocity greater than 100 km/s.", | |
"Fetch all known white dwarfs in the Sirius star system.", | |
"Identify stars in the globular cluster Omega Centauri.", | |
"Display stars from the Andromeda Galaxy that are visible from Earth.", | |
"Find stars in the Scorpius constellation with a magnitude brighter than 5.", | |
"Search for stars with high proper motion in the Ursa Major group." | |
] | |
def talk_to_llm(user_prompt): | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {openai_key}" | |
} | |
prompt_text = f''' | |
As StarGateVR, your role is specialized in customizing ADQL (Astronomical Data Query Language) | |
queries for astronomers. Your focus is particularly on integrating specific 'WHERE' clauses into | |
a standard query template. We will put your WHERE clause into the completed query template. | |
The query includes essential SELECT fields like source_id, positional data (ra, dec), | |
motion data (pmra, pmdec), and light parameters. Note that any | |
fields used in the WHERE clause must also be added to the SELECT clause. | |
Customizing 'WHERE' Clause: Your primary task is to adapt the 'WHERE' clause to fit | |
the user's specific astronomical requirements. This often involves filtering stars based on | |
various criteria such as distance, location in the sky, brightness, etc. | |
Always include, at a minimum, the SELECT and FROM clauses as given in this template: | |
``` | |
SELECT TOP 300000 # This limits the query run time and prevents timeouts. | |
'Gaia DR3 ' || source_id as source_id, | |
ra, | |
dec, | |
parallax, | |
pmra, | |
pmdec, | |
radial_velocity as rv, | |
phot_g_mean_mag, | |
bp_rp as bp_rp_mag, | |
``` | |
Note that the WHERE clause must reference variables by the field name and not the "AS" name. | |
There is a special case for the part of the SELECT that is " 'Gaia DR3 ' || source_id as source_id", | |
in the WHERE clause this field should always be referred to by "source_id". | |
Here is an example of the WHERE clause: | |
``` | |
WHERE (parallax >= 11.11 AND parallax_over_error>=20 AND | |
astrometric_excess_noise<=2) | |
``` | |
Here is the preferred structure for the FROM clause: | |
``` | |
FROM gaiadr3.gaia_source | |
``` | |
Bounds on Parallax: Always include bounds on parallax in the 'WHERE' clause. This is | |
important as it helps in retrieving stars within a specified 3D region of space. | |
The json structure to return is | |
{{ | |
"reasoning": "<Think through what the user is asking for, and what you know about the GAIA DB | |
and astronomy to create their request. Because the WHERE clause you are generating | |
is going to be concatenated into a larger SQL query, consider how to structure the | |
query such that everything fits in a single WHERE clause. Otherwise it will break | |
the downstream logic.>", | |
"the_query": "<a properly formatted ADQL query that will return the stars the | |
user is asking for>" | |
}} | |
The users prompt is "{user_prompt}" | |
''' | |
data = { | |
"model": "gpt-4o", | |
"response_format": { "type": "json_object" }, | |
"messages": [ | |
{"role": "system", "content": prompt_text}, | |
], | |
"temperature": .2 | |
} | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data) | |
response_json = response.json() | |
try: | |
output_json = json.loads(response_json['choices'][0]['message']['content']) | |
the_query = output_json['the_query'] | |
reasoning = output_json['reasoning'] | |
return reasoning, the_query | |
except KeyError as e: | |
print(f"Key error: {e}") | |
return "Failed to generate query.", "" | |
except json.JSONDecodeError: | |
print("JSON decoding failed") | |
return "Invalid response query.", "" | |
def complete_query(partial_query): | |
query_template = f''' | |
SELECT TOP 300000 | |
-- IMPORTANT NOTE: Parameters that are in units of Magnitude must have an "as" name that ends in "_mag" | |
--Required parameters | |
-- ID - force a leading hash symbol to stop Excel from reading the ID number as a float | |
'Gaia DR3 ' || source_id as source_id, | |
-- Measured Position | |
ra, | |
dec, | |
parallax, | |
-- Measured Motion | |
pmra, | |
pmdec, | |
radial_velocity as rv, | |
--Key source light params for HR diagram | |
phot_g_mean_mag, | |
bp_rp as bp_rp_mag, | |
--Optional plot parameters (you can add anything you want here, just give good unique "as" names - it will show up in the .cvs and hence in StarGate | |
phot_rp_mean_mag, | |
phot_bp_mean_mag, | |
g_rp as g_rp_mag, | |
bp_g as bp_g_mag, | |
radial_velocity_error as rv_error, | |
parallax_error, | |
-- Additional parameters that appear in the WHERE clause should be added here | |
-- Note: No comma after this last SELECT item | |
parallax_over_error | |
-- Use DR3 | |
FROM gaiadr3.gaia_source | |
{partial_query} | |
''' | |
return query_template | |
def download_url_from_query(query, user_prompt): | |
# Create the TAP URL | |
tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY=" | |
query = urllib.parse.quote_plus(query) | |
tap_url = tap_url + query | |
filename = user_prompt.replace(' ', '_').replace('.', '') + ".csv" | |
download_from_tap(tap_url, filename) | |
def create_markdown_url_from_query(query): | |
# Create the TAP URL | |
tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY=" | |
query = urllib.parse.quote_plus(query) | |
tap_url = tap_url + query | |
markdown_link = download_data(tap_url) | |
return markdown_link | |
def download_from_tap(url, output_path): | |
try: | |
response = requests.get(url) | |
response.raise_for_status() # Raises an HTTPError for bad responses (4xx, 5xx) | |
with open(output_path, 'wb') as f: | |
f.write(response.content) | |
print(f"Data successfully downloaded to {output_path}.") | |
except requests.exceptions.HTTPError as err: | |
print(f"HTTP error occurred: {err}") # Handle specific HTTP errors | |
except Exception as err: | |
print(f"An error occurred: {err}") # Handle other possible errors | |
def download_data(tap_url): | |
return f"[Run Query on GaiaDB and Download CSV Datafile - may need second click to login]({tap_url})" | |
# Main function to process all queries | |
def process_queries(): | |
results = {} | |
for prompt in user_prompt_examples: | |
reasoning, the_query = talk_to_llm(prompt) | |
if the_query: | |
download_url_from_query(the_query, prompt) | |
return results | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
user_prompt = gr.Textbox(label="Enter your query for the LLM", value="Find me the stars in the Hyades cluster.") | |
submit_btn = gr.Button("Ask LLM") | |
reasoning_output = gr.Textbox(label="Reasoning") | |
the_query_output = gr.Textbox(label="The Query") | |
submit_btn.click(fn=talk_to_llm, inputs=user_prompt, outputs=[reasoning_output, the_query_output]) | |
create_tap_url_bt = gr.Button("Create TAP URL") | |
download_data_output = gr.Markdown() | |
create_tap_url_bt.click(fn=create_markdown_url_from_query, inputs=the_query_output, outputs=[download_data_output]) | |
demo.launch() | |
#process_queries() |