import streamlit as st |
import numpy as np |
from langchain_core.retrievers import BaseRetriever |
from langchain_core.callbacks import CallbackManagerForRetrieverRun |
from langchain_core.documents import Document |
from langchain_chroma import Chroma |
from langchain_groq import ChatGroq |
from langchain_ollama import ChatOllama |
from langchain_core.prompts import ChatPromptTemplate |
from langchain_core.runnables import RunnablePassthrough |
from langchain_core.output_parsers import StrOutputParser |
from transformers import AutoModel, AutoTokenizer, pipeline |
import joblib |
from typing import List |
custom_styles = """ |
<style> |
[id^=tabs-] { |
padding: 20px; |
border: 1px solid #ccc; |
margin-bottom: 20px; |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); |
} |
.tab h2 { |
font-size: 24px; |
margin-top: 0; |
color: black; |
} |
.response-container { |
padding: 20px; |
border: 1px solid #ccc; |
background-color: #f9f9f9; |
margin-bottom: 20px; |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); |
} |
.response-container h3 { |
font-size: 20px; |
margin-top: 0; |
color: black; |
} |
</style> |
""" |
def setup_page(): |
st.set_page_config(page_title="EAD Generation", layout="wide") |
display_app_header() |
st.markdown(custom_styles, unsafe_allow_html=True) |
col1, col2 = st.columns([1, 4]) |
with col1: |
st.markdown("### Query History") |
query_history_container = st.container() |
return col1, col2 |
@st.cache_resource |
def display_app_header(): |
"""Display the header of the Streamlit app.""" |
st.title("EAD Generation") |
st.markdown("---") |
st.markdown("""This app allows you to generate EAD/XML archival descriptions. See this serie of blog posts for explanations :""") |
st.markdown(""" |
- [https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm](https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm) |
- [https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm-rag-23](https://iaetbibliotheques.fr/2024/11/comment-apprendre-lead-a-un-llm-rag-23) |
- [https://iaetbibliotheques.fr/2024/12/comment-apprendre-lead-a-un-llm-fine-tuning-33](https://iaetbibliotheques.fr/2024/12/comment-apprendre-lead-a-un-llm-fine-tuning-33) |
""", unsafe_allow_html=True) |
st.markdown("---") |
history_col, main_col = setup_page() |
def setup_sidebar(): |
groq_models = ["llama3-70b-8192", "llama-3.1-70b-versatile","llama3-8b-8192", "llama-3.1-8b-instant", "mixtral-8x7b-32768","gemma2-9b-it", "gemma-7b-it"] |
selected_groq_models = st.sidebar.radio("Choose a model (used in tabs 1, 2 and 3)", groq_models) |
return selected_groq_models |
def create_groq_llm(model): |
return ChatGroq( |
model=model, |
temperature=0.1, |
max_tokens=None, |
timeout=None, |
max_retries=2, |
groq_api_key=st.secrets["GROQ_API_KEY"] |
) |
def create_aws_ollama_llm(): |
return ChatOllama( |
model="hf.co/Geraldine/FineLlama-3.2-3B-Instruct-ead-GGUF:Q5_K_M", |
base_url="", |
temperature=0.1, |
max_tokens=None, |
timeout=None, |
max_retries=2, |
) |
def setup_zero_shot_tab(llm): |
st.header("Zero-shot Prompting") |
prompt = ChatPromptTemplate.from_messages( |
[ |
( |
"system", |
"You are an expert in the domain of archival description in the standardized Encoded Archival Description EAD format** and intelligent EAD/XML generator.", |
), |
("human", "{question}"), |
] |
) |
zero_shot_chain = prompt | llm | StrOutputParser() |
return zero_shot_chain |
def setup_ead_xsd_tab(llm): |
st.header("One-shot Prompting with EAD 2002 XSD schema") |
ead_xsd_2002 = open("assets/ead_xsd_2002.xml", "r").read() |
class CustomRetriever(BaseRetriever): |
def _get_relevant_documents( |
self, query: str, *, run_manager: CallbackManagerForRetrieverRun |
) -> List[Document]: |
return [Document(page_content=ead_xsd_2002)] |
retriever = CustomRetriever() |
template = """ |
### [INST] |
**You are an expert in the domain of archival description in the standardized Encoded Archival Description EAD format** and intelligent EAD/XML generator. |
Use the following xsd schema context to answer the question by generating a compliant xml content. |
{context} |
Please follow the EAD schema guidelines to ensure your output is valid and well-formed. Do not include any markup or comments other than what is specified in the schema. |
{question} |
[/INST] |
""" |
prompt = ChatPromptTemplate.from_template(template) |
retrieval_chain = ( |
{"context": retriever, "question": RunnablePassthrough()} |
| prompt |
| llm |
| StrOutputParser() |
) |
return retrieval_chain |
def setup_rag_tab(llm): |
st.header("RAG") |
model = AutoModel.from_pretrained("Geraldine/msmarco-distilbert-base-v4-ead", token=st.secrets["HF_TOKEN"]) |
tokenizer = AutoTokenizer.from_pretrained("Geraldine/msmarco-distilbert-base-v4-ead", token=st.secrets["HF_TOKEN"]) |
feature_extraction_pipeline = pipeline("feature-extraction", model=model, tokenizer=tokenizer) |
class HuggingFaceEmbeddingFunction: |
def __init__(self, pipeline, pca_model_path): |
self.pipeline = pipeline |
self.pca = joblib.load(pca_model_path) |
def embed_documents(self, texts): |
embeddings = self.pipeline(texts) |
embeddings = [embedding[0][0] for embedding in embeddings] |
embeddings = np.array(embeddings) |
reduced_embeddings = self.pca.transform(embeddings) |
return reduced_embeddings.tolist() |
def embed_query(self, text): |
embedding = self.pipeline(text) |
embedding = np.array(embedding[0][0]).reshape(1, -1) |
reduced_embedding = self.pca.transform(embedding) |
return reduced_embedding.flatten().tolist() |
embeddings = HuggingFaceEmbeddingFunction(feature_extraction_pipeline, pca_model_path="assets/pca_model.joblib") |
persist_directory = "assets/chroma_xml_db" |
vector_store = Chroma( |
collection_name="ead-xml", |
embedding_function=embeddings, |
persist_directory=persist_directory, |
) |
retriever = vector_store.as_retriever() |
template = """ |
# Generate EAD/XML File for Archival Collection |
## Description |
You are an assistant for the generation of archives encoded in EAD/XML format. |
You are an expert in archival description rules and standards, knowing very well the EAD format for encoding archival metadata. |
## Instruction |
Answer the question based only on the following context: |
{context}. |
The EAD/XML sections you generate must follow the Library of Congress EAD schema and be in the style of a traditional archival finding aid, as if written by a professional archivist, including the required XML tags and structure. |
## Question |
{question} |
## Answer |
""" |
prompt = ChatPromptTemplate.from_template(template) |
chain = ( |
{"context": retriever, "question": RunnablePassthrough()} |
| prompt |
| llm |
| StrOutputParser() |
) |
return chain |
def setup_fine_tuned_tab(): |
st.header("Fine-tuned FineLlama-3.2-3b-Instruct-ead model") |
llm = create_aws_ollama_llm() |
prompt = ChatPromptTemplate.from_messages( |
[ |
( |
"system", |
"You are an expert in the domain of archival description in the standardized Encoded Archival Description EAD format** and intelligent EAD/XML generator.", |
), |
("human", "{question}"), |
] |
) |
fine_tuned_chain = prompt | llm | StrOutputParser() |
return fine_tuned_chain |
def clear_outputs(): |
for key in list(st.session_state.keys()): |
if key.startswith('response_'): |
del st.session_state[key] |
if 'query_history' not in st.session_state: |
st.session_state.query_history = [] |
with main_col: |
selected_groq_models = setup_sidebar() |
if 'previous_query' not in st.session_state: |
st.session_state.previous_query = '' |
if 'current_query' not in st.session_state: |
st.session_state.current_query = '' |
query = st.chat_input("Enter your query", key="chat_input") |
st.markdown("*Example queries : Generate an EAD description for the personal papers of Marie Curie, Create an EAD inventory for a collection of World War II photographs, Create a EAD compliant `<eadheader>` sections with all necessary attributes and child elements.*") |
tab1, tab2, tab3, tab4 = st.tabs(["Zero-shot prompting","In-context learning with EAD schema", "RAG", "FineLlama-3.2-3B-Instruct-ead"]) |
with tab1: |
st.info("Simple inference with zero-shot prompting : the prompt used to interact with the model does not contain examples or demonstrations. The LLM used id the one selected in the sidebar list.",icon="ℹ️") |
with tab2: |
st.info("One-shot inference with EAD 2002 XSD schema : the prompt used to interact with the model contains the plaintext of the EAD schema as a guideline for the desired output format. The LLM used id the one selected in the sidebar list.",icon="ℹ️") |
with tab3: |
st.info("Retrieval-augmented generation : the prompt used to interact with the model contains the relevant context from an archival collection of EAD files. The LLM used id the one selected in the sidebar list.",icon="ℹ️") |
with tab4: |
st.info("FineLlama-3.2-3b-Instruct-ead model : this is a custom fine-tuned adaptation of llama-3.2-3b-instruct post-trained on a dataset of archival descriptions in the EAD format",icon="ℹ️") |
if query: |
if not st.session_state.query_history or query != st.session_state.query_history[-1]: |
st.session_state.query_history.append(query) |
st.session_state.current_query = query |
if query != st.session_state.previous_query: |
clear_outputs() |
st.session_state.previous_query = query |
with st.spinner('Processing query across all models...'): |
with tab1: |
llm = create_groq_llm(selected_groq_models) |
zero_shot_chain = setup_zero_shot_tab(llm) |
st.session_state.response_zero_shot = zero_shot_chain.invoke({"question": query}) |
with st.chat_message("assistant"): |
st.markdown(st.session_state.response_zero_shot) |
with tab2: |
llm = create_groq_llm("llama-3.1-8b-instant") |
ead_chain = setup_ead_xsd_tab(llm) |
st.session_state.response_ead = ead_chain.invoke(query) |
with st.chat_message("assistant"): |
st.markdown(st.session_state.response_ead) |
with tab3: |
llm = create_groq_llm(selected_groq_models) |
rag_chain = setup_rag_tab(llm) |
st.session_state.response_rag = rag_chain.invoke(query) |
with st.chat_message("assistant"): |
st.markdown(st.session_state.response_rag) |
with tab4: |
fine_tuned_chain = setup_fine_tuned_tab() |
st.session_state.response_fine_tuned = fine_tuned_chain.invoke(query) |
with st.chat_message("assistant"): |
st.markdown(st.session_state.response_fine_tuned) |
with history_col: |
if st.session_state.query_history: |
for i, past_query in enumerate(reversed(st.session_state.query_history)): |
st.text_area(f"Query {len(st.session_state.query_history) - i}", |
past_query, |
height=100, |
key=f"history_{i}", |
disabled=True) |
else: |
st.write("No queries yet") |