Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
import faiss
|
7 |
+
|
8 |
+
# ------------------------------
|
9 |
+
# Load Data
|
10 |
+
# ------------------------------
|
11 |
+
@st.cache_data
|
12 |
+
def load_data():
|
13 |
+
df = pd.read_csv("mea.csv", parse_dates=["date"])
|
14 |
+
# Ensure 'year' column exists
|
15 |
+
if "year" not in df.columns:
|
16 |
+
df["year"] = df["date"].dt.year
|
17 |
+
return df
|
18 |
+
|
19 |
+
df = load_data()
|
20 |
+
|
21 |
+
# ------------------------------
|
22 |
+
# Extract Unique Country Names
|
23 |
+
# ------------------------------
|
24 |
+
def get_unique_countries(df):
|
25 |
+
country_set = set()
|
26 |
+
for entry in df["countries"].dropna():
|
27 |
+
for country in entry.split(","):
|
28 |
+
country = country.strip()
|
29 |
+
if country:
|
30 |
+
country_set.add(country)
|
31 |
+
return sorted(list(country_set))
|
32 |
+
|
33 |
+
unique_countries = get_unique_countries(df)
|
34 |
+
|
35 |
+
# ------------------------------
|
36 |
+
# Load SentenceTransformer Model
|
37 |
+
# ------------------------------
|
38 |
+
@st.cache_resource
|
39 |
+
def load_model():
|
40 |
+
return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
41 |
+
|
42 |
+
model = load_model()
|
43 |
+
|
44 |
+
# ------------------------------
|
45 |
+
# Compute Embeddings for Unique Countries
|
46 |
+
# ------------------------------
|
47 |
+
@st.cache_resource
|
48 |
+
def compute_embeddings(countries):
|
49 |
+
return model.encode(countries, convert_to_tensor=False)
|
50 |
+
|
51 |
+
country_embeddings = compute_embeddings(unique_countries)
|
52 |
+
|
53 |
+
# ------------------------------
|
54 |
+
# Build FAISS Index
|
55 |
+
# ------------------------------
|
56 |
+
@st.cache_resource
|
57 |
+
def build_faiss_index(embeddings):
|
58 |
+
dimension = embeddings.shape[1]
|
59 |
+
index = faiss.IndexFlatL2(dimension)
|
60 |
+
index.add(np.array(embeddings))
|
61 |
+
return index
|
62 |
+
|
63 |
+
index = build_faiss_index(country_embeddings)
|
64 |
+
|
65 |
+
# Create a mapping from index to country name
|
66 |
+
country_map = {i: country for i, country in enumerate(unique_countries)}
|
67 |
+
|
68 |
+
# ------------------------------
|
69 |
+
# Sentiment Analysis Function
|
70 |
+
# ------------------------------
|
71 |
+
def sentiment_analysis(df, query, model, index, country_map, k, start_year):
|
72 |
+
# Encode the query term and search for similar country names
|
73 |
+
query_embedding = model.encode([query], convert_to_tensor=False)
|
74 |
+
distances, indices = index.search(np.array(query_embedding), k)
|
75 |
+
similar_countries = [country_map[idx] for idx in indices[0]]
|
76 |
+
|
77 |
+
st.write("**Similar country names to _{}_:**".format(query))
|
78 |
+
st.write(similar_countries)
|
79 |
+
|
80 |
+
# Filter the DataFrame using the similar country names (regex OR join)
|
81 |
+
country_variations = '|'.join(similar_countries)
|
82 |
+
filtered_df = df[df['countries'].str.contains(country_variations, case=False, na=False)]
|
83 |
+
filtered_df = filtered_df[filtered_df['year'] >= start_year]
|
84 |
+
|
85 |
+
if filtered_df.empty:
|
86 |
+
st.warning("No records found for the given query and start year.")
|
87 |
+
return
|
88 |
+
|
89 |
+
# Plot 1: Mean Sentiment per Year
|
90 |
+
mean_sentiment_per_year = filtered_df.groupby('year')['sentiment'].mean()
|
91 |
+
fig1, ax1 = plt.subplots()
|
92 |
+
ax1.plot(mean_sentiment_per_year.index, mean_sentiment_per_year, marker='o', color='r')
|
93 |
+
ax1.set_title(f'Mean Sentiment Score Over Years for "{query}"')
|
94 |
+
ax1.set_xlabel('Year')
|
95 |
+
ax1.set_ylabel('Mean Sentiment Score')
|
96 |
+
ax1.grid(True)
|
97 |
+
st.pyplot(fig1)
|
98 |
+
|
99 |
+
# Plot 2: Sentiment Scores Over Time (Scatter Plot)
|
100 |
+
fig2, ax2 = plt.subplots(figsize=(10, 6))
|
101 |
+
colors = filtered_df['sentiment'].apply(lambda x: 'red' if x < 0 else 'orange' if x > 0 else 'blue')
|
102 |
+
ax2.scatter(filtered_df['date'], filtered_df['sentiment'], marker='o', color=colors)
|
103 |
+
ax2.set_title(f'Sentiment Scores Over Time for "{query}"')
|
104 |
+
ax2.set_xlabel('Date')
|
105 |
+
ax2.set_ylabel('Sentiment Score')
|
106 |
+
ax2.grid(True)
|
107 |
+
st.pyplot(fig2)
|
108 |
+
|
109 |
+
# Display the average sentiment
|
110 |
+
average_sentiment = filtered_df['sentiment'].mean()
|
111 |
+
st.write(f'**Average sentiment of India towards "{query}" from {start_year} onwards = {average_sentiment:.2f}**')
|
112 |
+
|
113 |
+
# ------------------------------
|
114 |
+
# Streamlit User Interface
|
115 |
+
# ------------------------------
|
116 |
+
st.title("Sentiment Analysis: India & Country Relationship")
|
117 |
+
st.write("This app visualizes sentiment trends in India's press releases toward a selected country.")
|
118 |
+
|
119 |
+
# User inputs
|
120 |
+
query = st.text_input("Enter a country name (or variation) to search for:", "United States")
|
121 |
+
start_year = st.number_input("Enter start year (e.g., 2010):", min_value=1900, max_value=2100, value=2010)
|
122 |
+
k = st.number_input("Enter number of similar country variations to consider:", min_value=1, max_value=100, value=48)
|
123 |
+
|
124 |
+
if st.button("Analyze"):
|
125 |
+
sentiment_analysis(df, query, model, index, country_map, k, start_year)
|