davidberenstein1957 HF staff commited on
Commit
c9bd449
·
1 Parent(s): c0fa328

feat: use InferenceEndpointsLLM

Browse files
Files changed (1) hide show
  1. app.py +5 -28
app.py CHANGED
@@ -1,36 +1,13 @@
1
  import json
2
- import os
3
 
4
  import gradio as gr
5
- from distilabel.llms import LlamaCppLLM
6
  from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
7
 
8
- file_path = os.path.join(os.path.dirname(__file__), "Qwen2-5-0.5B-Instruct-f16.gguf")
9
- download_url = "https://huggingface.co/gaianet/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-Q8_0.gguf?download=true"
10
-
11
-
12
- if not os.path.exists(file_path):
13
- import requests
14
- import tqdm
15
-
16
- response = requests.get(download_url, stream=True)
17
- total_length = int(response.headers.get("content-length"))
18
-
19
- with open(file_path, "wb") as f:
20
- for chunk in tqdm.tqdm(
21
- response.iter_content(chunk_size=1024 * 1024),
22
- total=total_length / (1024 * 1024),
23
- unit="KB",
24
- unit_scale=True,
25
- ):
26
- f.write(chunk)
27
-
28
- context_window = 1024 * 128
29
- llm = LlamaCppLLM(
30
- model_path=file_path,
31
- n_gpu_layers=-1,
32
- n_ctx=context_window,
33
- generation_kwargs={"max_new_tokens": context_window},
34
  )
35
  task = ArgillaLabeller(llm=llm)
36
  task.load()
 
1
  import json
 
2
 
3
  import gradio as gr
4
+ from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
6
 
7
+ llm = InferenceEndpointsLLM(
8
+ model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
9
+ tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
10
+ generation_kwargs={"max_new_tokens": 1000 * 128},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  )
12
  task = ArgillaLabeller(llm=llm)
13
  task.load()