File size: 3,667 Bytes
c2e612d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
from langchain_community.retrievers import BM25Retriever
import pandas as pd
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from operator import itemgetter
from langchain_core.prompts import PromptTemplate
from langchain_groq import ChatGroq
from langchain.chains.question_answering import load_qa_chain
import os


@st.cache_data
def load_data():
    df = pd.read_csv("C:\\Users\\JVC Store\\Downloads\\data\\data 6\\train.csv")
    df = df.drop(columns = ['Unnamed: 0','hs_code_2','hs_code_4'])
    documents = []

    for index, row in df.iterrows():
        text = row['full_description']
        hs_code = row['hs_code']
        documents.append(Document(page_content=text, metadata={'hs_code': hs_code}))

    splitter = CharacterTextSplitter(
        chunk_size=100,  
        chunk_overlap=0,
        separator = ' '
    )

    split_documents = []
    for doc in documents:
        chunks = splitter.split_text(doc.page_content)
        #remove chunk split word
        word_chunks = []
        current_chunk = []
        
        for chunk in chunks:
            words = chunk.split()
            for word in words:
                if len(' '.join(current_chunk + [word])) <=100:
                    current_chunk.append(word)
                else:
                    word_chunks.append(' '.join(current_chunk))
                    current_chunk = [word]
        if current_chunk:
            word_chunks.append(' '.join(current_chunk))
            
        split_documents.append(Document(page_content=word_chunks[0], metadata=doc.metadata))


    docs = []
    for doc in split_documents:
        metadata = doc.metadata
        metadata_str = str(metadata).strip('{}')
        page = doc.page_content
        docs.append([metadata_str + " " + page])


    cleaned_list = [item.replace('"','').replace("'",'') for items in docs for item in items]
    retriever = BM25Retriever.from_texts(cleaned_list)
    retriever.k = 5
    return retriever


def load_llm():

    api_key2 = "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"

    llm2 = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key2)
    return llm2


def predict(sentence,retriever,llm2):
    sentence = sentence.lower()
    context = retriever.get_relevant_documents(sentence)
    #print("context:",context)
    template2 = """

    You are an expert in HS Code classification. 

    Based on the provided product description, accurately determine and return only one 6-digit HS Code that best matches the description.

    Always return the HS Code as a 6-digit number only.

    example: 123456

    Context:\n {context} \n

    Description:\n {description} \n

    Answer:

    """
    prompt2 = PromptTemplate(template=template2, input_variables=['context','description'])
    chain = load_qa_chain(llm2, chain_type = 'stuff', prompt = prompt2)
    response = chain.invoke({'input_documents': context, 'description':sentence})
    answer = response.get("output_text")
    return answer


if 'retriever' not in st.session_state:
    st.session_state.retriever = None
if 'llm' not in st.session_state:
    st.session_state.llm = None

if st.session_state.retriever is None:
    st.session_state.retriever = load_data()

if st.session_state.llm is None:
    st.session_state.llm = load_llm()

sentence = st.text_input("please enter description:")

if sentence !='':
    answer = predict(sentence,st.session_state.retriever,st.session_state.llm )
    st.write("answer:",answer)