import streamlit as st import bm25s from operator import itemgetter import os import re import pandas as pd from langchain_groq import ChatGroq from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate from langchain.docstore.document import Document @st.cache_data def load_data(): df = pd.read_csv("cleaned_list.csv",header = None) df.columns = ['document'] corpus = [doc for doc in df['document'].to_list()] retriever = bm25s.BM25(corpus=corpus) retriever.index(bm25s.tokenize(corpus)) return retriever # def extract_hscode(text): # match = re.search(r'hs_code:\s*(\d+)', text) # if match: # return match.group(1) # return None # df2 = pd.read_csv("hscode_main.csv") # new_col = [len(str(code))for code in df2['hs_code'].to_list()] # df2['len'] = new_col # new_hscode = [str(code) for code in df2['hs_code']] # for i in range(len(new_col)): # if new_col[i]==5: # new_hscode[i] = '0'+ new_hscode[i] # df2['hs_code'] = new_hscode # df2=df2.drop(columns='len') # if 'retriever' not in st.session_state: # st.session_state.retriever = None # if st.session_state.retriever is None: # st.session_state.retriever = load_data() # sentence = st.text_input("please enter description:") # if sentence !='': # results,_ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=5) # doc = [d for d in results] # hscodes = [extract_hscode(item) for item in doc[0]] # for code in hscodes: # if len(code)==5: # code = '0'+ code # filter_df = df2[df2['hs_code']==code] # answer = filter_df['description'].iloc[0] # st.write("Hscode:",code) # st.write("Description:",answer.lower()) def load_model(): prompt = ChatPromptTemplate.from_messages([ HumanMessagePromptTemplate.from_template( f""" Extract the appropriate 6-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results. Only return the HS Code as a 6-digit number . Example: 123456 Context: {{context}} Description: {{description}} Answer: """ ) ]) #device = "cuda" if torch.cuda.is_available() else "cpu" #llm = OllamaLLM(model="gemma2", temperature=0, device=device) #api_key = "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" api_key = "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" llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key) chain = prompt|llm return chain def process_input(sentence): docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15) documents =[] for doc in docs[0]: documents.append(Document(doc)) return documents if 'retriever' not in st.session_state: st.session_state.retriever = None if 'chain' not in st.session_state: st.session_state.chain = None if st.session_state.retriever is None: st.session_state.retriever = load_data() if st.session_state.chain is None: st.session_state.chain = load_model() sentence = st.text_input("please enter description:") if sentence !='': documents = process_input(sentence) hscode = st.session_state.chain.invoke({'context': documents,'description':sentence}) st.write("answer:",hscode.content)