hscode / main.py
tien314's picture
Upload main.py
c2e612d verified
raw
history blame
3.67 kB
import streamlit as st
from langchain_community.retrievers import BM25Retriever
import pandas as pd
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from operator import itemgetter
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from langchain.chains.question_answering import load_qa_chain
import os
@st.cache_data
def load_data():
df = pd.read_csv("C:\\Users\\JVC Store\\Downloads\\data\\data 6\\train.csv")
df = df.drop(columns = ['Unnamed: 0','hs_code_2','hs_code_4'])
documents = []
for index, row in df.iterrows():
text = row['full_description']
hs_code = row['hs_code']
documents.append(Document(page_content=text, metadata={'hs_code': hs_code}))
splitter = CharacterTextSplitter(
chunk_size=100,
chunk_overlap=0,
separator = ' '
)
split_documents = []
for doc in documents:
chunks = splitter.split_text(doc.page_content)
#remove chunk split word
word_chunks = []
current_chunk = []
for chunk in chunks:
words = chunk.split()
for word in words:
if len(' '.join(current_chunk + [word])) <=100:
current_chunk.append(word)
else:
word_chunks.append(' '.join(current_chunk))
current_chunk = [word]
if current_chunk:
word_chunks.append(' '.join(current_chunk))
split_documents.append(Document(page_content=word_chunks[0], metadata=doc.metadata))
docs = []
for doc in split_documents:
metadata = doc.metadata
metadata_str = str(metadata).strip('{}')
page = doc.page_content
docs.append([metadata_str + " " + page])
cleaned_list = [item.replace('"','').replace("'",'') for items in docs for item in items]
retriever = BM25Retriever.from_texts(cleaned_list)
retriever.k = 5
return retriever
def load_llm():
api_key2 = "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"
llm2 = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key2)
return llm2
def predict(sentence,retriever,llm2):
sentence = sentence.lower()
context = retriever.get_relevant_documents(sentence)
#print("context:",context)
template2 = """
You are an expert in HS Code classification.
Based on the provided product description, accurately determine and return only one 6-digit HS Code that best matches the description.
Always return the HS Code as a 6-digit number only.
example: 123456
Context:\n {context} \n
Description:\n {description} \n
Answer:
"""
prompt2 = PromptTemplate(template=template2, input_variables=['context','description'])
chain = load_qa_chain(llm2, chain_type = 'stuff', prompt = prompt2)
response = chain.invoke({'input_documents': context, 'description':sentence})
answer = response.get("output_text")
return answer
if 'retriever' not in st.session_state:
st.session_state.retriever = None
if 'llm' not in st.session_state:
st.session_state.llm = None
if st.session_state.retriever is None:
st.session_state.retriever = load_data()
if st.session_state.llm is None:
st.session_state.llm = load_llm()
sentence = st.text_input("please enter description:")
if sentence !='':
answer = predict(sentence,st.session_state.retriever,st.session_state.llm )
st.write("answer:",answer)