|
import streamlit as st |
|
from llama_index.core import VectorStoreIndex |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.core.ingestion import IngestionPipeline |
|
import chromadb |
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
from llama_index.llms.ollama import Ollama |
|
|
|
from llama_index.llms.huggingface import HuggingFaceLLM |
|
|
|
from llama_index.core import Settings |
|
from transformers import BitsAndBytesConfig |
|
|
|
|
|
|
|
st.title("Aplikacja z LlamaIndex") |
|
|
|
db = chromadb.PersistentClient(path="./abc") |
|
chroma_collection = db.get_or_create_collection("pomoc_ukrainie") |
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5", device="cpu") |
|
|
|
|
|
index = VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model) |
|
|
|
|
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype="float16" |
|
) |
|
|
|
llm = HuggingFaceLLM( |
|
model_name="microsoft/phi-1_5", |
|
tokenizer=AutoTokenizer.from_pretrained("microsoft/phi-1_5"), |
|
model_kwargs={ |
|
"quantization_config": quantization_config, |
|
"device_map": "auto" |
|
} |
|
) |
|
|
|
|
|
query_engine = index.as_query_engine( |
|
llm=llm, |
|
response_mode='compact') |
|
|
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state.keys(): |
|
st.session_state.messages = [{"role": "assistant", "content": "Zadaj mi pytanie..."}] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
|
|
|
|
if input := st.chat_input(): |
|
st.session_state.messages.append({"role": "user", "content": input}) |
|
with st.chat_message("user"): |
|
st.write(input) |
|
|
|
|
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
with st.chat_message("assistant"): |
|
with st.spinner("Czekaj, odpowied藕 jest generowana.."): |
|
response = query_engine.query(input) |
|
|
|
|
|
content = str(response.response) |
|
if hasattr(response, 'source_nodes') and response.source_nodes: |
|
|
|
content += f"\nScore: {response.source_nodes[0].score:.4f}" |
|
|
|
st.write(content) |
|
|
|
message = {"role": "assistant", "content": content} |
|
st.session_state.messages.append(message) |