davidberenstein1957 HF staff commited on
Commit
5ff9afc
1 Parent(s): 1853f75

fix: app.py with qwen backup

Browse files
Files changed (1) hide show
  1. app.py +46 -9
app.py CHANGED
@@ -1,16 +1,47 @@
1
  import json
 
2
 
3
  import gradio as gr
4
- from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
6
 
7
- llm = InferenceEndpointsLLM(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
9
  tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
10
  generation_kwargs={"max_new_tokens": 1000},
11
  )
12
- task = ArgillaLabeller(llm=llm)
13
- task.load()
14
 
15
 
16
  def load_examples():
@@ -49,11 +80,17 @@ def process_records_gradio(records, fields, question, example_records=None):
49
  if example_records:
50
  runtime_parameters["example_records"] = example_records
51
 
52
- task.set_runtime_parameters(runtime_parameters)
53
-
54
  results = []
55
- output = task.process(inputs=[{"record": record} for record in records])
56
- output = next(output)
 
 
 
 
 
 
57
  for idx in range(len(records)):
58
  entry = output[idx]
59
  if entry["suggestions"]:
@@ -115,7 +152,7 @@ interface = gr.Interface(
115
  gr.Code(label="Question (JSON, optional)", language="json"),
116
  ],
117
  examples=examples,
118
- cache_examples=True,
119
  outputs=gr.Code(label="Suggestions", language="json", lines=10),
120
  title="Distilabel - ArgillaLabeller - Record Processing Interface",
121
  description=description,
 
1
  import json
2
+ import os
3
 
4
  import gradio as gr
5
+ from distilabel.llms import InferenceEndpointsLLM, LlamaCppLLM
6
  from distilabel.steps.tasks.argillalabeller import ArgillaLabeller
7
 
8
+ file_path = os.path.join(os.path.dirname(__file__), "Qwen2-5-0.5B-Instruct-f16.gguf")
9
+ download_url = "https://huggingface.co/gaianet/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-Q5_K_S.gguf?download=true"
10
+
11
+
12
+ if not os.path.exists(file_path):
13
+ import requests
14
+ import tqdm
15
+
16
+ response = requests.get(download_url, stream=True)
17
+ total_length = int(response.headers.get("content-length"))
18
+
19
+ with open(file_path, "wb") as f:
20
+ for chunk in tqdm.tqdm(
21
+ response.iter_content(chunk_size=1024 * 1024),
22
+ total=total_length / (1024 * 1024),
23
+ unit="KB",
24
+ unit_scale=True,
25
+ ):
26
+ f.write(chunk)
27
+
28
+
29
+ llm_cpp = LlamaCppLLM(
30
+ model_path=file_path,
31
+ n_gpu_layers=-1,
32
+ n_ctx=1000 * 114,
33
+ generation_kwargs={"max_new_tokens": 1000 * 14},
34
+ )
35
+ task_cpp = ArgillaLabeller(llm=llm_cpp)
36
+ task_cpp.load()
37
+
38
+ llm_ep = InferenceEndpointsLLM(
39
  model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
40
  tokenizer_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
41
  generation_kwargs={"max_new_tokens": 1000},
42
  )
43
+ task_ep = ArgillaLabeller(llm=llm_ep)
44
+ task_ep.load()
45
 
46
 
47
  def load_examples():
 
80
  if example_records:
81
  runtime_parameters["example_records"] = example_records
82
 
83
+ task_ep.set_runtime_parameters(runtime_parameters)
84
+ task_cpp.set_runtime_parameters(runtime_parameters)
85
  results = []
86
+ try:
87
+ output = next(
88
+ task_ep.process(inputs=[{"record": record} for record in records])
89
+ )
90
+ except Exception:
91
+ output = next(
92
+ task_cpp.process(inputs=[{"record": record} for record in records])
93
+ )
94
  for idx in range(len(records)):
95
  entry = output[idx]
96
  if entry["suggestions"]:
 
152
  gr.Code(label="Question (JSON, optional)", language="json"),
153
  ],
154
  examples=examples,
155
+ cache_examples=False,
156
  outputs=gr.Code(label="Suggestions", language="json", lines=10),
157
  title="Distilabel - ArgillaLabeller - Record Processing Interface",
158
  description=description,