knowledge_graph / app.py
varun500's picture
Update app.py
9b048cd
raw
history blame
3.69 kB
import streamlit as st
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans
def main():
st.title("Financial Graph App")
st.write("Enter a financial sentence and see its similarity to predefined keywords.")
# User input
financial_sentence = st.text_area("Enter the financial sentence", value="")
# Check if the user entered a sentence
if financial_sentence.strip() != "":
# Predefined keywords
keywords = [
"Finance",
"Fiscal",
"Quarterly results",
"Revenue",
"Profit",
]
# Load the pre-trained Sentence-Transformers model
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
st.set_option('deprecation.showPyplotGlobalUse', False)
# Generate word embeddings for the financial sentence and keywords
sentence_embedding = model.encode([financial_sentence])
keyword_embeddings = model.encode(keywords)
# Calculate cosine similarity between the sentence embedding and keyword embeddings
similarity_scores = cosine_similarity(sentence_embedding, keyword_embeddings)[0]
# Create a graph
G = nx.Graph()
# Add the sentence embedding as a node to the graph
G.add_node(financial_sentence, embedding=sentence_embedding[0])
# Add the keyword embeddings as nodes to the graph
for keyword, embedding, similarity in zip(keywords, keyword_embeddings, similarity_scores):
G.add_node(keyword, embedding=embedding, similarity=similarity)
# Add edges between the sentence and keywords with their similarity scores as weights
for keyword, similarity in zip(keywords, similarity_scores):
G.add_edge(financial_sentence, keyword, weight=similarity)
# Perform KNN clustering on the keyword embeddings
kmeans = KMeans(n_clusters=3)
cluster_labels = kmeans.fit_predict(keyword_embeddings)
# Add cluster labels as node attributes
for node, cluster_label in zip(G.nodes, cluster_labels):
G.nodes[node]["cluster"] = cluster_label
# Set node positions using spring layout
pos = nx.spring_layout(G)
# Get unique cluster labels
unique_clusters = set(cluster_labels)
# Assign colors to clusters
cluster_colors = ["lightblue", "lightgreen", "lightyellow"]
# Draw nodes with cluster colors
nx.draw_networkx_nodes(
G,
pos,
node_color=[cluster_colors[G.nodes[node].get("cluster", 0)] for node in G.nodes],
node_size=800,
)
# Draw edges
nx.draw_networkx_edges(G, pos, edge_color="gray", width=1, alpha=0.7)
# Draw labels
nx.draw_networkx_labels(G, pos, font_size=10, font_weight="bold")
# Draw edge labels (cosine similarity scores)
edge_labels = nx.get_edge_attributes(G, "weight")
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=8)
# Set plot attributes
plt.title("Financial Context and Keywords")
plt.axis("off")
# Save the graph as an image
plt.savefig("financial_graph.png")
# Show the graph
st.pyplot()
# Save the similarity scores in a CSV file
df = pd.DataFrame({"Keyword": keywords, "Cosine Similarity": similarity_scores})
st.write("Similarity Scores:")
st.dataframe(df)
if __name__ == "__main__":
main()