File size: 8,171 Bytes
62790e6
 
 
 
 
 
 
 
b9a8926
62790e6
b9a8926
 
ecca2ac
62790e6
b7e50c1
e9fb1d8
62790e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7e50c1
62790e6
 
 
 
 
 
 
51a63fd
 
 
 
 
 
 
 
 
 
0a3a765
 
 
a96002c
0a3a765
 
 
 
 
 
 
 
 
 
b7e50c1
0a3a765
51a63fd
 
ecada79
 
 
 
 
11fe03c
 
 
 
51a63fd
 
62790e6
 
51a63fd
 
 
 
 
 
 
62790e6
 
 
 
 
b8b8f1d
62790e6
 
 
 
b4b37b4
62790e6
 
 
 
b7e50c1
62790e6
b7e50c1
51a63fd
62790e6
51a63fd
62790e6
 
b7e50c1
62790e6
 
b7e50c1
62790e6
 
 
 
 
 
 
124f278
62790e6
 
 
51a63fd
62790e6
 
 
 
 
 
 
 
 
 
 
 
 
 
e6f082e
62790e6
 
 
 
 
 
597d095
62790e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
096512a
62790e6
 
 
 
 
51a63fd
b7e50c1
 
62790e6
 
 
 
e9fb1d8
62790e6
 
51a63fd
c47fca8
62790e6
0aa2eba
62790e6
b7e50c1
62790e6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
import io
import base64
import requests
import json
import os
import urllib.parse
from dotenv import load_dotenv

load_dotenv()
hugging_face_api_key = os.getenv("HUGGING_FACE_API_KEY")
openai_key = os.getenv("OPENAI_API_KEY")

user_prompt_examples = [
    "Find me the stars in the hyades cluster.",
    "Find me the stars in the orion cluster.",
    "Find me the stars in the andromeda cluster.",
    "Find me the stars in the milky way.",
    "Find me the stars in the virgo cluster.",
    "Find me the stars in the bootes cluster.",
    "Find me the stars in the perseus cluster.",
    "Find me the stars in the hydra cluster.",
    "Find me the stars in the Pleiades cluster.",
    "Retrieve stars within 50 light-years of Earth.",
    "List all stars in the Orion Nebula.",
    "Get data on stars with a radial velocity greater than 100 km/s.",
    "Fetch all known white dwarfs in the Sirius star system.",
    "Identify stars in the globular cluster Omega Centauri.",
    "Display stars from the Andromeda Galaxy that are visible from Earth.",
    "Find stars in the Scorpius constellation with a magnitude brighter than 5.",
    "Search for stars with high proper motion in the Ursa Major group."
]

def talk_to_llm(user_prompt):
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai_key}"
    }

    prompt_text = f'''
    As StarGateVR, your role is specialized in customizing ADQL (Astronomical Data Query Language) 
    queries for astronomers. Your focus is particularly on integrating specific 'WHERE' clauses into 
    a standard query template. We will put your WHERE clause into the completed query template. 
    The query includes essential SELECT fields like source_id, positional data (ra, dec), 
    motion data (pmra, pmdec), and light parameters. Note that any
    fields used in the WHERE clause must also be added to the SELECT clause.
    Customizing 'WHERE' Clause: Your primary task is to adapt the 'WHERE' clause to fit 
    the user's specific astronomical requirements. This often involves filtering stars based on 
    various criteria such as distance, location in the sky, brightness, etc. 
    
    Always include, at a minimum, the SELECT and FROM clauses as given in this template:
    
    ```
    SELECT TOP 300000 # This limits the query run time and prevents timeouts.
    'Gaia DR3 ' || source_id as source_id,
    ra,
    dec,
    parallax,
    pmra,
    pmdec,
    radial_velocity as rv,
    phot_g_mean_mag,
    bp_rp as bp_rp_mag,
    ```

    Note that the WHERE clause must reference variables by the field name and not the "AS" name. 
    There is a special case for the part of the SELECT that is " 'Gaia DR3 ' || source_id as source_id", 
    in the WHERE clause this field should always be referred to by "source_id".
    Here is an example of the WHERE clause:
    ```
    WHERE (parallax >= 11.11 AND parallax_over_error>=20 AND
    astrometric_excess_noise<=2)
    ```
    Here is the preferred structure for the FROM clause:
    ```
    FROM gaiadr3.gaia_source
    ```
    Bounds on Parallax: Always include bounds on parallax in the 'WHERE' clause. This is 
    important as it helps in retrieving stars within a specified 3D region of space.
    The json structure to return is
    {{
        "reasoning": "<Think through what the user is asking for, and what you know about the GAIA DB 
        and astronomy to create their request. Because the WHERE clause you are generating 
        is going to be concatenated into a larger SQL query, consider how to structure the 
        query such that everything fits in a single WHERE clause.  Otherwise it will break 
        the downstream logic.>",
        "the_query": "<a properly formatted ADQL query that will return the stars the 
        user is asking for>"
    }}
    The users prompt is "{user_prompt}"
    '''

    data = {
        "model": "gpt-4o",
        "response_format": { "type": "json_object" },
        "messages": [
            {"role": "system", "content": prompt_text},
        ],
        "temperature": .2
    }

    response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data)
    response_json = response.json()

    try:
        output_json = json.loads(response_json['choices'][0]['message']['content'])
        the_query = output_json['the_query']
        reasoning = output_json['reasoning']
        return reasoning, the_query
    except KeyError as e:
        print(f"Key error: {e}")
        return "Failed to generate query.", ""
    except json.JSONDecodeError:
        print("JSON decoding failed")
        return "Invalid response query.", ""

def complete_query(partial_query):
    query_template = f'''
    SELECT TOP 300000
    -- IMPORTANT NOTE: Parameters that are in units of Magnitude must have an "as" name that ends in "_mag"
    --Required parameters
    -- ID - force a leading hash symbol to stop Excel from reading the ID number as a float
    'Gaia DR3 ' || source_id as source_id,
    -- Measured Position
    ra,
    dec,
    parallax,
    -- Measured Motion
    pmra,
    pmdec,
    radial_velocity as rv,
    --Key source light params for HR diagram
    phot_g_mean_mag,
    bp_rp as bp_rp_mag,
    --Optional plot parameters (you can add anything you want here, just give good unique "as" names - it will show up in the .cvs and hence in StarGate
    phot_rp_mean_mag,
    phot_bp_mean_mag,
    g_rp as g_rp_mag,
    bp_g as bp_g_mag,
    radial_velocity_error as rv_error,
    parallax_error,
    -- Additional parameters that appear in the WHERE clause should be added here
    -- Note: No comma after this last SELECT item
    parallax_over_error
    -- Use DR3
    FROM gaiadr3.gaia_source    
    {partial_query}
    '''
    
    return query_template

def download_url_from_query(query, user_prompt):
    # Create the TAP URL
    tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY="
    query = urllib.parse.quote_plus(query)
    tap_url = tap_url + query
    filename = user_prompt.replace(' ', '_').replace('.', '') + ".csv"
    download_from_tap(tap_url, filename)

def create_markdown_url_from_query(query):
    # Create the TAP URL
    tap_url = "https://gea.esac.esa.int/tap-server/tap/sync?REQUEST=doQuery&LANG=ADQL&FORMAT=csv&QUERY="
    query = urllib.parse.quote_plus(query)
    tap_url = tap_url + query
    markdown_link = download_data(tap_url)
    return markdown_link

def download_from_tap(url, output_path):
    try:
        response = requests.get(url)
        response.raise_for_status()  # Raises an HTTPError for bad responses (4xx, 5xx)
        with open(output_path, 'wb') as f:
            f.write(response.content)
        print(f"Data successfully downloaded to {output_path}.")
    except requests.exceptions.HTTPError as err:
        print(f"HTTP error occurred: {err}")  # Handle specific HTTP errors
    except Exception as err:
        print(f"An error occurred: {err}")  # Handle other possible errors

def download_data(tap_url):
    return f"[Run Query on GaiaDB and Download CSV Datafile - may need second click to login]({tap_url})"

# Main function to process all queries
def process_queries():
    results = {}
    for prompt in user_prompt_examples:
        reasoning, the_query = talk_to_llm(prompt)
        if the_query:
            download_url_from_query(the_query, prompt)
    return results

with gr.Blocks() as demo:
    with gr.Row():
        user_prompt = gr.Textbox(label="Enter your query for the LLM", value="Find me the stars in the Hyades cluster.")
        submit_btn = gr.Button("Ask LLM")
    reasoning_output = gr.Textbox(label="Reasoning")
    the_query_output = gr.Textbox(label="The Query")
    submit_btn.click(fn=talk_to_llm, inputs=user_prompt, outputs=[reasoning_output, the_query_output])

    create_tap_url_bt = gr.Button("Create TAP URL")
    download_data_output = gr.Markdown()
    create_tap_url_bt.click(fn=create_markdown_url_from_query, inputs=the_query_output, outputs=[download_data_output])

demo.launch()


#process_queries()