File size: 6,339 Bytes
d60dac5
9d25320
cebfd3c
 
 
 
 
 
a11501e
 
 
2a5c937
 
 
 
 
cebfd3c
a11501e
e506679
cebfd3c
a11501e
 
cebfd3c
a11501e
cebfd3c
2e2d510
a11501e
cebfd3c
 
 
9d25320
 
 
d60dac5
9d25320
 
99ff44d
9d25320
 
2006c2b
9d25320
 
2006c2b
9d25320
 
cebfd3c
9d25320
568675d
e506679
9d25320
 
e506679
9d25320
 
2a5c937
 
 
 
 
 
 
 
 
9d25320
8b18fd0
9d25320
a11501e
 
 
 
 
9d25320
 
 
 
a11501e
2006c2b
9d25320
a11501e
 
 
 
 
9d25320
 
 
 
 
 
 
 
 
a11501e
b162398
a11501e
 
b162398
9d25320
a11501e
9d25320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b18fd0
9d25320
a11501e
8b18fd0
9d25320
a11501e
9d25320
 
d60dac5
9d25320
a11501e
9d25320
 
d60dac5
9d25320
 
 
 
d60dac5
a11501e
d60dac5
 
6c6516f
9d25320
cebfd3c
 
6c6516f
2e2d510
9d25320
cebfd3c
9d25320
2dee873
2e2d510
e24f6eb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import streamlit as st
from chat_client import chat
import time
import os
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
load_dotenv()

CHAT_BOTS = {"Mixtral 8x7B v0.1" :"mistralai/Mixtral-8x7B-Instruct-v0.1"}
SYSTEM_PROMPT = ["Sei BonsiAI e mi aiuterai nelle mie richieste (Parla in ITALIANO)", "Esatto, sono BonsiAI. Di cosa hai bisogno?"]
IDENTITY_CHANGE = ["Sei BonsiAI da ora in poi!", "Certo farò del mio meglio"]
options = {
    'Email Genitori': {'text': 'Scrivi il testo per una mail XXXX su questo stile.', 'description': 'Descrizione aggiuntiva per Email Genitori'},
    'Email Colleghi': {'text': 'Scrivi il testo per una mail XXXX su questo stile.', 'description': 'Descrizione aggiuntiva per Email Colleghi'},
    'Decreti': {'text': 'Cerca testo dei decreti!', 'description': 'Descrizione aggiuntiva per Decreti'}
}

st.set_page_config(page_title="BonsiAI", page_icon="🤖")

def gen_augmented_prompt(prompt, top_k) :
    context = ""     
    links = ""
    generated_prompt = f"""
    A PARTIRE DAL SEGUENTE CONTESTO: {context},

    ----
    RISPONDI ALLA SEGUENTE RICHIESTA: {prompt}
    """
    return generated_prompt, links

def init_state() :
    if "messages" not in st.session_state:
        st.session_state.messages = []

    if "temp" not in st.session_state:
        st.session_state.temp = 0.8

    if "history" not in st.session_state:
        st.session_state.history = [SYSTEM_PROMPT]

    if "top_k" not in st.session_state:
        st.session_state.top_k = 5

    if "repetion_penalty" not in st.session_state :
        st.session_state.repetion_penalty = 1

    if "rag_enabled" not in st.session_state :
        st.session_state.rag_enabled = False

    if "chat_bot" not in st.session_state :
        st.session_state.chat_bot = "Mixtral 8x7B v0.1"

def sidebar() :
    def retrieval_settings() :
        st.markdown("# Impostazioni Azioni")
        st.session_state.selected_option_key = st.selectbox('Azione', list(options.keys()) + ['+ Aggiungi'])
        st.session_state.selected_option = options.get(selected_option_key, {})
        st.session_state.selected_option_text = selected_option.get('text', '')
        st.session_state.option_text = st.text_area("Testo Azione", selected_option_text)
        st.session_state.selected_option_description = selected_option.get('description', '')
        if selected_option_key == 'Decreti':
           st.session_state.rag_enabled = st.toggle("Cerca nel DB Vettoriale", value=True)
           st.session_state.top_k = st.slider(label="Documenti da ricercare", min_value=1, max_value=20, value=4, disabled=not st.session_state.rag_enabled)
        st.markdown("---")
    
    def model_settings() :
        st.markdown("# Impostazioni Modello")
        st.session_state.chat_bot = st.sidebar.radio('Seleziona Modello:', [key for key, value in CHAT_BOTS.items() ])
        st.session_state.temp = st.slider(label="Creatività", min_value=0.0, max_value=1.0, step=0.1, value=0.9)
        st.session_state.max_tokens = st.slider(label="Lunghezza Output", min_value = 64, max_value=2048, step= 32, value=512)
        st.session_state.repetion_penalty = st.slider(label="Penalità Ripetizione", min_value=0., max_value=1., step=0.1, value=1. )

    with st.sidebar:
        retrieval_settings()
        model_settings()
        st.markdown("""> **Creato da [Matteo Script] 🔗**""")

def header() :
    st.title("BonsiAI")
    with st.expander("Cos'è BonsiAI?"):
        st.info("""BonsiAI Chat è un ChatBot personalizzato basato su un database vettoriale, funziona secondo il principio della Generazione potenziata da Recupero (RAG). 
                   La sua funzione principale ruota attorno alla gestione di un ampio repository di documenti BonsiAI e fornisce agli utenti risposte in linea con le loro domande. 
                   Questo approccio garantisce una risposta più precisa sulla base della richiesta degli utenti.""")

def chat_box() :
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

def generate_chat_stream(prompt) :
    links = []
    if st.session_state.rag_enabled :
        with st.spinner("Ricerca nei documenti...."):
            time.sleep(2)
            prompt, links = gen_augmented_prompt(prompt=prompt, top_k=st.session_state.top_k)        
    with st.spinner("Generazione in corso...") :
        time.sleep(2)
        chat_stream = chat(prompt, st.session_state.history,chat_client=CHAT_BOTS[st.session_state.chat_bot] ,
                       temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens)    
    return chat_stream, links

def stream_handler(chat_stream, placeholder) :
    start_time = time.time()
    full_response = ''

    for chunk in chat_stream :
        if chunk.token.text!='</s>' :
            full_response += chunk.token.text
            placeholder.markdown(full_response + "▌")
    placeholder.markdown(full_response)

    end_time = time.time()
    elapsed_time = end_time - start_time
    total_tokens_processed = len(full_response.split())
    tokens_per_second = total_tokens_processed // elapsed_time
    len_response = (len(prompt.split()) + len(full_response.split())) * 1.25
    col1, col2, col3 = st.columns(3)
    
    with col1 :
        st.write(f"**{tokens_per_second} token/secondi**")
    
    with col2 :
        st.write(f"**{int(len_response)} tokens generati**")
        
    return full_response

def show_source(links) :
    with st.expander("Mostra fonti") :
        for i, link in enumerate(links) :
            st.info(f"{link}")

init_state()
sidebar()
header()
chat_box()

if prompt := st.chat_input("Chatta con BonsiAI..."):
    st.chat_message("user").markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    chat_stream, links = generate_chat_stream(prompt)

    
    with st.chat_message("assistant"):
        placeholder = st.empty()
        full_response = stream_handler(chat_stream, placeholder)
        if st.session_state.rag_enabled :
            show_source(links)

    st.session_state.history.append([prompt, full_response])
    st.session_state.messages.append({"role": "assistant", "content": full_response})
    st.success('Generazione Completata', icon="✅")