File size: 2,223 Bytes
73cdc6b 9dad51b 49cc9c5 9dad51b 6321de9 9f2a90c 73cdc6b 9dad51b a7c96df 9dad51b 49cc9c5 647872d 9dad51b 9f2a90c 9dad51b 9f2a90c 307dfae 36f0086 9dad51b a7c96df 9dad51b edf39fd a7c96df 49e025c 9dad51b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import streamlit as st
import pandas as pd
import bm25s
from bm25s.hf import BM25HF
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.docstore.document import Document
import torch
import os
from huggingface_hub import login
from langchain_groq import ChatGroq
@st.cache_resource
def load_data():
retriever = BM25HF.load_from_hub(
"tien314/hscode8", load_corpus=True, mmap=True)
return retriever
def load_model():
prompt = ChatPromptTemplate.from_messages([
HumanMessagePromptTemplate.from_template(
f"""
Extract the appropriate 8-digit HS Code base on the product description and retrieved document by thoroughly analyzing its details and utilizing a reliable and up-to-date HS Code database for accurate results.
Only return the HS Code as a 8-digit number .
Example: 1234567878
Context: {{context}}
Description: {{description}}
Answer:
"""
)
])
#device = "cuda" if torch.cuda.is_available() else "cpu"
#llm = OllamaLLM(model="gemma2", temperature=0, device=device)
#api_key = "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"
api_key = "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"
llm = ChatGroq(model = "gemma2-9b-it", temperature = 0,api_key = api_key)
chain = prompt|llm
return chain
def process_input(sentence):
docs, _ = st.session_state.retriever.retrieve(bm25s.tokenize(sentence), k=15)
documents =[]
for doc in docs[0]:
documents.append(Document(doc['text']))
return documents
if 'retriever' not in st.session_state:
st.session_state.retriever = None
if 'chain' not in st.session_state:
st.session_state.chain = None
if st.session_state.retriever is None:
st.session_state.retriever = load_data()
if st.session_state.chain is None:
st.session_state.chain = load_model()
sentence = st.text_input("please enter description:")
if sentence !='':
documents = process_input(sentence)
#st.write(documents)
hscode = st.session_state.chain.invoke({'context': documents,'description':sentence})
st.write("answer:",hscode.content)
|