# pylint: skip-file import subprocess import json import requests subprocess.run( f"pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import os from threading import Thread from typing import Iterator import gradio as gr import spaces import torch import wikipedia import time from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from bs4 import BeautifulSoup from functools import lru_cache MAX_MAX_NEW_TOKENS = 4096 DEFAULT_MAX_NEW_TOKENS = 1536 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192")) DESCRIPTION = """\ # Playground with Ghost 8B Beta (β, 8k, Online) **Ghost 8B Beta** model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/). The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default. The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese. 🗞️ **Updates** * Jul 23, 2024: added support for tools, now available to search for information on the internet. """ PLACEHOLDER = """
Ask and share whatever you want ~
Running on CPU 🥶 This demo does not work on CPU.
" if torch.cuda.is_available(): model_id = "ghost-x/ghost-8b-beta" hf_serect = os.getenv("HF_TOKEN", None) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", trust_remote_code=True, token=hf_serect, ) tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, token=hf_serect, ) waiting_tools_timeout = 5 supported_tools = json.dumps( [ { "type": "function", "function": { "name": "search_on_internet", "description": "Use this tool to search for information on the internet to answer questions you are unsure about, don't know or need the latest information (e.g. news, reports, companies, people,...) to give the most accurate results. Note: can only be used or ignored, not asked again", "parameters": { "type": "object", "properties": { "keyword": { "type": "string", "description": "Search keywords, rephrase to optimize search results based on questions suitable to the specified search type.", "required": True, }, "type": { "type": "string", "description": "Search type, based on the question to determine whether to search for it in 'wikipedia' or 'google', prefer to use wikipedia for information about events, history and people.", "enum": ["wikipedia", "google"], "default": "google", "required": True, }, "language": { "type": "string", "description": "Search language, is the user language code with 2 letters, e.g: vi = vietnamese, en = english.", "default": "en", "required": True, }, }, }, }, } ], ensure_ascii=False, ) @lru_cache(maxsize=128) def extract_text_from_webpage(html_content): soup = BeautifulSoup(html_content, "html.parser") for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]): tag.extract() visible_text = soup.get_text(strip=True, separator=" ") return visible_text def search_with_wikipedia( query: str, language: str = "en", ): all_results = [] try: wikipedia.set_lang(language) all_results.append(wikipedia.summary(query)) except Exception as e: pass return all_results def search_with_google( query: str, num_results: int = 3, timeout: int = 5, language: str = "en", ssl_verify: bool = None, ): all_results = [] max_chars_per_page = 4096 with requests.Session() as session: resp = session.get( url="https://www.google.com/search", headers={ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0" }, params={ "q": query, "num": num_results, "udm": 14, "hl": language, }, timeout=timeout, verify=ssl_verify, ) resp.raise_for_status() soup = BeautifulSoup(resp.text, "html.parser") result_block = soup.find_all("div", attrs={"class": "g"}) for result in result_block: link = result.find("a", href=True) if link: link = link["href"] try: webpage = session.get( link, headers={ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0" }, ) webpage.raise_for_status() visible_text = extract_text_from_webpage(webpage.text) if len(visible_text) > max_chars_per_page: visible_text = visible_text[:max_chars_per_page] all_results.append({"link": link, "text": visible_text}) except requests.exceptions.RequestException as e: print(f"Error fetching or processing {link}: {e}") pass else: pass return all_results @spaces.GPU(duration=120) def generate( message: str, chat_history: list[tuple[str, str]], allow_used_tools: bool = True, system_prompt: str = "", max_new_tokens: int = 1536, temperature: float = 0.4, top_p: float = 0.95, top_k: int = 50, repetition_penalty: float = 1.0, other_client_info: str = None, ) -> Iterator[str]: # print() # print("allow_used_tools:\n", allow_used_tools) # print("system_prompt:\n", system_prompt) # print("max_new_tokens:\n", max_new_tokens) # print("temperature:\n", temperature) def build_input_ids( apply_tools: bool = None, references=None, ): conversation = [] if system_prompt: conversation.append({"role": "system", "content": system_prompt}) if apply_tools is True: conversation.append({"role": "tools", "content": supported_tools}) if references is None: references = [other_client_info] else: references.insert(0, other_client_info) if ( references is not None and isinstance(references, list) and len(references) > 0 ): conversation.append( { "role": "refs", "content": json.dumps( { "instructions": "These are only general documents used for reference to give the most accurate and honest answers possible. Ignore it if it's irrelevant and don't overuse it.", "documents": references, }, indent=2, ensure_ascii=False, ), } ) for user, assistant in chat_history: conversation.extend( [ {"role": "user", "content": user}, {"role": "assistant", "content": assistant}, ] ) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ) input_ids = input_ids.to(model.device) if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning( f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens." ) return input_ids def generate_chat_responses( previous_response: str = None, ): document_references = [] if previous_response is not None: scheduled_tools_runs = None try: scheduled_tools_runs = json.loads(previous_response) if scheduled_tools_runs["type"] == "function" and scheduled_tools_runs[ "name" ] in ["search_on_internet"]: pass else: scheduled_tools_runs = None except Exception as e: print(e) pass if ( scheduled_tools_runs is not None and scheduled_tools_runs["name"] == "search_on_internet" ): keyword = scheduled_tools_runs["arguments"]["keyword"] search_type = scheduled_tools_runs["arguments"]["type"] language = scheduled_tools_runs["arguments"]["language"] print("scheduled_tools_runs:", scheduled_tools_runs) if search_type == "wikipedia": gr.Info( "Searching for information on the Wikipedia.", duration=5, visible=True, ) document_references.extend( search_with_wikipedia(query=keyword, language=language) ) gr.Info("Searching for information on the Google.") document_references.extend( search_with_google( query=keyword, language=language, num_results=3, # num_results=2 if search_type == "wikipedia" else 3, ) ) print("document_references:", document_references) apply_tools = ( True if allow_used_tools is True and previous_response is None else False ) input_ids = build_input_ids( apply_tools=apply_tools, references=document_references, ) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, repetition_penalty=repetition_penalty, ) if temperature == 0: generate_kwargs["do_sample"] = False else: generate_kwargs["temperature"] = temperature generate_kwargs["top_p"] = top_p generate_kwargs["top_k"] = top_k t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() state = { "mark": None, "respond": False, } outputs = [] for text in streamer: if state["mark"] is None: state["mark"] = time.time() outputs.append(text) if ( apply_tools is False or state["mark"] + waiting_tools_timeout < time.time() ): state["respond"] = True yield "".join(outputs) if ( apply_tools is True and state["respond"] is False and state["mark"] + waiting_tools_timeout > time.time() ): previous_response = "".join(outputs) yield from generate_chat_responses(previous_response=previous_response) yield from generate_chat_responses(previous_response=None) chatbot = gr.Chatbot( height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta", show_copy_button=True ) chat_interface = gr.ChatInterface( fn=generate, chatbot=chatbot, fill_height=True, additional_inputs=[ gr.Checkbox( label="Allow used tools (available: search on internet)", value=True ), gr.Textbox(label="System prompt", lines=6), gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0.0, maximum=2.0, step=0.1, value=0.4, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.95, ), gr.Slider( label="Top-k", minimum=1, maximum=100, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0, ), gr.Textbox( label="Other client information", lines=1, value="This user's current time: {}".format(time.strftime("%Y-%m-%d")), visible=False, ), ], stop_btn="Stop", cache_examples=False, examples=EXAMPLES, examples_per_page=9, concurrency_limit=100, ) with gr.Blocks(fill_height=True, css="style.css") as demo: gr.Markdown(DESCRIPTION) chat_interface.render() gr.Markdown(LICENSE) if __name__ == "__main__": demo.queue(max_size=20).launch(share=True)