Update rkllm_binding.py
Browse files- rkllm_binding.py +4 -5
rkllm_binding.py
CHANGED
@@ -4,7 +4,7 @@ from enum import IntEnum
|
|
4 |
from typing import Callable, Any
|
5 |
|
6 |
# Load the shared library
|
7 |
-
_lib = ctypes.CDLL("librkllmrt.so") # Adjust the library name if necessary
|
8 |
|
9 |
# Define enums
|
10 |
class LLMCallState(IntEnum):
|
@@ -181,12 +181,12 @@ def destroy(handle: ctypes.c_void_p) -> None:
|
|
181 |
raise RuntimeError(f"Failed to destroy RKLLM: {status}")
|
182 |
|
183 |
def run(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
|
184 |
-
status = _lib.rkllm_run(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), userdata)
|
185 |
if status != 0:
|
186 |
raise RuntimeError(f"Failed to run RKLLM: {status}")
|
187 |
|
188 |
def run_async(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
|
189 |
-
status = _lib.rkllm_run_async(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), userdata)
|
190 |
if status != 0:
|
191 |
raise RuntimeError(f"Failed to run RKLLM asynchronously: {status}")
|
192 |
|
@@ -212,8 +212,7 @@ def create_rkllm_input(input_type: RKLLMInputType, **kwargs) -> RKLLMInput:
|
|
212 |
elif input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
213 |
embed = kwargs['embed']
|
214 |
rkllm_input._input.embed_input.embed = numpy_to_c_array(embed, ctypes.c_float)
|
215 |
-
|
216 |
-
rkllm_input._input.embed_input.n_tokens = embed.shape[2]
|
217 |
elif input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
218 |
tokens = kwargs['tokens']
|
219 |
rkllm_input._input.token_input.input_ids = numpy_to_c_array(tokens, ctypes.c_int32)
|
|
|
4 |
from typing import Callable, Any
|
5 |
|
6 |
# Load the shared library
|
7 |
+
_lib = ctypes.CDLL("./librkllmrt.so") # Adjust the library name if necessary
|
8 |
|
9 |
# Define enums
|
10 |
class LLMCallState(IntEnum):
|
|
|
181 |
raise RuntimeError(f"Failed to destroy RKLLM: {status}")
|
182 |
|
183 |
def run(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
|
184 |
+
status = _lib.rkllm_run(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), ctypes.c_void_p(userdata))
|
185 |
if status != 0:
|
186 |
raise RuntimeError(f"Failed to run RKLLM: {status}")
|
187 |
|
188 |
def run_async(handle: ctypes.c_void_p, rkllm_input: RKLLMInput, rkllm_infer_params: RKLLMInferParam, userdata: Any) -> None:
|
189 |
+
status = _lib.rkllm_run_async(handle, ctypes.byref(rkllm_input), ctypes.byref(rkllm_infer_params), ctypes.c_void_p(userdata))
|
190 |
if status != 0:
|
191 |
raise RuntimeError(f"Failed to run RKLLM asynchronously: {status}")
|
192 |
|
|
|
212 |
elif input_type == RKLLMInputType.RKLLM_INPUT_EMBED:
|
213 |
embed = kwargs['embed']
|
214 |
rkllm_input._input.embed_input.embed = numpy_to_c_array(embed, ctypes.c_float)
|
215 |
+
rkllm_input._input.embed_input.n_tokens = embed.shape[1]
|
|
|
216 |
elif input_type == RKLLMInputType.RKLLM_INPUT_TOKEN:
|
217 |
tokens = kwargs['tokens']
|
218 |
rkllm_input._input.token_input.input_ids = numpy_to_c_array(tokens, ctypes.c_int32)
|