Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import streamlit as st | |
import pandas as pd | |
from langchain_text_splitters import TokenTextSplitter | |
from langchain.docstore.document import Document | |
from torch import cuda | |
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInferenceAPIEmbeddings | |
from langchain_community.vectorstores import Qdrant | |
from qdrant_client import QdrantClient | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import CrossEncoderReranker | |
from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
from langchain_qdrant import FastEmbedSparse, RetrievalMode | |
# get the device to be used eithe gpu or cpu | |
device = 'cuda' if cuda.is_available() else 'cpu' | |
st.set_page_config(page_title="SEARCH IATI",layout='wide') | |
st.title("SEARCH IATI Database") | |
var=st.text_input("enter keyword") | |
def create_chunks(text): | |
"""TAKES A TEXT AND CERATES CREATES CHUNKS""" | |
# chunk size in terms of token | |
text_splitter = TokenTextSplitter(chunk_size=500, chunk_overlap=0) | |
texts = text_splitter.split_text(text) | |
return texts | |
def get_chunks(): | |
""" | |
this will read the iati files and create the chunks | |
""" | |
orgas_df = pd.read_csv("iati_files/project_orgas.csv") | |
region_df = pd.read_csv("iati_files/project_region.csv") | |
sector_df = pd.read_csv("iati_files/project_sector.csv") | |
status_df = pd.read_csv("iati_files/project_status.csv") | |
texts_df = pd.read_csv("iati_files/project_texts.csv") | |
projects_df = pd.merge(orgas_df, region_df, on='iati_id', how='inner') | |
projects_df = pd.merge(projects_df, sector_df, on='iati_id', how='inner') | |
projects_df = pd.merge(projects_df, status_df, on='iati_id', how='inner') | |
projects_df = pd.merge(projects_df, texts_df, on='iati_id', how='inner') | |
projects_df = projects_df[projects_df.client.str.contains('bmz')].reset_index(drop=True) | |
projects_df.drop(columns= ['orga_abbreviation', 'client', | |
'orga_full_name', 'country', | |
'country_flag', 'crs_5_code', 'crs_3_code','country_code_list', | |
'sgd_pred_code','crs_5_name', 'crs_3_name', 'sgd_pred_str'], inplace=True) | |
print(projects_df.columns) | |
projects_df['text_size'] = projects_df.apply(lambda x: len((x['title_main'] + x['description_main']).split()), axis=1) | |
projects_df['chunks'] = projects_df.apply(lambda x:create_chunks(x['title_main'] + x['description_main']),axis=1) | |
projects_df = projects_df.explode(column=['chunks'], ignore_index=True) | |
projects_df['source'] = 'IATI' | |
projects_df.rename(columns = {'iati_id':'id','iati_orga_id':'org'}, inplace=True) | |
#### code for eading the giz_worldwide data | |
giz_df = pd.read_json('iati_files/data_giz_website.json') | |
giz_df = giz_df.rename(columns={'content':'project_description'}) | |
giz_df['text_size'] = giz_df.apply(lambda x: len((x['project_name'] + x['project_description']).split()), axis=1) | |
giz_df['chunks'] = giz_df.apply(lambda x:create_chunks(x['project_name'] + x['project_description']),axis=1) | |
giz_df = giz_df.explode(column=['chunks'], ignore_index=True) | |
print(giz_df.columns) | |
giz_df.drop(columns = ['filename', 'url', 'name', 'mail', | |
'language', 'start_year', 'end_year','poli_trager'], inplace=True) | |
giz_df.rename(columns = {'project_name':'title_main','countries':'country_name', | |
'client':'org','project_description':'description_main'}, inplace=True) | |
giz_df['source'] = 'GIZ_WORLDWIDE' | |
giz_df['status'] = "None" | |
df = pd.concat([projects_df,giz_df],ignore_index=True) | |
print(df.columns) | |
print(df) | |
placeholder= [] | |
for i in range(len(df)): | |
placeholder.append(Document(page_content= df.loc[i,'chunks'], | |
metadata={"id": df.loc[i,'id'], | |
"org":df.loc[i,'org'], | |
"country_name":str(df.loc[i,'country_name']), | |
"status":df.loc[i,'status'], | |
"title_main":df.loc[i,'title_main'],})) | |
return placeholder | |
# placeholder= [] | |
# for i in range(len(giz_df)): | |
# placeholder.append(Document(page_content= giz_df.loc[i,'chunks'], | |
# metadata={ | |
# "title_main":giz_df.loc[i,'title_main'], | |
# "country_name":str(giz_df.loc[i,'countries']), | |
# "client": giz_df_new.loc[i, 'client'], | |
# "language":giz_df_new.loc[i, 'language'], | |
# "political_sponsor":giz_df.loc[i, 'poli_trager'], | |
# "url": giz_df.loc[i, 'url'] | |
# #"iati_id": giz_df.loc[i,'iati_id'], | |
# #"iati_orga_id":giz_df.loc[i,'iati_orga_id'], | |
# #"crs_5_name": giz_df.loc[i,'crs_5_name'], | |
# #"crs_3_name": giz_df.loc[i,'crs_3_name'], | |
# #"sgd_pred_str":giz_df.loc[i,'sgd_pred_str'], | |
# #"status":giz_df.loc[i,'status'], | |
# })) | |
# return placeholder | |
def embed_chunks(chunks): | |
""" | |
takes the chunks and does the hybrid embedding for the list of chunks | |
""" | |
embeddings = HuggingFaceEmbeddings( | |
model_kwargs = {'device': device}, | |
encode_kwargs = {'normalize_embeddings': True}, | |
model_name='BAAI/bge-m3' | |
) | |
#sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25") | |
# placeholder for collection | |
print("starting embedding") | |
qdrant_collections = {} | |
qdrant_collections['all'] = Qdrant.from_documents( | |
chunks, | |
embeddings, | |
path="/data/local_qdrant", | |
collection_name='all', | |
) | |
print(qdrant_collections) | |
print("vector embeddings done") | |
def get_local_qdrant(): | |
"""once the local qdrant server is created this is used to make the connection to exisitng server""" | |
qdrant_collections = {} | |
embeddings = HuggingFaceEmbeddings( | |
model_kwargs = {'device': device}, | |
encode_kwargs = {'normalize_embeddings': True}, | |
model_name='BAAI/bge-m3') | |
client = QdrantClient(path="/data/local_qdrant") | |
print("Collections in local Qdrant:",client.get_collections()) | |
qdrant_collections['all'] = Qdrant(client=client, collection_name='all', embeddings=embeddings, ) | |
return qdrant_collections | |
def get_context(vectorstore,query): | |
# create metadata filter | |
# getting context | |
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", | |
search_kwargs={"score_threshold": 0.5, | |
"k": 10,}) | |
# # re-ranking the retrieved results | |
# model = HuggingFaceCrossEncoder(model_name=model_config.get('ranker','MODEL')) | |
# compressor = CrossEncoderReranker(model=model, top_n=int(model_config.get('ranker','TOP_K'))) | |
# compression_retriever = ContextualCompressionRetriever( | |
# base_compressor=compressor, base_retriever=retriever | |
# ) | |
context_retrieved = retriever.invoke(query) | |
print(f"retrieved paragraphs:{len(context_retrieved)}") | |
return context_retrieved | |
# first we create the chunks for iati documents | |
#chunks = get_chunks() | |
#print("chunking done") | |
# once the chunks are done, we perform hybrid emebddings | |
#embed_chunks(chunks) | |
vectorstores = get_local_qdrant() | |
vectorstore = vectorstores['all'] | |
button=st.button("search") | |
results= get_context(vectorstore, f"find the relvant paragraphs for: {var}") | |
if button: | |
st.write(f"Found {len(results)} results for query:{var}") | |
for i in results: | |
st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main'])) | |
st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}") | |
st.write(i.page_content) | |
st.divider() | |