happyme531 commited on
Commit
d09e8b9
·
verified ·
1 Parent(s): d8ab42f

Update rkllm_binding.py

Browse files
Files changed (1) hide show
  1. rkllm_binding.py +4 -5
rkllm_binding.py CHANGED
@@ -4,7 +4,7 @@ from enum import IntEnum
4
  from typing import Callable, Any
5
 
6
  # Load the shared library
7
- _lib = ctypes.CDLL("librkllmrt.so") # Adjust the library name if necessary
8
 
9
  # Define enums
10
  class LLMCallState(IntEnum):
@@ -181,12 +181,12 @@ def destroy(handle: ctypes.c_void_p) -> None:
181
  raise RuntimeError(f"Failed to destroy RKLLM: {status}")
182
 
183
  def run(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
184
- status = _lib.rkllm_run(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), userdata)
185
  if status != 0:
186
  raise RuntimeError(f"Failed to run RKLLM: {status}")
187
 
188
  def run_async(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
189
- status = _lib.rkllm_run_async(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), userdata)
190
  if status != 0:
191
  raise RuntimeError(f"Failed to run RKLLM asynchronously: {status}")
192
 
@@ -212,8 +212,7 @@ def create_rkllm_input(input_type: RKLLMInputType, **kwargs) -> RKLLMInput:
212
  elif input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
213
  embed = kwargs['embed']
214
  rkllm_input._input.embed_input.embed = numpy_to_c_array(embed, ctypes.c_float)
215
- # rkllm_input._input.embed_input.n_tokens = embed.shape[1]
216
- rkllm_input._input.embed_input.n_tokens = embed.shape[2]
217
  elif input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
218
  tokens = kwargs['tokens']
219
  rkllm_input._input.token_input.input_ids = numpy_to_c_array(tokens, ctypes.c_int32)
 
4
  from typing import Callable, Any
5
 
6
  # Load the shared library
7
+ _lib = ctypes.CDLL("./librkllmrt.so") # Adjust the library name if necessary
8
 
9
  # Define enums
10
  class LLMCallState(IntEnum):
 
181
  raise RuntimeError(f"Failed to destroy RKLLM: {status}")
182
 
183
  def run(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
184
+ status = _lib.rkllm_run(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), ctypes.c_void_p(userdata))
185
  if status != 0:
186
  raise RuntimeError(f"Failed to run RKLLM: {status}")
187
 
188
  def run_async(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
189
+ status = _lib.rkllm_run_async(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), ctypes.c_void_p(userdata))
190
  if status != 0:
191
  raise RuntimeError(f"Failed to run RKLLM asynchronously: {status}")
192
 
 
212
  elif input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
213
  embed = kwargs['embed']
214
  rkllm_input._input.embed_input.embed = numpy_to_c_array(embed, ctypes.c_float)
215
+ rkllm_input._input.embed_input.n_tokens = embed.shape[1]
 
216
  elif input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
217
  tokens = kwargs['tokens']
218
  rkllm_input._input.token_input.input_ids = numpy_to_c_array(tokens, ctypes.c_int32)