|
import streamlit as st
|
|
from langchain_community.retrievers import BM25Retriever
|
|
import pandas as pd
|
|
from langchain.docstore.document import Document
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from operator import itemgetter
|
|
from langchain_core.prompts import PromptTemplate
|
|
from langchain_groq import ChatGroq
|
|
from langchain.chains.question_answering import load_qa_chain
|
|
import os
|
|
|
|
|
|
@st.cache_data
|
|
def load_data():
|
|
df = pd.read_csv("C:\\Users\\JVC Store\\Downloads\\data\\data 6\\train.csv")
|
|
df = df.drop(columns = ['Unnamed: 0','hs_code_2','hs_code_4'])
|
|
documents = []
|
|
|
|
for index, row in df.iterrows():
|
|
text = row['full_description']
|
|
hs_code = row['hs_code']
|
|
documents.append(Document(page_content=text, metadata={'hs_code': hs_code}))
|
|
|
|
splitter = CharacterTextSplitter(
|
|
chunk_size=100,
|
|
chunk_overlap=0,
|
|
separator = ' '
|
|
)
|
|
|
|
split_documents = []
|
|
for doc in documents:
|
|
chunks = splitter.split_text(doc.page_content)
|
|
|
|
word_chunks = []
|
|
current_chunk = []
|
|
|
|
for chunk in chunks:
|
|
words = chunk.split()
|
|
for word in words:
|
|
if len(' '.join(current_chunk + [word])) <=100:
|
|
current_chunk.append(word)
|
|
else:
|
|
word_chunks.append(' '.join(current_chunk))
|
|
current_chunk = [word]
|
|
if current_chunk:
|
|
word_chunks.append(' '.join(current_chunk))
|
|
|
|
split_documents.append(Document(page_content=word_chunks[0], metadata=doc.metadata))
|
|
|
|
|
|
docs = []
|
|
for doc in split_documents:
|
|
metadata = doc.metadata
|
|
metadata_str = str(metadata).strip('{}')
|
|
page = doc.page_content
|
|
docs.append([metadata_str + " " + page])
|
|
|
|
|
|
cleaned_list = [item.replace('"','').replace("'",'') for items in docs for item in items]
|
|
retriever = BM25Retriever.from_texts(cleaned_list)
|
|
retriever.k = 5
|
|
return retriever
|
|
|
|
|
|
def load_llm():
|
|
|
|
api_key2 = "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"
|
|
|
|
llm2 = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key2)
|
|
return llm2
|
|
|
|
|
|
def predict(sentence,retriever,llm2):
|
|
sentence = sentence.lower()
|
|
context = retriever.get_relevant_documents(sentence)
|
|
|
|
template2 = """
|
|
You are an expert in HS Code classification.
|
|
Based on the provided product description, accurately determine and return only one 6-digit HS Code that best matches the description.
|
|
Always return the HS Code as a 6-digit number only.
|
|
example: 123456
|
|
Context:\n {context} \n
|
|
Description:\n {description} \n
|
|
Answer:
|
|
"""
|
|
prompt2 = PromptTemplate(template=template2, input_variables=['context','description'])
|
|
chain = load_qa_chain(llm2, chain_type = 'stuff', prompt = prompt2)
|
|
response = chain.invoke({'input_documents': context, 'description':sentence})
|
|
answer = response.get("output_text")
|
|
return answer
|
|
|
|
|
|
if 'retriever' not in st.session_state:
|
|
st.session_state.retriever = None
|
|
if 'llm' not in st.session_state:
|
|
st.session_state.llm = None
|
|
|
|
if st.session_state.retriever is None:
|
|
st.session_state.retriever = load_data()
|
|
|
|
if st.session_state.llm is None:
|
|
st.session_state.llm = load_llm()
|
|
|
|
sentence = st.text_input("please enter description:")
|
|
|
|
if sentence !='':
|
|
answer = predict(sentence,st.session_state.retriever,st.session_state.llm )
|
|
st.write("answer:",answer) |