File size: 8,240 Bytes
a539bdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# from dotenv import find_dotenv, load_dotenv
# _ = load_dotenv(find_dotenv())

import solara

import polars as pl

df = pl.read_csv(
    "https://drive.google.com/uc?export=download&id=1uD3h7xYxr9EoZ0Ggoh99JtQXa3AxtxyU"
)

import string

df = df.with_columns(
    pl.Series("Album", [string.capwords(album) for album in df["Album"]])
)
df = df.with_columns(pl.Series("Song", [string.capwords(song) for song in df["Song"]]))
df = df.with_columns(pl.col("Lyrics").fill_null("None"))

df = df.with_columns(
    text=pl.lit("# ")
    + pl.col("Album")
    + pl.lit(": ")
    + pl.col("Song")
    + pl.lit("\n\n")
    + pl.col("Lyrics")
)

import shutil
import lancedb

shutil.rmtree("test_lancedb", ignore_errors=True)
db = lancedb.connect("test_lancedb")

from lancedb.embeddings import get_registry

embeddings = (
    get_registry()
    .get("sentence-transformers")
    .create(name="TaylorAI/gte-tiny", device="cpu")
)

from lancedb.pydantic import LanceModel, Vector


class Songs(LanceModel):
    Song: str
    Lyrics: str
    Album: str
    Artist: str
    text: str = embeddings.SourceField()
    vector: Vector(embeddings.ndims()) = embeddings.VectorField()

table = db.create_table("Songs", schema=Songs)
table.add(data=df)

import os
from typing import Optional

from langchain_community.chat_models import ChatOpenAI

class ChatOpenRouter(ChatOpenAI):
    openai_api_base: str
    openai_api_key: str
    model_name: str

    def __init__(
        self,
        model_name: str,
        openai_api_key: Optional[str] = None,
        openai_api_base: str = "https://openrouter.ai/api/v1",
        **kwargs,
    ):
        openai_api_key = os.getenv("OPENROUTER_API_KEY")
        super().__init__(
            openai_api_base=openai_api_base,
            openai_api_key=openai_api_key,
            model_name=model_name,
            **kwargs,
        )

llm_openrouter = ChatOpenRouter(model_name="meta-llama/llama-3.1-405b-instruct", temperature=0.1)

def get_relevant_texts(query, table=table):
    results = (
        table.search(query)
             .limit(5)
             .to_polars()
    )
    return " ".join([results["text"][i] + "\n\n---\n\n" for i in range(5)])

def generate_prompt(query, table=table):
    return (
        "Answer the question based only on the following context:\n\n"
        + get_relevant_texts(query, table)
        + "\n\nQuestion: "
        + query
    )

def generate_response(query, table=table):
    prompt = generate_prompt(query, table)
    response = llm_openrouter.invoke(input=prompt)
    return response.content

import kuzu
 
shutil.rmtree("test_kuzudb", ignore_errors=True)
db = kuzu.Database("test_kuzudb")
conn = kuzu.Connection(db)
# Create schema
conn.execute("CREATE NODE TABLE ARTIST(name STRING, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE ALBUM(name STRING, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE SONG(ID SERIAL, name STRING, lyrics STRING, PRIMARY KEY(ID))")
conn.execute("CREATE REL TABLE IN_ALBUM(FROM SONG TO ALBUM)")
conn.execute("CREATE REL TABLE FROM_ARTIST(FROM ALBUM TO ARTIST)");

# Insert nodes
for artist in df["Artist"].unique():
    conn.execute(f"CREATE (artist:ARTIST {{name: '{artist}'}})")

for album in df["Album"].unique():
    conn.execute(f"""CREATE (album:ALBUM {{name: "{album}"}})""")

for song, lyrics in df.select(["Song", "text"]).unique().rows():
    replaced_lyrics = lyrics.replace('"', "'")
    conn.execute(
        f"""CREATE (song:SONG {{name: "{song}", lyrics: "{replaced_lyrics}"}})"""
    )

# Insert edges
for song, album, lyrics in df.select(["Song", "Album", "text"]).rows():
    replaced_lyrics = lyrics.replace('"', "'")
    conn.execute(
        f"""
        MATCH (song:SONG), (album:ALBUM) 
        WHERE song.name = "{song}" AND song.lyrics = "{replaced_lyrics}" AND album.name = "{album}"
        CREATE (song)-[:IN_ALBUM]->(album)
        """
    )

for album, artist in df.select(["Album", "Artist"]).unique().rows():
  conn.execute(
    f"""
    MATCH (album:ALBUM), (artist:ARTIST) WHERE album.name = "{album}" AND artist.name = "{artist}"
    CREATE (album)-[:FROM_ARTIST]->(artist)
    """
  )

response = conn.execute(
    """
    MATCH (a:ALBUM {name: 'The Black Album'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name
    """
  )

df_response = response.get_as_pl()

from langchain_community.graphs import KuzuGraph

graph = KuzuGraph(db)

def generate_kuzu_prompt(user_query):
    return """Task: Generate Kùzu Cypher statement to query a graph database.

Instructions:
Generate the Kùzu dialect of Cypher with the following rules in mind:
1. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`.
2. Do not include triple backticks ``` in your response. Return only Cypher.
3. Do not return any notes or comments in your response.


Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.

Schema:\n""" + graph.get_schema + """\nExample:
The question is:\n"Which songs does the load album have?"
MATCH (a:ALBUM {name: 'Load'})<-[:IN_ALBUM]-(s:SONG) RETURN s.name

Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.

The question is:\n""" + user_query


def generate_final_prompt(query,cypher_query,col_name,_values):
    return f"""You are an assistant that helps to form nice and human understandable answers.
The information part contains the provided information that you must use to construct an answer.
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
Here is an example:

Question: Which managers own Neo4j stocks?
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.

Follow this example when generating answers.
If the provided information is empty, say that you don't know the answer.
Query:\n{cypher_query}
Information:
[{col_name}: {_values}]

Question: {query}
Helpful Answer:
"""

def generate_kg_response(query):
    prompt = generate_kuzu_prompt(query)
    cypher_query_response = llm_openrouter.invoke(input=prompt)
    cypher_query = cypher_query_response.content
    response = conn.execute(
        f"""
        {cypher_query}
        """
    )
    df = response.get_as_pl()
    col_name = df.columns[0]
    _values = df[col_name].to_list()
    final_prompt = generate_final_prompt(query,cypher_query,col_name,_values)
    final_response = llm_openrouter.invoke(input=final_prompt)
    final_response = final_response.content
    return final_response, cypher_query

def get_classification(query):
    prompt = "Answer only YES or NO. Is the question '" + query + "' related to the content of a song?"
    response = llm_openrouter.invoke(input=prompt)
    return response.content

query = solara.reactive("How many songs does the black album have?")
@solara.component
def Page():
    with solara.Column(margin=10):
        solara.Markdown("# Metallica Song Finder Graph RAG")
        solara.InputText("Enter some query:", query, continuous_update=False)
        if query.value != "":
            query_class = get_classification(query.value)
            if query_class == 'YES' or query_class == 'YES.':
                df_results = table.search(query.value).limit(5).to_polars()
                df_results = df_results.select(['Song', 'Album', '_distance', 'Lyrics', 'Artist'])
                response = generate_response(query.value)
                solara.Markdown("## Answer:")
                solara.Markdown(response)
                solara.Markdown("## Context:")
                solara.DataFrame(df_results, items_per_page=5)
            else:
                response, cypher_query = generate_kg_response(query.value)
                solara.Markdown("## Answer:")
                solara.Markdown(response)
                solara.Markdown("## Cypher query:")
                solara.Markdown(cypher_query)