dineth554 commited on
Commit
38de785
·
verified ·
1 Parent(s): 25bcda8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -2
app.py CHANGED
@@ -1,3 +1,183 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/mistralai/Mistral-7B-Instruct-v0.3").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import streamlit as st
4
+ from huggingface_hub import hf_hub_download
5
+ from llama_cpp import Llama
6
+ from llama_cpp_agent.providers import LlamaCppPythonProvider
7
+ from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
8
+ from llama_cpp_agent.chat_history import BasicChatHistory
9
+ from llama_cpp_agent.chat_history.messages import Roles
10
+ from llama_cpp_agent.llm_output_settings import (
11
+ LlmStructuredOutputSettings,
12
+ LlmStructuredOutputType,
13
+ )
14
+ from llama_cpp_agent.tools import WebSearchTool
15
+ from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
16
+ from utils import CitingSources
17
+ from settings import get_context_by_model, get_messages_formatter_type
18
 
19
+ # Install necessary libraries using os.system
20
+ os.system("pip install streamlit llama-cpp-agent huggingface_hub")
21
+
22
+ # Download the models
23
+ hf_hub_download(
24
+ repo_id="bartowski/Mistral-7B-Instruct-v0.3-GGUF",
25
+ filename="Mistral-7B-Instruct-v0.3-Q6_K.gguf",
26
+ local_dir="./models"
27
+ )
28
+ hf_hub_download(
29
+ repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF",
30
+ filename="Meta-Llama-3-8B-Instruct-Q6_K.gguf",
31
+ local_dir="./models"
32
+ )
33
+ hf_hub_download(
34
+ repo_id="TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
35
+ filename="mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf",
36
+ local_dir="./models"
37
+ )
38
+
39
+ # Function to respond to user messages
40
+ def respond(message, history, model, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
41
+ chat_template = get_messages_formatter_type(model)
42
+ llm = Llama(
43
+ model_path=f"models/{model}",
44
+ flash_attn=True,
45
+ n_gpu_layers=81,
46
+ n_batch=1024,
47
+ n_ctx=get_context_by_model(model),
48
+ )
49
+ provider = LlamaCppPythonProvider(llm)
50
+ logging.info(f"Loaded chat examples: {chat_template}")
51
+ search_tool = WebSearchTool(
52
+ llm_provider=provider,
53
+ message_formatter_type=chat_template,
54
+ max_tokens_search_results=12000,
55
+ max_tokens_per_summary=2048,
56
+ )
57
+
58
+ web_search_agent = LlamaCppAgent(
59
+ provider,
60
+ system_prompt=web_search_system_prompt,
61
+ predefined_messages_formatter_type=chat_template,
62
+ debug_output=True,
63
+ )
64
+
65
+ answer_agent = LlamaCppAgent(
66
+ provider,
67
+ system_prompt=research_system_prompt,
68
+ predefined_messages_formatter_type=chat_template,
69
+ debug_output=True,
70
+ )
71
+
72
+ settings = provider.get_provider_default_settings()
73
+ settings.stream = False
74
+ settings.temperature = temperature
75
+ settings.top_k = top_k
76
+ settings.top_p = top_p
77
+
78
+ settings.max_tokens = max_tokens
79
+ settings.repeat_penalty = repeat_penalty
80
+
81
+ output_settings = LlmStructuredOutputSettings.from_functions(
82
+ [search_tool.get_tool()]
83
+ )
84
+
85
+ messages = BasicChatHistory()
86
+
87
+ for msn in history:
88
+ user = {"role": Roles.user, "content": msn[0]}
89
+ assistant = {"role": Roles.assistant, "content": msn[1]}
90
+ messages.add_message(user)
91
+ messages.add_message(assistant)
92
+
93
+ result = web_search_agent.get_chat_response(
94
+ message,
95
+ llm_sampling_settings=settings,
96
+ structured_output_settings=output_settings,
97
+ add_message_to_chat_history=False,
98
+ add_response_to_chat_history=False,
99
+ print_output=False,
100
+ )
101
+
102
+ outputs = ""
103
+
104
+ settings.stream = True
105
+ response_text = answer_agent.get_chat_response(
106
+ f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
107
+ result[0]["return_value"],
108
+ role=Roles.tool,
109
+ llm_sampling_settings=settings,
110
+ chat_history=messages,
111
+ returns_streaming_generator=True,
112
+ print_output=False,
113
+ )
114
+
115
+ for text in response_text:
116
+ outputs += text
117
+ yield outputs
118
+
119
+ output_settings = LlmStructuredOutputSettings.from_pydantic_models(
120
+ [CitingSources], LlmStructuredOutputType.object_instance
121
+ )
122
+
123
+ citing_sources = answer_agent.get_chat_response(
124
+ "Cite the sources you used in your response.",
125
+ role=Roles.tool,
126
+ llm_sampling_settings=settings,
127
+ chat_history=messages,
128
+ returns_streaming_generator=False,
129
+ structured_output_settings=output_settings,
130
+ print_output=False,
131
+ )
132
+ outputs += "\n\nSources:\n"
133
+ outputs += "\n".join(citing_sources.sources)
134
+ yield outputs
135
+
136
+ # Streamlit app
137
+ st.title("Llama-CPP-Agent Chatbot with Web Search")
138
+
139
+ # Sidebar for settings
140
+ st.sidebar.title("Settings")
141
+ model = st.sidebar.selectbox(
142
+ "Model",
143
+ [
144
+ 'Mistral-7B-Instruct-v0.3-Q6_K.gguf',
145
+ 'mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf',
146
+ 'Meta-Llama-3-8B-Instruct-Q6_K.gguf'
147
+ ]
148
+ )
149
+ system_message = st.sidebar.text_area("System message", value=web_search_system_prompt)
150
+ max_tokens = st.sidebar.slider("Max tokens", min_value=1, max_value=4096, value=2048, step=1)
151
+ temperature = st.sidebar.slider("Temperature", min_value=0.1, max_value=1.0, value=0.45, step=0.1)
152
+ top_p = st.sidebar.slider("Top-p", min_value=0.1, max_value=1.0, value=0.95, step=0.05)
153
+ top_k = st.sidebar.slider("Top-k", min_value=0, max_value=100, value=40, step=1)
154
+ repeat_penalty = st.sidebar.slider("Repetition penalty", min_value=0.0, max_value=2.0, value=1.1, step=0.1)
155
+
156
+ # Chat history
157
+ if "history" not in st.session_state:
158
+ st.session_state.history = []
159
+
160
+ # Chat input
161
+ message = st.text_input("You:", key="input")
162
+
163
+ if st.button("Send"):
164
+ history = st.session_state.history
165
+ response = respond(
166
+ message,
167
+ history,
168
+ model,
169
+ system_message,
170
+ max_tokens,
171
+ temperature,
172
+ top_p,
173
+ top_k,
174
+ repeat_penalty
175
+ )
176
+
177
+ for res in response:
178
+ st.session_state.history.append((message, res))
179
+ st.text_area("Chat", value=f"You: {message}\nBot: {res}", height=300)
180
+
181
+ # Display chat history
182
+ for user_msg, bot_msg in st.session_state.history:
183
+ st.text_area("Chat", value=f"You: {user_msg}\nBot: {bot_msg}", height=300)