tien314 commited on
Commit
3de78eb
·
verified ·
1 Parent(s): 1e3a503

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -90
app.py CHANGED
@@ -1,108 +1,68 @@
1
  import streamlit as st
2
- from langchain_community.retrievers import BM25Retriever
3
  import pandas as pd
 
 
 
4
  from langchain.docstore.document import Document
5
- from langchain.text_splitter import CharacterTextSplitter
6
- from operator import itemgetter
7
- from langchain_core.prompts import PromptTemplate
8
- from langchain_groq import ChatGroq
9
- from langchain.chains.question_answering import load_qa_chain
10
  import os
 
 
11
 
12
 
13
- @st.cache_data
14
  def load_data():
15
- df = pd.read_csv("trained.csv")
16
- df = df.drop(columns = ['Unnamed: 0','hs_code_2','hs_code_4'])
17
- documents = []
18
-
19
- for index, row in df.iterrows():
20
- text = row['full_description']
21
- hs_code = row['hs_code']
22
- documents.append(Document(page_content=text, metadata={'hs_code': hs_code}))
23
-
24
- splitter = CharacterTextSplitter(
25
- chunk_size=100,
26
- chunk_overlap=0,
27
- separator = ' '
28
- )
29
-
30
- split_documents = []
31
- for doc in documents:
32
- chunks = splitter.split_text(doc.page_content)
33
- #remove chunk split word
34
- word_chunks = []
35
- current_chunk = []
36
-
37
- for chunk in chunks:
38
- words = chunk.split()
39
- for word in words:
40
- if len(' '.join(current_chunk + [word])) <=100:
41
- current_chunk.append(word)
42
- else:
43
- word_chunks.append(' '.join(current_chunk))
44
- current_chunk = [word]
45
- if current_chunk:
46
- word_chunks.append(' '.join(current_chunk))
47
-
48
- split_documents.append(Document(page_content=word_chunks[0], metadata=doc.metadata))
49
-
50
-
51
- docs = []
52
- for doc in split_documents:
53
- metadata = doc.metadata
54
- metadata_str = str(metadata).strip('{}')
55
- page = doc.page_content
56
- docs.append([metadata_str + " " + page])
57
-
58
-
59
- cleaned_list = [item.replace('"','').replace("'",'') for items in docs for item in items]
60
- retriever = BM25Retriever.from_texts(cleaned_list)
61
- retriever.k = 5
62
  return retriever
63
 
64
-
65
- def load_llm():
66
-
67
- api_key2 = "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"
68
-
69
- llm2 = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key2)
70
- return llm2
71
-
72
-
73
- def predict(sentence,retriever,llm2):
74
- sentence = sentence.lower()
75
- context = retriever.get_relevant_documents(sentence)
76
- #print("context:",context)
77
- template2 = """
78
- You are an expert in HS Code classification.
79
- Based on the provided product description, accurately determine and return only one 6-digit HS Code that best matches the description.
80
- Always return the HS Code as a 6-digit number only.
81
- example: 123456
82
- Context:\n {context} \n
83
- Description:\n {description} \n
84
- Answer:
85
- """
86
- prompt2 = PromptTemplate(template=template2, input_variables=['context','description'])
87
- chain = load_qa_chain(llm2, chain_type = 'stuff', prompt = prompt2)
88
- response = chain.invoke({'input_documents': context, 'description':sentence})
89
- answer = response.get("output_text")
90
- return answer
91
-
92
-
 
 
93
  if 'retriever' not in st.session_state:
94
  st.session_state.retriever = None
95
- if 'llm' not in st.session_state:
96
- st.session_state.llm = None
97
 
 
 
 
98
  if st.session_state.retriever is None:
99
  st.session_state.retriever = load_data()
100
 
101
- if st.session_state.llm is None:
102
- st.session_state.llm = load_llm()
103
-
104
  sentence = st.text_input("please enter description:")
105
 
106
  if sentence !='':
107
- answer = predict(sentence,st.session_state.retriever,st.session_state.llm )
108
- st.write("answer:",answer)
 
 
1
  import streamlit as st
 
2
  import pandas as pd
3
+ import bm25s
4
+ from bm25s.hf import BM25HF
5
+ from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
6
  from langchain.docstore.document import Document
7
+ import torch
 
 
 
 
8
  import os
9
+ from huggingface_hub import login
10
+ from langchain_groq import ChatGroq
11
 
12
 
13
+ @st.cache_resource
14
  def load_data():
15
+
16
+ retriever = BM25HF.load_from_hub(
17
+ "tien314/hscode8", load_corpus=True, mmap=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  return retriever
19
 
20
+ def load_model():
21
+ prompt = ChatPromptTemplate.from_messages([
22
+ HumanMessagePromptTemplate.from_template(
23
+ f"""
24
+ Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results.
25
+ Only return the HS Code as a 8-digit number .
26
+ Example: 1234567878
27
+ Context: {{context}}
28
+ Description: {{description}}
29
+ Answer:
30
+ """
31
+ )
32
+ ])
33
+
34
+
35
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ #llm = OllamaLLM(model="gemma2", temperature=0, device=device)
38
+ #api_key = "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"
39
+ api_key = "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"
40
+ llm = ChatGroq(model = "llama-3.1-70b-versatile", temperature = 0,api_key = api_key)
41
+ chain = prompt|llm
42
+ return chain
43
+
44
+ def process_input(sentence):
45
+ docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15)
46
+ documents =[]
47
+ for doc in docs[0]:
48
+ documents.append(Document(doc['text']))
49
+ return documents
50
+
51
  if 'retriever' not in st.session_state:
52
  st.session_state.retriever = None
 
 
53
 
54
+ if 'chain' not in st.session_state:
55
+ st.session_state.chain = None
56
+
57
  if st.session_state.retriever is None:
58
  st.session_state.retriever = load_data()
59
 
60
+ if st.session_state.chain is None:
61
+ st.session_state.chain = load_model()
62
+
63
  sentence = st.text_input("please enter description:")
64
 
65
  if sentence !='':
66
+ documents = process_input(sentence)
67
+ hscode = st.session_state.chain.invoke({'context': documents,'description':sentence})
68
+ st.write("answer:",hscode.content)