nishjay commited on
Commit
3d65ea8
·
verified ·
1 Parent(s): 0694408

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+
8
+ # ------------------------------
9
+ # Load Data
10
+ # ------------------------------
11
+ @st.cache_data
12
+ def load_data():
13
+ df = pd.read_csv("mea.csv", parse_dates=["date"])
14
+ # Ensure 'year' column exists
15
+ if "year" not in df.columns:
16
+ df["year"] = df["date"].dt.year
17
+ return df
18
+
19
+ df = load_data()
20
+
21
+ # ------------------------------
22
+ # Extract Unique Country Names
23
+ # ------------------------------
24
+ def get_unique_countries(df):
25
+ country_set = set()
26
+ for entry in df["countries"].dropna():
27
+ for country in entry.split(","):
28
+ country = country.strip()
29
+ if country:
30
+ country_set.add(country)
31
+ return sorted(list(country_set))
32
+
33
+ unique_countries = get_unique_countries(df)
34
+
35
+ # ------------------------------
36
+ # Load SentenceTransformer Model
37
+ # ------------------------------
38
+ @st.cache_resource
39
+ def load_model():
40
+ return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
41
+
42
+ model = load_model()
43
+
44
+ # ------------------------------
45
+ # Compute Embeddings for Unique Countries
46
+ # ------------------------------
47
+ @st.cache_resource
48
+ def compute_embeddings(countries):
49
+ return model.encode(countries, convert_to_tensor=False)
50
+
51
+ country_embeddings = compute_embeddings(unique_countries)
52
+
53
+ # ------------------------------
54
+ # Build FAISS Index
55
+ # ------------------------------
56
+ @st.cache_resource
57
+ def build_faiss_index(embeddings):
58
+ dimension = embeddings.shape[1]
59
+ index = faiss.IndexFlatL2(dimension)
60
+ index.add(np.array(embeddings))
61
+ return index
62
+
63
+ index = build_faiss_index(country_embeddings)
64
+
65
+ # Create a mapping from index to country name
66
+ country_map = {i: country for i, country in enumerate(unique_countries)}
67
+
68
+ # ------------------------------
69
+ # Sentiment Analysis Function
70
+ # ------------------------------
71
+ def sentiment_analysis(df, query, model, index, country_map, k, start_year):
72
+ # Encode the query term and search for similar country names
73
+ query_embedding = model.encode([query], convert_to_tensor=False)
74
+ distances, indices = index.search(np.array(query_embedding), k)
75
+ similar_countries = [country_map[idx] for idx in indices[0]]
76
+
77
+ st.write("**Similar country names to _{}_:**".format(query))
78
+ st.write(similar_countries)
79
+
80
+ # Filter the DataFrame using the similar country names (regex OR join)
81
+ country_variations = '|'.join(similar_countries)
82
+ filtered_df = df[df['countries'].str.contains(country_variations, case=False, na=False)]
83
+ filtered_df = filtered_df[filtered_df['year'] >= start_year]
84
+
85
+ if filtered_df.empty:
86
+ st.warning("No records found for the given query and start year.")
87
+ return
88
+
89
+ # Plot 1: Mean Sentiment per Year
90
+ mean_sentiment_per_year = filtered_df.groupby('year')['sentiment'].mean()
91
+ fig1, ax1 = plt.subplots()
92
+ ax1.plot(mean_sentiment_per_year.index, mean_sentiment_per_year, marker='o', color='r')
93
+ ax1.set_title(f'Mean Sentiment Score Over Years for "{query}"')
94
+ ax1.set_xlabel('Year')
95
+ ax1.set_ylabel('Mean Sentiment Score')
96
+ ax1.grid(True)
97
+ st.pyplot(fig1)
98
+
99
+ # Plot 2: Sentiment Scores Over Time (Scatter Plot)
100
+ fig2, ax2 = plt.subplots(figsize=(10, 6))
101
+ colors = filtered_df['sentiment'].apply(lambda x: 'red' if x < 0 else 'orange' if x > 0 else 'blue')
102
+ ax2.scatter(filtered_df['date'], filtered_df['sentiment'], marker='o', color=colors)
103
+ ax2.set_title(f'Sentiment Scores Over Time for "{query}"')
104
+ ax2.set_xlabel('Date')
105
+ ax2.set_ylabel('Sentiment Score')
106
+ ax2.grid(True)
107
+ st.pyplot(fig2)
108
+
109
+ # Display the average sentiment
110
+ average_sentiment = filtered_df['sentiment'].mean()
111
+ st.write(f'**Average sentiment of India towards "{query}" from {start_year} onwards = {average_sentiment:.2f}**')
112
+
113
+ # ------------------------------
114
+ # Streamlit User Interface
115
+ # ------------------------------
116
+ st.title("Sentiment Analysis: India & Country Relationship")
117
+ st.write("This app visualizes sentiment trends in India's press releases toward a selected country.")
118
+
119
+ # User inputs
120
+ query = st.text_input("Enter a country name (or variation) to search for:", "United States")
121
+ start_year = st.number_input("Enter start year (e.g., 2010):", min_value=1900, max_value=2100, value=2010)
122
+ k = st.number_input("Enter number of similar country variations to consider:", min_value=1, max_value=100, value=48)
123
+
124
+ if st.button("Analyze"):
125
+ sentiment_analysis(df, query, model, index, country_map, k, start_year)