tien314 commited on
Commit
9dad51b
·
verified ·
1 Parent(s): 3b49b20

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import bm25s
3
+ from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
4
+ from langchain.docstore.document import Document
5
+ import torch
6
+ from langchain_ollama.llms import OllamaLLM
7
+ from langchain.chains import LLMChain
8
+
9
+ @st.cache_data
10
+ def load_data():
11
+ retriever = bm25s.BM25.load("bm25s_very_big_index", mmap=True, load_corpus = True)
12
+ return retriever
13
+
14
+ def load_model():
15
+ prompt = ChatPromptTemplate.from_messages([
16
+ HumanMessagePromptTemplate.from_template(
17
+ f"""
18
+ Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results.
19
+ Only return the HS Code as a 8-digit number .
20
+ Example: 1234567878
21
+ Context: {{context}}
22
+ Description: {{description}}
23
+ Answer:
24
+ """
25
+ )
26
+ ])
27
+
28
+
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ llm = OllamaLLM(model="gemma2", temperature=0, device=device)
32
+
33
+ chain = prompt|llm
34
+ return chain
35
+
36
+ def process_input(sentence):
37
+ docs, _ = retriever.retrieve(bm25s.tokenize(sentence), k=15)
38
+ documents =[]
39
+ for doc in docs[0]:
40
+ documents.append(Document(doc['text']))
41
+ return documents
42
+
43
+ if 'retriever' not in st.session_state:
44
+ st.session_state.retriever = None
45
+
46
+ if 'chain' not in st.session_state:
47
+ st.session_state.chain = None
48
+
49
+ if st.session_state.retriever is None:
50
+ st.session_state.retriever = load_data()
51
+
52
+ if st.session_state.chain is None:
53
+ st.session_state.chain = load_model()
54
+
55
+ sentence = st.text_input("please enter description:")
56
+
57
+ if sentence !='':
58
+ documents = process_input(sentence)
59
+ hscode = chain.invoke({'context': documents,'description':sentence})
60
+ st.write("answer:",hscode)
61
+