File size: 6,053 Bytes
6822c70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8513c8e
4c3757a
6822c70
8513c8e
6822c70
 
 
8513c8e
6822c70
8513c8e
6822c70
 
 
 
 
 
 
 
 
26aef91
 
6822c70
26aef91
 
6822c70
 
 
 
3c111c9
 
6822c70
 
3c111c9
 
6822c70
 
 
 
 
 
 
 
 
 
 
 
700581f
6822c70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8321a05
6822c70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
940d146
6822c70
940d146
5a46de1
940d146
5a46de1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import streamlit as st
from time import sleep
import json
from pymongo import MongoClient
from bson import ObjectId
from openai import OpenAI
openai_client = OpenAI()
import os

# Get the restaurants based on the search and location
def get_restaurants(search, location, meters):
    try:
        uri = os.environ.get('MONGODB_ATLAS_URI')
        client = MongoClient(uri)
        db_name = 'whatscooking'
        collection_name = 'restaurants'
        restaurants_collection = client[db_name][collection_name]
        trips_collection = client[db_name]['smart_trips']
    except:
        st.error("Error Connecting to the MongoDB Atlas Cluster")
        return None, None, None, None

    try:
        with st.status("Search data..."):
            newTrip, pre_agg = pre_aggregate_meters(restaurants_collection, location, meters)

            response = openai_client.embeddings.create(
            input=search,
            model="text-embedding-3-small",
            dimensions=256
           )

            vectorQuery = {
            "$vectorSearch": {
                "index": "vector_index",
                "queryVector": response.data[0].embedding,
                "path": "embedding",
                "numCandidates": 10,
                "limit": 3,
                "filter": {"searchTrip": newTrip}
            }
        }
            st.write("Vector query")
            restaurant_docs = list(trips_collection.aggregate([vectorQuery, {"$project": {"_id": 0, "embedding": 0}}]))

            st.write("RAG...")
            stream_response = openai_client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are a helpful restaurant assistant. Answer shortly and quickly. You will get a context if the context is not relevant to the user query please address that and not provide by default the restaurants as is."},
                {"role": "user", "content": f"Find me the 2 best restaurant and why based on {search} and {restaurant_docs}. Shortly explain trades offs and why I should go to each one. You can mention the third option as a possible alternative in one sentence."}
            ],
            stream=True
        )

        chat_response = st.write_stream(stream_response)

        trips_collection.delete_many({"searchTrip": newTrip})

        if len(restaurant_docs) == 0:
            return "No restaurants found", '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':\'\'}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>', str(pre_agg), str(vectorQuery)

        first_restaurant = restaurant_docs[0]['restaurant_id']
        second_restaurant = restaurant_docs[1]['restaurant_id']
        third_restaurant = restaurant_docs[2]['restaurant_id']
        restaurant_string = f"'{first_restaurant}', '{second_restaurant}', '{third_restaurant}'"

        iframe = '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':{$in:[' + restaurant_string + ']}}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>'
        client.close()
        return chat_response, iframe, str(pre_agg), str(vectorQuery)
    except Exception as e:
        st.error(f"Your query caused an error: {e}")
        return "Your query caused an error, please retry with allowed input only ...", '<iframe style="background: #FFFFFF;border: none;border-radius: 2px;box-shadow: 0 2px 10px 0 rgba(70, 76, 79, .2);" width="640" height="480" src="https://charts.mongodb.com/charts-paveldev-wiumf/embed/charts?id=65c24b0c-2215-4e6f-829c-f484dfd8a90c&filter={\'restaurant_id\':\'\'}&maxDataAge=3600&theme=light&autoRefresh=true"></iframe>', str(pre_agg), str(vectorQuery)

def pre_aggregate_meters(restaurants_collection, location, meters):
    tripId = ObjectId()
    pre_aggregate_pipeline = [{
        "$geoNear": {
            "near": location,
            "distanceField": "distance",
            "maxDistance": meters,
            "spherical": True,
        },
    }, {
        "$addFields": {
            "searchTrip": tripId,
            "date": tripId.generation_time
        }
    }, {
        "$merge": {
            "into": "smart_trips"
        }
    }]

    result = restaurants_collection.aggregate(pre_aggregate_pipeline)
    #sleep(3)
    return tripId, pre_aggregate_pipeline

st.markdown(
    """
    # MongoDB's Vector Restaurant Planner 
    Start typing below to see the results. You can search a specific cuisine for you and choose 3 predefined locations.

    The radius specifies the distance from the start search location. This space uses the dataset called [whatscooking.restaurants](https://huggingface.co/datasets/AIatMongoDB/whatscooking.restaurants)
    """
)

search = st.text_input("What type of dinner are you looking for?")
location = st.radio("Location", options=[
    {"label": "Timesquare Manhattan", "value": {"type": "Point", "coordinates": [-73.98527039999999, 40.7589099]}},
    {"label": "Westside Manhattan", "value": {"type": "Point", "coordinates": [-74.013686, 40.701975]}},
    {"label": "Downtown Manhattan", "value": {"type": "Point", "coordinates": [-74.000468, 40.720777]}}
], format_func=lambda x: x['label'])
meters = st.slider("Radius in meters", min_value=500, max_value=10000, step=5)

if st.button("Get Restaurants"):
    location_value = location['value']
    result, iframe, pre_agg, vectorQuery = get_restaurants(search, location_value, meters)
    if result:
        st.subheader("Map")
        st.markdown(iframe, unsafe_allow_html=True)
        st.subheader("Geo pre aggregation")
        st.code(pre_agg)
        st.subheader("Vector query")
        st.code(vectorQuery)