lamhieu commited on
Commit
bf7aef8
โ€ข
1 Parent(s): b54a72a

chore: support tools with search on internet

Browse files
Files changed (2) hide show
  1. app.py +241 -49
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # pylint: skip-file
2
 
3
  import subprocess
 
 
4
 
5
  subprocess.run(
6
  f"pip install flash-attn --no-build-isolation",
@@ -15,7 +17,11 @@ from typing import Iterator
15
  import gradio as gr
16
  import spaces
17
  import torch
 
 
18
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
19
 
20
 
21
  MAX_MAX_NEW_TOKENS = 4096
@@ -25,13 +31,12 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
25
  DESCRIPTION = """\
26
  # Playground with Ghost 8B Beta (ฮฒ, 8k)
27
 
28
- **Ghost 8B Beta** is a large language model developed with goals that include excellent multilingual support, superior knowledge capabilities, and cost-effectiveness. The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
29
-
30
- The Ghost 8B Beta model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/).
31
 
32
  The languages supported are ๐Ÿ‡บ๐Ÿ‡ธ English, ๐Ÿ‡ซ๐Ÿ‡ท French, ๐Ÿ‡ฎ๐Ÿ‡น Italian, ๐Ÿ‡ช๐Ÿ‡ธ Spanish, ๐Ÿ‡ต๐Ÿ‡น Portuguese, ๐Ÿ‡ฉ๐Ÿ‡ช German, ๐Ÿ‡ป๐Ÿ‡ณ Vietnamese, ๐Ÿ‡ฐ๐Ÿ‡ท Korean and ๐Ÿ‡จ๐Ÿ‡ณ Chinese.
33
 
34
- ๐Ÿ“‹ Note: current model version is "disl-0x5" (10 Jul 2024), context length 8k (8192 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
 
35
  """
36
 
37
 
@@ -250,88 +255,274 @@ if not torch.cuda.is_available():
250
 
251
  if torch.cuda.is_available():
252
  model_id = "ghost-x/ghost-8b-beta"
253
- model_tk = os.getenv("HF_TOKEN", None)
254
  model = AutoModelForCausalLM.from_pretrained(
255
  model_id,
256
  device_map="auto",
257
  torch_dtype=torch.bfloat16,
258
  attn_implementation="flash_attention_2",
259
  trust_remote_code=True,
260
- token=model_tk,
261
  )
262
  tokenizer = AutoTokenizer.from_pretrained(
263
  model_id,
264
  trust_remote_code=True,
265
- token=model_tk,
266
  )
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  @spaces.GPU(duration=120)
270
  def generate(
271
  message: str,
272
  chat_history: list[tuple[str, str]],
273
- system_prompt: str,
 
274
  max_new_tokens: int = 1536,
275
  temperature: float = 0.4,
276
  top_p: float = 0.95,
277
  top_k: int = 50,
278
  repetition_penalty: float = 1.0,
279
  ) -> Iterator[str]:
280
- conversation = []
281
- if system_prompt:
282
- conversation.append({"role": "system", "content": system_prompt})
283
- for user, assistant in chat_history:
284
- conversation.extend(
285
- [
286
- {"role": "user", "content": user},
287
- {"role": "assistant", "content": assistant},
288
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  )
290
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
291
 
292
- input_ids = tokenizer.apply_chat_template(
293
- conversation, add_generation_prompt=True, return_tensors="pt"
294
- )
295
- input_ids = input_ids.to(model.device)
296
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
297
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
298
- gr.Warning(
299
- f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  )
 
 
 
 
 
 
301
 
302
- streamer = TextIteratorStreamer(
303
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
304
- )
305
- generate_kwargs = dict(
306
- input_ids=input_ids,
307
- streamer=streamer,
308
- max_new_tokens=max_new_tokens,
309
- do_sample=True,
310
- repetition_penalty=repetition_penalty,
311
- )
312
- if temperature == 0:
313
- generate_kwargs["do_sample"] = False
314
- else:
315
- generate_kwargs["temperature"] = temperature
316
- generate_kwargs["top_p"] = top_p
317
- generate_kwargs["top_k"] = top_k
318
 
319
- t = Thread(target=model.generate, kwargs=generate_kwargs)
320
- t.start()
 
 
 
 
 
 
 
 
 
 
321
 
322
- outputs = []
323
- for text in streamer:
324
- outputs.append(text)
325
- yield "".join(outputs)
 
 
 
326
 
 
327
 
328
- chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta")
 
 
 
329
 
330
  chat_interface = gr.ChatInterface(
331
  fn=generate,
332
  chatbot=chatbot,
333
  fill_height=True,
334
  additional_inputs=[
 
 
 
335
  gr.Textbox(label="System prompt", lines=6),
336
  gr.Slider(
337
  label="Max new tokens",
@@ -373,6 +564,7 @@ chat_interface = gr.ChatInterface(
373
  cache_examples=False,
374
  examples=EXAMPLES,
375
  examples_per_page=9,
 
376
  )
377
 
378
  with gr.Blocks(fill_height=True, css="style.css") as demo:
 
1
  # pylint: skip-file
2
 
3
  import subprocess
4
+ import json
5
+ import requests
6
 
7
  subprocess.run(
8
  f"pip install flash-attn --no-build-isolation",
 
17
  import gradio as gr
18
  import spaces
19
  import torch
20
+ import wikipedia
21
+ import time
22
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
23
+ from bs4 import BeautifulSoup
24
+ from functools import lru_cache
25
 
26
 
27
  MAX_MAX_NEW_TOKENS = 4096
 
31
  DESCRIPTION = """\
32
  # Playground with Ghost 8B Beta (ฮฒ, 8k)
33
 
34
+ **Ghost 8B Beta** model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/). The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
 
 
35
 
36
  The languages supported are ๐Ÿ‡บ๐Ÿ‡ธ English, ๐Ÿ‡ซ๐Ÿ‡ท French, ๐Ÿ‡ฎ๐Ÿ‡น Italian, ๐Ÿ‡ช๐Ÿ‡ธ Spanish, ๐Ÿ‡ต๐Ÿ‡น Portuguese, ๐Ÿ‡ฉ๐Ÿ‡ช German, ๐Ÿ‡ป๐Ÿ‡ณ Vietnamese, ๐Ÿ‡ฐ๐Ÿ‡ท Korean and ๐Ÿ‡จ๐Ÿ‡ณ Chinese.
37
 
38
+ ๐Ÿ—ž๏ธ **Updates**
39
+ * Jul 23, 2024: added support for tools, now available to search for information on the internet.
40
  """
41
 
42
 
 
255
 
256
  if torch.cuda.is_available():
257
  model_id = "ghost-x/ghost-8b-beta"
258
+ hf_serect = os.getenv("HF_TOKEN", None)
259
  model = AutoModelForCausalLM.from_pretrained(
260
  model_id,
261
  device_map="auto",
262
  torch_dtype=torch.bfloat16,
263
  attn_implementation="flash_attention_2",
264
  trust_remote_code=True,
265
+ token=hf_serect,
266
  )
267
  tokenizer = AutoTokenizer.from_pretrained(
268
  model_id,
269
  trust_remote_code=True,
270
+ token=hf_serect,
271
  )
272
 
273
+ waiting_tools_timeout = 5
274
+ supported_tools = json.dumps(
275
+ [
276
+ {
277
+ "type": "function",
278
+ "function": {
279
+ "name": "search_on_internet",
280
+ "description": "Use this tool to search online, only use it for information you don't know or are unsure of, don't abuse it.",
281
+ "parameters": {
282
+ "type": "object",
283
+ "properties": {
284
+ "keyword": {
285
+ "type": "string",
286
+ "description": "Search keywords, rephrase to optimize search results based on questions suitable to the specified search type.",
287
+ "required": True,
288
+ },
289
+ "type": {
290
+ "type": "string",
291
+ "description": "Search type, based on the question to determine whether to search for it in 'wikipedia' or 'google', prefer to use wikipedia for information about events, history and people.",
292
+ "enum": ["wikipedia", "google"],
293
+ "default": "google",
294
+ "required": True,
295
+ },
296
+ },
297
+ },
298
+ },
299
+ }
300
+ ],
301
+ ensure_ascii=False,
302
+ )
303
+
304
+
305
+ @lru_cache(maxsize=128)
306
+ def extract_text_from_webpage(html_content):
307
+ soup = BeautifulSoup(html_content, "html.parser")
308
+ for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
309
+ tag.extract()
310
+ visible_text = soup.get_text(strip=True, separator=" ")
311
+ return visible_text
312
+
313
+
314
+ def search_with_wikipedia(query: str):
315
+ all_results = []
316
+ try:
317
+ all_results.append(wikipedia.summary(query))
318
+ except Exception as e:
319
+ pass
320
+ return all_results
321
+
322
+
323
+ def search_with_google(
324
+ query: str,
325
+ num_results: int = 3,
326
+ timeout: int = 5,
327
+ ssl_verify: bool = None,
328
+ ):
329
+ all_results = []
330
+ max_chars_per_page = 4096
331
+ with requests.Session() as session:
332
+ resp = session.get(
333
+ url="https://www.google.com/search",
334
+ headers={
335
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
336
+ },
337
+ params={
338
+ "q": query,
339
+ "num": num_results,
340
+ "udm": 14,
341
+ },
342
+ timeout=timeout,
343
+ verify=ssl_verify,
344
+ )
345
+ resp.raise_for_status()
346
+ soup = BeautifulSoup(resp.text, "html.parser")
347
+ result_block = soup.find_all("div", attrs={"class": "g"})
348
+ for result in result_block:
349
+ link = result.find("a", href=True)
350
+ if link:
351
+ link = link["href"]
352
+ try:
353
+ webpage = session.get(
354
+ link,
355
+ headers={
356
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
357
+ },
358
+ )
359
+ webpage.raise_for_status()
360
+ visible_text = extract_text_from_webpage(webpage.text)
361
+ if len(visible_text) > max_chars_per_page:
362
+ visible_text = visible_text[:max_chars_per_page]
363
+ all_results.append({"link": link, "text": visible_text})
364
+ except requests.exceptions.RequestException as e:
365
+ print(f"Error fetching or processing {link}: {e}")
366
+ pass
367
+ else:
368
+ pass
369
+ return all_results
370
+
371
 
372
  @spaces.GPU(duration=120)
373
  def generate(
374
  message: str,
375
  chat_history: list[tuple[str, str]],
376
+ allow_used_tools: bool = True,
377
+ system_prompt: str = "",
378
  max_new_tokens: int = 1536,
379
  temperature: float = 0.4,
380
  top_p: float = 0.95,
381
  top_k: int = 50,
382
  repetition_penalty: float = 1.0,
383
  ) -> Iterator[str]:
384
+ # print()
385
+ # print("allow_used_tools:\n", allow_used_tools)
386
+ # print("system_prompt:\n", system_prompt)
387
+ # print("max_new_tokens:\n", max_new_tokens)
388
+ # print("temperature:\n", temperature)
389
+
390
+ def build_input_ids(
391
+ apply_tools: bool = None,
392
+ references: list[str] = None,
393
+ ):
394
+ conversation = []
395
+ if system_prompt:
396
+ conversation.append({"role": "system", "content": system_prompt})
397
+ if apply_tools is True:
398
+ conversation.append({"role": "tools", "content": supported_tools})
399
+ if (
400
+ references is not None
401
+ and isinstance(references, list)
402
+ and len(references) > 0
403
+ ):
404
+ conversation.append(
405
+ {
406
+ "role": "refs",
407
+ "content": json.dumps(references, ensure_ascii=False),
408
+ }
409
+ )
410
+
411
+ for user, assistant in chat_history:
412
+ conversation.extend(
413
+ [
414
+ {"role": "user", "content": user},
415
+ {"role": "assistant", "content": assistant},
416
+ ]
417
+ )
418
+ conversation.append({"role": "user", "content": message})
419
+
420
+ input_ids = tokenizer.apply_chat_template(
421
+ conversation, add_generation_prompt=True, return_tensors="pt"
422
  )
423
+ input_ids = input_ids.to(model.device)
424
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
425
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
426
+ gr.Warning(
427
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
428
+ )
429
+ return input_ids
430
 
431
+ def generate_chat_responses(
432
+ previous_response: str = None,
433
+ ):
434
+ document_references = []
435
+ if previous_response is not None:
436
+ scheduled_tools_runs = None
437
+ try:
438
+ scheduled_tools_runs = json.loads(previous_response)
439
+ if scheduled_tools_runs["type"] == "function" and scheduled_tools_runs[
440
+ "name"
441
+ ] in ["search_on_internet"]:
442
+ pass
443
+ else:
444
+ scheduled_tools_runs = None
445
+ except Exception as e:
446
+ print(e)
447
+ pass
448
+
449
+ if (
450
+ scheduled_tools_runs is not None
451
+ and scheduled_tools_runs["name"] == "search_on_internet"
452
+ ):
453
+ keyword = scheduled_tools_runs["arguments"]["keyword"]
454
+ search_type = scheduled_tools_runs["arguments"]["type"]
455
+ if search_type == "wikipedia":
456
+ gr.Info("Searching for information on the Wikipedia.")
457
+ document_references = search_with_wikipedia(keyword)
458
+ else:
459
+ gr.Info("Searching for information on the Google.")
460
+ document_references = search_with_google(keyword)
461
+
462
+ input_ids = build_input_ids(
463
+ apply_tools=(
464
+ True
465
+ if allow_used_tools is True and previous_response is None
466
+ else False
467
+ ),
468
+ references=document_references,
469
+ )
470
+ streamer = TextIteratorStreamer(
471
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
472
+ )
473
+ generate_kwargs = dict(
474
+ input_ids=input_ids,
475
+ streamer=streamer,
476
+ max_new_tokens=max_new_tokens,
477
+ do_sample=True,
478
+ repetition_penalty=repetition_penalty,
479
  )
480
+ if temperature == 0:
481
+ generate_kwargs["do_sample"] = False
482
+ else:
483
+ generate_kwargs["temperature"] = temperature
484
+ generate_kwargs["top_p"] = top_p
485
+ generate_kwargs["top_k"] = top_k
486
 
487
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
488
+ t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ state = {
491
+ "mark": None,
492
+ "respond": False,
493
+ }
494
+ outputs = []
495
+ for text in streamer:
496
+ if state["mark"] is None:
497
+ state["mark"] = time.time()
498
+ outputs.append(text)
499
+ if state["mark"] + waiting_tools_timeout < time.time():
500
+ state["respond"] = True
501
+ yield "".join(outputs)
502
 
503
+ if (
504
+ state["respond"] is False
505
+ and state["mark"] + waiting_tools_timeout > time.time()
506
+ ):
507
+ gr.Info("Searching for information on the internet.")
508
+ previous_response = "".join(outputs)
509
+ yield from generate_chat_responses(previous_response=previous_response)
510
 
511
+ yield from generate_chat_responses(previous_response=None)
512
 
513
+
514
+ chatbot = gr.Chatbot(
515
+ height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta", show_copy_button=True
516
+ )
517
 
518
  chat_interface = gr.ChatInterface(
519
  fn=generate,
520
  chatbot=chatbot,
521
  fill_height=True,
522
  additional_inputs=[
523
+ gr.Checkbox(
524
+ label="Allow used tools (available: search on internet)", value=True
525
+ ),
526
  gr.Textbox(label="System prompt", lines=6),
527
  gr.Slider(
528
  label="Max new tokens",
 
564
  cache_examples=False,
565
  examples=EXAMPLES,
566
  examples_per_page=9,
567
+ concurrency_limit=100,
568
  )
569
 
570
  with gr.Blocks(fill_height=True, css="style.css") as demo:
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
- gradio==4.37.2
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
 
 
 
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
+ gradio==4.39.0
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
9
+ beautifulsoup4>=4.9
10
+ wikipedia==1.4.0