Spaces:
Runtime error
Runtime error
# "Since it's an almost example, it probably won't be affected by a license." | |
# Importing required libraries | |
from langchain.docstore.document import Document | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.retrievers import BM25Retriever | |
import warnings | |
warnings.filterwarnings("ignore") | |
import datasets | |
import os | |
import json | |
import subprocess | |
import sys | |
import joblib | |
from llama_cpp import Llama | |
from llama_cpp_agent import LlamaCppAgent | |
from llama_cpp_agent import MessagesFormatterType | |
from llama_cpp_agent.providers import LlamaCppPythonProvider | |
from llama_cpp_agent.chat_history import BasicChatHistory | |
from llama_cpp_agent.chat_history.messages import Roles | |
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings | |
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from typing import List, Tuple,Dict,Optional | |
from logger import logging | |
from exception import CustomExceptionHandling | |
from smolagents.gradio_ui import GradioUI | |
from smolagents import ( | |
CodeAgent, | |
GoogleSearchTool, | |
Model, | |
Tool, | |
LiteLLMModel, | |
ToolCallingAgent, | |
ChatMessage,tool,MessageRole | |
) | |
cache_file = "docs_processed.joblib" | |
if os.path.exists(cache_file): | |
docs_processed = joblib.load(cache_file) | |
print("Loaded docs_processed from cache.") | |
else: | |
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") | |
source_docs = [ | |
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base | |
] | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=400, | |
chunk_overlap=20, | |
add_start_index=True, | |
strip_whitespace=True, | |
separators=["\n\n", "\n", ".", " ", ""], | |
) | |
docs_processed = text_splitter.split_documents(source_docs) | |
joblib.dump(docs_processed, cache_file) | |
print("Created and saved docs_processed to cache.") | |
class RetrieverTool(Tool): | |
name = "retriever" | |
description = "Uses semantic search to retrieve the parts of documentation that could be most relevant to answer your query." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", | |
} | |
} | |
output_type = "string" | |
def __init__(self, docs, **kwargs): | |
super().__init__(**kwargs) | |
self.retriever = BM25Retriever.from_documents( | |
docs, | |
k=7, | |
) | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
docs = self.retriever.invoke( | |
query, | |
) | |
return "\nRetrieved documents:\n" + "".join( | |
[ | |
f"\n\n===== Document {str(i)} =====\n" + str(doc.page_content) | |
for i, doc in enumerate(docs) | |
] | |
) | |
# Download gguf model files | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
os.makedirs("models",exist_ok=True) | |
logging.info("start download") | |
hf_hub_download( | |
repo_id="bartowski/google_gemma-3-4b-it-GGUF", | |
filename="google_gemma-3-4b-it-Q4_K_M.gguf", | |
local_dir="./models", | |
) | |
retriever_tool = RetrieverTool(docs_processed) | |
# Define the prompt markers for Gemma 3 | |
gemma_3_prompt_markers = { | |
Roles.system: PromptMarkers("", "\n"), # System prompt should be included within user message | |
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"), | |
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"), | |
Roles.tool: PromptMarkers("", ""), # If you need tool support | |
} | |
# Create the formatter | |
gemma_3_formatter = MessagesFormatter( | |
pre_prompt="", # No pre-prompt | |
prompt_markers=gemma_3_prompt_markers, | |
include_sys_prompt_in_first_user_message=True, # Include system prompt in first user message | |
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"], | |
strip_prompt=False, # Don't strip whitespace from the prompt | |
bos_token="<bos>", # Beginning of sequence token for Gemma 3 | |
eos_token="<eos>", # End of sequence token for Gemma 3 | |
) | |
# based https://github.com/huggingface/smolagents/pull/450 | |
# almost overwrite with https://huggingface.co/spaces/sitammeur/Gemma-llamacpp | |
class LlamaCppModel(Model): | |
def __init__( | |
self, | |
model_path: Optional[str] = None, | |
repo_id: Optional[str] = None, | |
filename: Optional[str] = None, | |
n_gpu_layers: int = 0, | |
n_ctx: int = 8192, | |
max_tokens: int = 1024, | |
verbose:bool = False, | |
**kwargs, | |
): | |
""" | |
Initializes the LlamaCppModel. | |
Parameters: | |
model_path (str, optional): Path to the local model file. | |
repo_id (str, optional): Hugging Face repository ID if loading from Hugging Face. | |
filename (str, optional): Specific filename to load from the repository. | |
n_gpu_layers (int, default=0): Number of GPU layers to use. | |
n_ctx (int, default=8192): Context size for the model. | |
**kwargs: Additional keyword arguments. | |
Raises: | |
ValueError: If neither model_path nor repo_id+filename are provided. | |
""" | |
from llama_cpp import Llama | |
super().__init__(**kwargs) | |
self.flatten_messages_as_text=True | |
self.max_tokens = max_tokens | |
if model_path: | |
self.llm = Llama( | |
model_path=model_path, | |
flash_attn=False, | |
n_gpu_layers=0, | |
#n_batch=1024, | |
n_ctx=n_ctx, | |
n_threads=2, | |
n_threads_batch=2,verbose=False | |
) | |
elif repo_id and filename: | |
self.llm = Llama.from_pretrained( | |
repo_id=repo_id, | |
filename=filename, | |
n_gpu_layers=n_gpu_layers, | |
n_ctx=n_ctx, | |
max_tokens=max_tokens, | |
verbose=verbose, | |
**kwargs | |
) | |
else: | |
raise ValueError("Must provide either model_path or repo_id+filename") | |
def __call__( | |
self, | |
messages: List[Dict[str, str]], | |
stop_sequences: Optional[List[str]] = None, | |
grammar: Optional[str] = None, | |
tools_to_call_from: Optional[List[Tool]] = None, | |
**kwargs, | |
) -> ChatMessage: | |
from llama_cpp import LlamaGrammar | |
try: | |
completion_kwargs = self._prepare_completion_kwargs( | |
messages=messages, | |
stop_sequences=stop_sequences, | |
grammar=grammar, | |
tools_to_call_from=tools_to_call_from, | |
**kwargs | |
) | |
if not tools_to_call_from: | |
completion_kwargs.pop("tools", None) | |
completion_kwargs.pop("tool_choice", None) | |
filtered_kwargs = { | |
k: v for k, v in completion_kwargs.items() | |
if k not in ["messages", "stop", "grammar", "max_tokens", "tools_to_call_from"] | |
} | |
max_tokens = ( | |
kwargs.get("max_tokens") | |
or self.max_tokens | |
or 1024 | |
) | |
provider = LlamaCppPythonProvider(self.llm) | |
system_message= completion_kwargs["messages"][0]["content"] | |
message= completion_kwargs["messages"].pop()["content"] | |
# Create the agent | |
agent = LlamaCppAgent( | |
provider, | |
system_prompt=f"{system_message}", | |
custom_messages_formatter=gemma_3_formatter, | |
debug_output=True, | |
) | |
temperature = 0.5 | |
top_k=40 | |
top_p=0.95 | |
max_tokens=2048 | |
repeat_penalty=1.1 | |
settings = provider.get_provider_default_settings() | |
settings.temperature = temperature | |
settings.top_k = top_k | |
settings.top_p = top_p | |
settings.max_tokens = max_tokens | |
settings.repeat_penalty = repeat_penalty | |
settings.stream = False | |
messages = BasicChatHistory() | |
for from_message in completion_kwargs["messages"]: | |
if from_message["role"] is MessageRole.USER: | |
history_message = {"role": MessageRole.USER, "content": from_message["content"]} | |
elif from_message["role"] is MessageRole.SYSTEM: | |
history_message = {"role": MessageRole.SYSTEM, "content": from_message["content"]} | |
else: | |
history_message = {"role": MessageRole.ASSISTANT, "content": from_message["content"]} | |
messages.add_message(from_message) | |
stream = agent.get_chat_response( | |
message, | |
llm_sampling_settings=settings, | |
chat_history=messages, | |
returns_streaming_generator=False, | |
print_output=False, | |
) | |
content = stream | |
message = ChatMessage(role=MessageRole.ASSISTANT, content=content) | |
if tools_to_call_from is not None: | |
return super.parse_tool_args_if_needed(message) | |
return message | |
except Exception as e: | |
logging.error(f"Model error: {e}") | |
return ChatMessage(role="assistant", content=f"Error: {str(e)}") | |
model = LlamaCppModel( | |
model_path = "models/google_gemma-3-4b-it-Q4_K_M.gguf", | |
n_ctx=8192,verbose=False | |
) | |
import yaml | |
with open("retriever.yaml", "r") as f: | |
prompt = f.read() | |
description=""" | |
*CPU Rag Example with LlamaCpp* | |
Take a few minute.customized prompt is the key. | |
Reference | |
- [Qwen2.5-0.5B-Rag-Thinking](https://huggingface.co/spaces/Akjava/Qwen2.5-0.5B-Rag-Thinking-Flan-T5) | |
- [smolagents pull-450](https://github.com/huggingface/smolagents/pull/450) | |
- [Gemma-llamacpp](https://huggingface.co/spaces/sitammeur/Gemma-llamacpp) | |
- [Dataset(m-ric/huggingface_doc)](https://huggingface.co/datasets/m-ric/huggingface_doc) | |
""" | |
agent = CodeAgent(prompt_templates =yaml.safe_load(prompt),model=model, tools=[retriever_tool],max_steps=1,verbosity_level=0,name="AGENT",description=description) | |
demo = GradioUI(agent) | |
if __name__ == "__main__": | |
demo.launch() |