Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import time | |
import numpy as np | |
import pandas as pd | |
def add_custom_css(): | |
st.markdown(""" | |
<style> | |
.container { | |
text-align: center; | |
background-color: #f0f0f0; | |
padding: 20px; | |
} | |
.big-font { | |
font-size: 50px; | |
color: #4CAF50; | |
} | |
.progress-bar { | |
margin-top: 20px; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
if 'packages_installed' not in st.session_state: | |
st.info("Installing required packages...") | |
os.system("pip install -U sentence-transformers") | |
os.system("pip install pinecone-client") | |
st.session_state['packages_installed'] = True | |
from sentence_transformers import SentenceTransformer | |
from pinecone import Pinecone, ServerlessSpec, PodSpec | |
if 'pc' not in st.session_state: | |
use_serverless = False | |
# Configure Pinecone client | |
api_key = os.environ.get('PINECONE_API_KEY', '28b0fd5a-fdfb-422d-9a44-c0ec09a25074') | |
environment = os.environ.get('PINECONE_ENVIRONMENT', 'gcp-starter') | |
st.session_state['pc'] = Pinecone(api_key=api_key) | |
if use_serverless: | |
spec = ServerlessSpec(cloud='gcp', region='asia-southeast1-gcp') | |
else: | |
spec = PodSpec(environment=environment) | |
if 'model' not in st.session_state: | |
st.session_state['model'] = SentenceTransformer('intfloat/e5-small') | |
index_name = 'dataset' | |
if index_name not in st.session_state.pc.list_indexes().names(): | |
dimensions = 384 | |
st.session_state.pc.create_index( | |
name=index_name, | |
dimension=dimensions, | |
metric='cosine', | |
spec=spec | |
) | |
# Wait until index is ready | |
while not st.session_state.pc.describe_index(index_name).status['ready']: | |
time.sleep(1) | |
if 'index' not in st.session_state: | |
st.session_state['index'] = st.session_state.pc.Index(index_name) | |
# Function to process data and insert into Pinecone index | |
def process_data(data, namespace): | |
input_texts = data['Query'] | |
progress_bar = st.progress(0) | |
total_chunks = len(data) // 1000 + 1 | |
for chunk_start in range(0, len(data), 1000): | |
chunk_end = min(chunk_start + 1000, len(data)) | |
chunk = data.iloc[chunk_start:chunk_end] | |
# Generate embeddings for the current chunk | |
chunk_embeddings = [st.session_state.model.encode(query, normalize_embeddings=True) for query in chunk['Query']] | |
chunk['embedding'] = chunk_embeddings | |
# Upsert embeddings | |
st.session_state.index.upsert(vectors=zip(chunk['id'], chunk['embedding']), namespace=namespace) | |
# Update progress bar | |
progress = (chunk_end / len(data)) * 100 | |
progress_bar.progress(int(progress)) | |
def load_and_process_data(file): | |
data = pd.read_csv(file) | |
data['id'] = data.index.astype(str) | |
namespace = file.name[:15] # Use first 15 characters of file name as namespace | |
if 'embeddings_done' not in st.session_state: | |
process_data(data, namespace) | |
st.session_state['embeddings_done'] = True | |
return data, namespace | |
def main(): | |
add_custom_css() | |
st.markdown(""" | |
<div class='container'> | |
<h1 class='big-font'>Semantic Search Engine</h1> | |
</div> | |
""", unsafe_allow_html=True) | |
# Use session state to retain information across interactions | |
if 'namespace' not in st.session_state: | |
st.session_state.namespace = None | |
if 'df' not in st.session_state: | |
st.session_state.df = None | |
uploaded_file = st.file_uploader("Upload dataset (CSV format)", type=["csv"]) | |
if uploaded_file is not None: | |
filename = uploaded_file.name | |
namespace = filename.split('.')[0] | |
st.info("Dataset Processing Started...") | |
st.session_state.df, st.session_state.namespace = load_and_process_data(uploaded_file) | |
st.info("Dataset Processing Completed...") | |
if st.session_state.namespace: | |
query = st.text_input("Enter your query about the data (or type 'exit' to quit):") | |
if query.lower() != 'exit': | |
vec = st.session_state.model.encode(query) | |
result = None | |
result = st.session_state.index.query( | |
namespace=st.session_state.namespace, | |
vector=vec.tolist(), | |
top_k=5, | |
include_values=False | |
) | |
st.subheader("Query Results:") | |
if result is not None: | |
id = result['matches'][0]['id'] | |
data = st.session_state.df | |
answer = data[data['id'] == id]['Answer'].values[0] | |
st.write(answer) | |
if st.button("Delete Stored Data"): | |
st.session_state.index.delete(deleteAll=True, namespace =st.session_state.namespace) | |
st.stop() | |
if __name__ == "__main__": | |
main() | |