File size: 13,372 Bytes
734c13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a91dfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734c13e
 
 
0a91dfe
734c13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b26de5
734c13e
0a91dfe
 
bc45760
 
 
0a91dfe
734c13e
 
 
 
 
 
 
bc45760
734c13e
 
 
 
 
 
 
 
 
 
 
 
 
 
95067c2
d6a0da9
734c13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ac262d
95067c2
734c13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc995a6
734c13e
 
 
 
 
 
 
 
 
 
 
 
95067c2
734c13e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04f4640
95067c2
 
 
734c13e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import streamlit as st
import numpy as np
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_groq import ChatGroq
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from transformers import AutoModel, AutoTokenizer, pipeline
import joblib
from typing import List

# Define styles for different sections
custom_styles = """
<style>
    [id^=tabs-] {
        padding: 20px;
        border: 1px solid #ccc;
        margin-bottom: 20px;
        box-shadow: 0 4px 8px rgba(0,0,0,0.1);
    }

    .tab h2 {
        font-size: 24px;
        margin-top: 0;
        color: black;
    }

    .response-container {
        padding: 20px;
        border: 1px solid #ccc;
        background-color: #f9f9f9;
        margin-bottom: 20px;
        box-shadow: 0 4px 8px rgba(0,0,0,0.1);
    }

    .response-container h3 {
        font-size: 20px;
        margin-top: 0;
        color: black;
    }
</style>
"""

def setup_page():
    st.set_page_config(page_title="EAD Generation", layout="wide")
    display_app_header()
    st.markdown(custom_styles, unsafe_allow_html=True)
    # Create two columns - one for query history and one for main content
    col1, col2 = st.columns([1, 4])
    
    with col1:
        st.markdown("### Query History")
        # Create a container for the query history
        query_history_container = st.container()
    
    return col1, col2

# Cache the header of the app to prevent re-rendering on each load
@st.cache_resource
def display_app_header():
    """Display the header of the Streamlit app."""
    st.title("EAD Generation")
    #st.subheader("Tries")
    st.markdown("---")
    # Add a description of the app
    st.markdown("""This app allows you to generate EAD/XML archival descriptions. See this serie of blog posts for explanations :""")                
    st.markdown("""
    - [https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm](https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm)
    - [https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm-rag-23](https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm-rag-23)
    - [https://iaetbibliotheques.fr/2024/12/comment-apprendre-lead-a-un-llm-fine-tuning-33](https://iaetbibliotheques.fr/2024/12/comment-apprendre-lead-a-un-llm-fine-tuning-33)
    """, unsafe_allow_html=True)
    st.markdown("---")

# Display the header of the app and get columns
history_col, main_col = setup_page()

def setup_sidebar():    
    groq_models = ["llama3-70b-8192", "llama-3.1-70b-versatile","llama3-8b-8192", "llama-3.1-8b-instant", "mixtral-8x7b-32768","gemma2-9b-it", "gemma-7b-it"]
    selected_groq_models = st.sidebar.radio("Choose a model (used in tabs 1, 2 and 3)", groq_models)
    return selected_groq_models

def create_groq_llm(model):
    return ChatGroq(
        model=model,
        temperature=0.1,
        max_tokens=None,
        timeout=None,
        max_retries=2,
        groq_api_key=st.secrets["GROQ_API_KEY"]
    )
    
def create_aws_ollama_llm():
    return ChatOllama(
        model="hf.co/Geraldine/FineLlama-3.2-3B-Instruct-ead-GGUF:Q5_K_M",
        base_url="http://129.80.86.176:11434",
        temperature=0.1,
        max_tokens=None,
        timeout=None,
        max_retries=2,
    )
    
def setup_zero_shot_tab(llm):
    st.header("Zero-shot Prompting")
    prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an expert in the domain of archival description in the standardized Encoded Archival Description EAD format** and intelligent EAD/XML generator.",
        ),
        ("human", "{question}"),
    ]
    )
    zero_shot_chain = prompt | llm | StrOutputParser()
    return zero_shot_chain

def setup_ead_xsd_tab(llm):
    st.header("One-shot Prompting with EAD 2002 XSD schema")
    ead_xsd_2002 = open("assets/ead_xsd_2002.xml", "r").read()
    
    class CustomRetriever(BaseRetriever):
        def _get_relevant_documents(
          self, query: str, *, run_manager: CallbackManagerForRetrieverRun
        ) -> List[Document]:
            return [Document(page_content=ead_xsd_2002)]
    
    retriever = CustomRetriever()
    
    template = """
    ### [INST]
    **You are an expert in the domain of archival description in the standardized Encoded Archival Description EAD format** and intelligent EAD/XML generator.
    Use the following xsd schema context to answer the question by generating a compliant xml content.

    {context}

    Please follow the EAD schema guidelines to ensure your output is valid and well-formed. Do not include any markup or comments other than what is specified in the schema.

    ### QUESTION:
    {question}

    [/INST]
    """

    prompt = ChatPromptTemplate.from_template(template)
    retrieval_chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    
    return retrieval_chain

def setup_rag_tab(llm):
    st.header("RAG")

    model = AutoModel.from_pretrained("Geraldine/msmarco-distilbert-base-v4-ead", token=st.secrets["HF_TOKEN"])
    tokenizer = AutoTokenizer.from_pretrained("Geraldine/msmarco-distilbert-base-v4-ead", token=st.secrets["HF_TOKEN"])
    #pca = hf_hub_download("Geraldine/msmarco-distilbert-base-v4-ead", "pca_model.joblib",local_dir="assets")
    feature_extraction_pipeline = pipeline("feature-extraction", model=model, tokenizer=tokenizer)
    
    class HuggingFaceEmbeddingFunction:
        def __init__(self, pipeline, pca_model_path):
            self.pipeline = pipeline
            self.pca = joblib.load(pca_model_path)
                
        # Function for embedding documents (lists of text)
        def embed_documents(self, texts):
            # Get embeddings as numpy arrays
            embeddings = self.pipeline(texts)
            embeddings = [embedding[0][0] for embedding in embeddings]
            embeddings = np.array(embeddings)

            # Transform embeddings using PCA
            reduced_embeddings = self.pca.transform(embeddings)
            return reduced_embeddings.tolist()

        # Function for embedding individual queries
        def embed_query(self, text):
            embedding = self.pipeline(text)
            embedding = np.array(embedding[0][0]).reshape(1, -1)

            # Transform embedding using PCA
            reduced_embedding = self.pca.transform(embedding)
            return reduced_embedding.flatten().tolist()

    embeddings = HuggingFaceEmbeddingFunction(feature_extraction_pipeline, pca_model_path="assets/pca_model.joblib")
    persist_directory = "assets/chroma_xml_db"
    vector_store = Chroma(
        collection_name="ead-xml",
        embedding_function=embeddings,
        persist_directory=persist_directory,
    )
    retriever = vector_store.as_retriever()
    
    template = """
    # Generate EAD/XML File for Archival Collection

    ## Description
    You are an assistant for the generation of archives encoded in EAD/XML format.
    You are an expert in archival description rules and standards, knowing very well the EAD format for encoding archival metadata.

    ## Instruction
    Answer the question based only on the following context:
    {context}.

    The EAD/XML sections you generate must follow the Library of Congress EAD schema and be in the style of a traditional archival finding aid, as if written by a professional archivist, including the required XML tags and structure.

    ## Question
    {question}

    ## Answer
    """

    prompt = ChatPromptTemplate.from_template(template)
    chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    
    return chain

def setup_fine_tuned_tab():
    st.header("Fine-tuned FineLlama-3.2-3b-Instruct-ead model")
    
    llm = create_aws_ollama_llm()
    prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are an expert in the domain of archival description in the standardized Encoded Archival Description EAD format** and intelligent EAD/XML generator.",
        ),
        ("human", "{question}"),
    ]
    )
    fine_tuned_chain = prompt | llm | StrOutputParser()
    return fine_tuned_chain

def clear_outputs():
    # Clear all stored responses from session state
    for key in list(st.session_state.keys()):
        if key.startswith('response_'):
            del st.session_state[key]

# Initialize query history in session state if it doesn't exist
if 'query_history' not in st.session_state:
    st.session_state.query_history = []

# Move all main content to the main column
with main_col:
    selected_groq_models = setup_sidebar()

    if 'previous_query' not in st.session_state:
        st.session_state.previous_query = ''
        
    if 'current_query' not in st.session_state:
        st.session_state.current_query = ''

    query = st.chat_input("Enter your query", key="chat_input")
    st.markdown("*Example queries : Generate an EAD description for the personal papers of Marie Curie, Create an EAD inventory for a collection of World War II photographs, Create a EAD compliant `<eadheader>` sections with all necessary attributes and child elements.*")

    tab1, tab2, tab3, tab4 = st.tabs(["Zero-shot prompting","In-context learning with EAD schema", "RAG", "FineLlama-3.2-3B-Instruct-ead"])

    # Display info messages for each tab on app launch
    with tab1:
        st.info("Simple inference with zero-shot prompting : the prompt used to interact with the model does not contain examples or demonstrations. The LLM used id the one selected in the sidebar list.",icon="ℹ️")

    with tab2:
        st.info("One-shot inference with EAD 2002 XSD schema : the prompt used to interact with the model contains the plaintext of the EAD schema as a guideline for the desired output format. The LLM used id the one selected in the sidebar list.",icon="ℹ️")

    with tab3:
        st.info("Retrieval-augmented generation : the prompt used to interact with the model contains the relevant context from an archival collection of EAD files. The LLM used id the one selected in the sidebar list.",icon="ℹ️")

    with tab4:
        st.info("FineLlama-3.2-3b-Instruct-ead model : this is a custom fine-tuned adaptation of llama-3.2-3b-instruct post-trained on a dataset of archival descriptions in the EAD format",icon="ℹ️")

    # Process query for all tabs when submitted
    if query:
        # Add new query to history if it's different from the last one
        if not st.session_state.query_history or query != st.session_state.query_history[-1]:
            st.session_state.query_history.append(query)
        
        # Store the current query
        st.session_state.current_query = query
        # Clear outputs if query has changed
        if query != st.session_state.previous_query:
            clear_outputs()
            st.session_state.previous_query = query
            
        with st.spinner('Processing query across all models...'):
            # Process for Tab 1 - zero-shot inference
            with tab1:
                llm = create_groq_llm(selected_groq_models)
                zero_shot_chain = setup_zero_shot_tab(llm)
                st.session_state.response_zero_shot = zero_shot_chain.invoke({"question": query})
                with st.chat_message("assistant"):
                    st.markdown(st.session_state.response_zero_shot)
            
            # Process for Tab 2 - EAD XSD
            with tab2:
                llm = create_groq_llm("llama-3.1-8b-instant")
                ead_chain = setup_ead_xsd_tab(llm)
                st.session_state.response_ead = ead_chain.invoke(query)
                with st.chat_message("assistant"):
                    st.markdown(st.session_state.response_ead)
            
            # Process for Tab 3 - RAG
            with tab3:
                llm = create_groq_llm(selected_groq_models)
                rag_chain = setup_rag_tab(llm)
                st.session_state.response_rag = rag_chain.invoke(query)
                with st.chat_message("assistant"):
                    st.markdown(st.session_state.response_rag)
                    
            # Process for Tab 4 - Fine-tuned model
            with tab4:
                fine_tuned_chain = setup_fine_tuned_tab()
                st.session_state.response_fine_tuned = fine_tuned_chain.invoke(query)
                with st.chat_message("assistant"):
                    st.markdown(st.session_state.response_fine_tuned)

# Display query history in the sidebar column
with history_col:
    if st.session_state.query_history:
        for i, past_query in enumerate(reversed(st.session_state.query_history)):
            st.text_area(f"Query {len(st.session_state.query_history) - i}", 
                        past_query, 
                        height=100,
                        key=f"history_{i}",
                        disabled=True)
    else:
        st.write("No queries yet")