Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .clang-format +161 -0
- .devops/nix/package-gguf-py.nix +36 -0
- .devops/nix/python-scripts.nix +66 -0
- .editorconfig +50 -0
- .gitattributes +20 -0
- .github/ISSUE_TEMPLATE/create-new-issue.md +14 -0
- .github/workflows/kcpp-build-release-arm64.yaml +87 -0
- .github/workflows/kcpp-build-release-linux-cuda12.yaml +34 -0
- .github/workflows/kcpp-build-release-linux.yaml +34 -0
- .github/workflows/kcpp-build-release-osx.yaml +41 -0
- .github/workflows/kcpp-build-release-win-full-cu12.yaml +91 -0
- .github/workflows/kcpp-build-release-win-full.yaml +92 -0
- .github/workflows/kcpp-build-release-win-oldcpu-full.yaml +91 -0
- .gitignore +140 -0
- CLINFO_LICENSE +19 -0
- CMakeLists.txt +543 -0
- LICENSE.md +661 -0
- MIT_LICENSE_GGML_LLAMACPP_ONLY +26 -0
- Makefile +758 -0
- OpenCL.dll +0 -0
- README.md +194 -0
- Remote-Link.cmd +18 -0
- build-info.h +12 -0
- build-xcframework.sh +519 -0
- clblast.dll +3 -0
- colab.ipynb +174 -0
- common/arg.cpp +0 -0
- common/arg.h +80 -0
- common/base64.hpp +392 -0
- common/build-info.cpp.in +4 -0
- common/chat.cpp +1779 -0
- common/chat.h +135 -0
- common/common.cpp +2058 -0
- common/common.h +681 -0
- common/console.cpp +504 -0
- common/console.h +19 -0
- common/json-schema-to-grammar.cpp +1024 -0
- common/json-schema-to-grammar.h +21 -0
- common/json.hpp +0 -0
- common/llguidance.cpp +270 -0
- common/log.cpp +393 -0
- common/log.h +103 -0
- common/minja/chat-template.hpp +529 -0
- common/minja/minja.hpp +0 -0
- common/ngram-cache.cpp +286 -0
- common/ngram-cache.h +101 -0
- common/sampling.cpp +570 -0
- common/sampling.h +107 -0
- common/speculative.cpp +278 -0
- common/speculative.h +28 -0
.clang-format
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
Language: Cpp
|
3 |
+
AlignAfterOpenBracket: Align
|
4 |
+
AlignArrayOfStructures: Left
|
5 |
+
AlignConsecutiveAssignments: AcrossComments
|
6 |
+
AlignConsecutiveBitFields: AcrossComments
|
7 |
+
AlignConsecutiveDeclarations: AcrossComments
|
8 |
+
AlignConsecutiveMacros: AcrossComments
|
9 |
+
# AlignConsecutiveShortCaseStatements: AcrossComments
|
10 |
+
AlignEscapedNewlines: Left # LeftWithLastLine
|
11 |
+
AlignOperands: Align
|
12 |
+
AlignTrailingComments:
|
13 |
+
Kind: Always
|
14 |
+
OverEmptyLines: 1
|
15 |
+
AllowAllArgumentsOnNextLine: true
|
16 |
+
AllowAllParametersOfDeclarationOnNextLine: false
|
17 |
+
# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen
|
18 |
+
AllowShortBlocksOnASingleLine: Never
|
19 |
+
AllowShortCaseLabelsOnASingleLine: false
|
20 |
+
AllowShortFunctionsOnASingleLine: Inline
|
21 |
+
AllowShortIfStatementsOnASingleLine: Never
|
22 |
+
AllowShortLambdasOnASingleLine: Inline
|
23 |
+
AllowShortLoopsOnASingleLine: false
|
24 |
+
AlwaysBreakBeforeMultilineStrings: true
|
25 |
+
BinPackArguments: true
|
26 |
+
BinPackParameters: true # OnePerLine
|
27 |
+
BitFieldColonSpacing: Both
|
28 |
+
BreakBeforeBraces: Custom # Attach
|
29 |
+
BraceWrapping:
|
30 |
+
AfterCaseLabel: true
|
31 |
+
AfterClass: false
|
32 |
+
AfterControlStatement: false
|
33 |
+
AfterEnum: false
|
34 |
+
AfterFunction: false
|
35 |
+
AfterNamespace: false
|
36 |
+
AfterObjCDeclaration: false
|
37 |
+
AfterStruct: false
|
38 |
+
AfterUnion: false
|
39 |
+
AfterExternBlock: false
|
40 |
+
BeforeCatch: false
|
41 |
+
BeforeElse: false
|
42 |
+
BeforeLambdaBody: false
|
43 |
+
BeforeWhile: false
|
44 |
+
IndentBraces: false
|
45 |
+
SplitEmptyFunction: false
|
46 |
+
SplitEmptyRecord: false
|
47 |
+
SplitEmptyNamespace: false
|
48 |
+
# BreakAdjacentStringLiterals: true
|
49 |
+
BreakAfterAttributes: Never
|
50 |
+
BreakBeforeBinaryOperators: None
|
51 |
+
BreakBeforeInlineASMColon: OnlyMultiline
|
52 |
+
BreakBeforeTernaryOperators: false
|
53 |
+
# BreakBinaryOperations: Never
|
54 |
+
BreakConstructorInitializers: AfterColon
|
55 |
+
# BreakFunctionDefinitionParameters: false
|
56 |
+
BreakInheritanceList: AfterComma
|
57 |
+
BreakStringLiterals: true
|
58 |
+
# BreakTemplateDeclarations: Yes
|
59 |
+
ColumnLimit: 120
|
60 |
+
CommentPragmas: '^ IWYU pragma:'
|
61 |
+
CompactNamespaces: false
|
62 |
+
ConstructorInitializerIndentWidth: 4
|
63 |
+
ContinuationIndentWidth: 4
|
64 |
+
Cpp11BracedListStyle: false
|
65 |
+
DerivePointerAlignment: false
|
66 |
+
DisableFormat: false
|
67 |
+
EmptyLineBeforeAccessModifier: Leave
|
68 |
+
EmptyLineAfterAccessModifier: Never
|
69 |
+
ExperimentalAutoDetectBinPacking: false
|
70 |
+
FixNamespaceComments: true
|
71 |
+
IncludeBlocks: Regroup
|
72 |
+
IncludeCategories:
|
73 |
+
- Regex: '^<.*\.h>'
|
74 |
+
Priority: 1
|
75 |
+
SortPriority: 0
|
76 |
+
- Regex: '^<.*'
|
77 |
+
Priority: 2
|
78 |
+
SortPriority: 0
|
79 |
+
- Regex: '.*'
|
80 |
+
Priority: 3
|
81 |
+
SortPriority: 0
|
82 |
+
IncludeIsMainRegex: '([-_](test|unittest))?$'
|
83 |
+
IncludeIsMainSourceRegex: ''
|
84 |
+
IndentAccessModifiers: false
|
85 |
+
IndentCaseBlocks: true
|
86 |
+
IndentCaseLabels: true
|
87 |
+
IndentExternBlock: NoIndent
|
88 |
+
IndentGotoLabels: false
|
89 |
+
IndentPPDirectives: AfterHash
|
90 |
+
IndentWidth: 4
|
91 |
+
IndentWrappedFunctionNames: false
|
92 |
+
InsertBraces: true # NOTE: may lead to incorrect formatting
|
93 |
+
InsertNewlineAtEOF: true
|
94 |
+
JavaScriptQuotes: Leave
|
95 |
+
JavaScriptWrapImports: true
|
96 |
+
KeepEmptyLinesAtTheStartOfBlocks: false
|
97 |
+
LambdaBodyIndentation: Signature
|
98 |
+
LineEnding: LF
|
99 |
+
MacroBlockBegin: ''
|
100 |
+
MacroBlockEnd: ''
|
101 |
+
MaxEmptyLinesToKeep: 1
|
102 |
+
NamespaceIndentation: None
|
103 |
+
ObjCBinPackProtocolList: Auto
|
104 |
+
ObjCBlockIndentWidth: 4
|
105 |
+
ObjCSpaceAfterProperty: true
|
106 |
+
ObjCSpaceBeforeProtocolList: true
|
107 |
+
PPIndentWidth: -1
|
108 |
+
PackConstructorInitializers: CurrentLine
|
109 |
+
PenaltyBreakAssignment: 2
|
110 |
+
PenaltyBreakBeforeFirstCallParameter: 1
|
111 |
+
PenaltyBreakComment: 300
|
112 |
+
PenaltyBreakFirstLessLess: 120
|
113 |
+
PenaltyBreakString: 1000
|
114 |
+
PenaltyBreakTemplateDeclaration: 10
|
115 |
+
PenaltyExcessCharacter: 1000000
|
116 |
+
PenaltyReturnTypeOnItsOwnLine: 200
|
117 |
+
PointerAlignment: Middle
|
118 |
+
QualifierAlignment: Left
|
119 |
+
#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict']
|
120 |
+
RawStringFormats:
|
121 |
+
- Language: Cpp
|
122 |
+
Delimiters:
|
123 |
+
- cc
|
124 |
+
- CC
|
125 |
+
- cpp
|
126 |
+
- Cpp
|
127 |
+
- CPP
|
128 |
+
- 'c++'
|
129 |
+
- 'C++'
|
130 |
+
CanonicalDelimiter: ''
|
131 |
+
ReferenceAlignment: Middle
|
132 |
+
ReflowComments: false # IndentOnly
|
133 |
+
SeparateDefinitionBlocks: Always
|
134 |
+
SortIncludes: CaseInsensitive
|
135 |
+
SortUsingDeclarations: LexicographicNumeric
|
136 |
+
SpaceAfterCStyleCast: true
|
137 |
+
SpaceAfterLogicalNot: false
|
138 |
+
SpaceAfterTemplateKeyword: true
|
139 |
+
SpaceBeforeAssignmentOperators: true
|
140 |
+
SpaceBeforeCpp11BracedList: false
|
141 |
+
SpaceBeforeCtorInitializerColon: true
|
142 |
+
SpaceBeforeInheritanceColon: true
|
143 |
+
SpaceBeforeParens: ControlStatements
|
144 |
+
SpaceBeforeRangeBasedForLoopColon: true
|
145 |
+
SpaceInEmptyBlock: false
|
146 |
+
SpaceInEmptyParentheses: false
|
147 |
+
SpacesBeforeTrailingComments: 2
|
148 |
+
SpacesInAngles: Never
|
149 |
+
SpacesInContainerLiterals: true
|
150 |
+
SpacesInLineCommentPrefix:
|
151 |
+
Minimum: 1
|
152 |
+
Maximum: -1
|
153 |
+
SpacesInParentheses: false
|
154 |
+
SpacesInSquareBrackets: false
|
155 |
+
SpaceBeforeSquareBrackets: false
|
156 |
+
Standard: c++17
|
157 |
+
TabWidth: 4
|
158 |
+
UseTab: Never
|
159 |
+
WhitespaceSensitiveMacros: ['STRINGIZE']
|
160 |
+
...
|
161 |
+
|
.devops/nix/package-gguf-py.nix
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
lib,
|
3 |
+
llamaVersion,
|
4 |
+
numpy,
|
5 |
+
tqdm,
|
6 |
+
sentencepiece,
|
7 |
+
pyyaml,
|
8 |
+
poetry-core,
|
9 |
+
buildPythonPackage,
|
10 |
+
pytestCheckHook,
|
11 |
+
}:
|
12 |
+
|
13 |
+
buildPythonPackage {
|
14 |
+
pname = "gguf";
|
15 |
+
version = llamaVersion;
|
16 |
+
pyproject = true;
|
17 |
+
nativeBuildInputs = [ poetry-core ];
|
18 |
+
propagatedBuildInputs = [
|
19 |
+
numpy
|
20 |
+
tqdm
|
21 |
+
sentencepiece
|
22 |
+
pyyaml
|
23 |
+
];
|
24 |
+
src = lib.cleanSource ../../gguf-py;
|
25 |
+
pythonImportsCheck = [
|
26 |
+
"numpy"
|
27 |
+
"gguf"
|
28 |
+
];
|
29 |
+
nativeCheckInputs = [ pytestCheckHook ];
|
30 |
+
doCheck = true;
|
31 |
+
meta = with lib; {
|
32 |
+
description = "Python package for writing binary files in the GGUF format";
|
33 |
+
license = licenses.mit;
|
34 |
+
maintainers = [ maintainers.ditsuke ];
|
35 |
+
};
|
36 |
+
}
|
.devops/nix/python-scripts.nix
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
lib,
|
3 |
+
stdenv,
|
4 |
+
buildPythonPackage,
|
5 |
+
poetry-core,
|
6 |
+
mkShell,
|
7 |
+
python3Packages,
|
8 |
+
gguf-py,
|
9 |
+
}@inputs:
|
10 |
+
|
11 |
+
let
|
12 |
+
llama-python-deps = with python3Packages; [
|
13 |
+
numpy
|
14 |
+
sentencepiece
|
15 |
+
transformers
|
16 |
+
protobuf
|
17 |
+
torchWithoutCuda
|
18 |
+
gguf-py
|
19 |
+
tqdm
|
20 |
+
|
21 |
+
# for scripts/compare-llama-bench.py
|
22 |
+
gitpython
|
23 |
+
tabulate
|
24 |
+
|
25 |
+
# for examples/pydantic-models-to-grammar-examples.py
|
26 |
+
docstring-parser
|
27 |
+
pydantic
|
28 |
+
|
29 |
+
];
|
30 |
+
|
31 |
+
llama-python-test-deps = with python3Packages; [
|
32 |
+
# Server bench
|
33 |
+
matplotlib
|
34 |
+
|
35 |
+
# server tests
|
36 |
+
openai
|
37 |
+
pytest
|
38 |
+
prometheus-client
|
39 |
+
];
|
40 |
+
in
|
41 |
+
|
42 |
+
buildPythonPackage ({
|
43 |
+
pname = "llama-scripts";
|
44 |
+
version = "0.0.0";
|
45 |
+
pyproject = true;
|
46 |
+
|
47 |
+
# NOTE: The files filtered out here are not visible in the build sandbox, neither
|
48 |
+
# do they affect the output hash. They can be modified without triggering a rebuild.
|
49 |
+
src = lib.cleanSourceWith {
|
50 |
+
filter =
|
51 |
+
name: type:
|
52 |
+
let
|
53 |
+
any = builtins.any (x: x);
|
54 |
+
baseName = builtins.baseNameOf name;
|
55 |
+
in
|
56 |
+
any [
|
57 |
+
(lib.hasSuffix ".py" name)
|
58 |
+
(baseName == "README.md")
|
59 |
+
(baseName == "pyproject.toml")
|
60 |
+
];
|
61 |
+
src = lib.cleanSource ../../.;
|
62 |
+
};
|
63 |
+
nativeBuildInputs = [ poetry-core ];
|
64 |
+
nativeCheckInputs = llama-python-test-deps;
|
65 |
+
dependencies = llama-python-deps;
|
66 |
+
})
|
.editorconfig
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://EditorConfig.org
|
2 |
+
|
3 |
+
# Top-most EditorConfig file
|
4 |
+
root = true
|
5 |
+
|
6 |
+
# Unix-style newlines with a newline ending every file, utf-8 charset
|
7 |
+
[*]
|
8 |
+
end_of_line = lf
|
9 |
+
insert_final_newline = true
|
10 |
+
trim_trailing_whitespace = true
|
11 |
+
charset = utf-8
|
12 |
+
indent_style = space
|
13 |
+
indent_size = 4
|
14 |
+
|
15 |
+
[Makefile]
|
16 |
+
indent_style = tab
|
17 |
+
|
18 |
+
[scripts/*.mk]
|
19 |
+
indent_style = tab
|
20 |
+
|
21 |
+
[prompts/*.txt]
|
22 |
+
insert_final_newline = unset
|
23 |
+
|
24 |
+
[examples/server/public/*]
|
25 |
+
indent_size = 2
|
26 |
+
|
27 |
+
[examples/server/public/deps_*]
|
28 |
+
trim_trailing_whitespace = unset
|
29 |
+
indent_style = unset
|
30 |
+
indent_size = unset
|
31 |
+
|
32 |
+
[examples/server/deps_*]
|
33 |
+
trim_trailing_whitespace = unset
|
34 |
+
indent_style = unset
|
35 |
+
indent_size = unset
|
36 |
+
|
37 |
+
[examples/llama.swiftui/llama.swiftui.xcodeproj/*]
|
38 |
+
indent_style = tab
|
39 |
+
|
40 |
+
[examples/cvector-generator/*.txt]
|
41 |
+
trim_trailing_whitespace = unset
|
42 |
+
insert_final_newline = unset
|
43 |
+
|
44 |
+
[models/templates/*.jinja]
|
45 |
+
indent_style = unset
|
46 |
+
indent_size = unset
|
47 |
+
end_of_line = unset
|
48 |
+
charset = unset
|
49 |
+
trim_trailing_whitespace = unset
|
50 |
+
insert_final_newline = unset
|
.gitattributes
CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
clblast.dll filter=lfs diff=lfs merge=lfs -text
|
37 |
+
cudart64_110.dll filter=lfs diff=lfs merge=lfs -text
|
38 |
+
cudart64_12.dll filter=lfs diff=lfs merge=lfs -text
|
39 |
+
examples/server/themes/buttons-top/buttons_top.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
examples/server/themes/wild/llamapattern.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
examples/server/themes/wild/wild.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
ggml/src/ggml-vulkan-shaders.cpp filter=lfs diff=lfs merge=lfs -text
|
43 |
+
glslc-linux filter=lfs diff=lfs merge=lfs -text
|
44 |
+
glslc.exe filter=lfs diff=lfs merge=lfs -text
|
45 |
+
lib/clblast.lib filter=lfs diff=lfs merge=lfs -text
|
46 |
+
msvcp140.dll filter=lfs diff=lfs merge=lfs -text
|
47 |
+
nikogreen.ico filter=lfs diff=lfs merge=lfs -text
|
48 |
+
otherarch/sdcpp/vocab.hpp filter=lfs diff=lfs merge=lfs -text
|
49 |
+
taesd.embd filter=lfs diff=lfs merge=lfs -text
|
50 |
+
taesd_3.embd filter=lfs diff=lfs merge=lfs -text
|
51 |
+
taesd_f.embd filter=lfs diff=lfs merge=lfs -text
|
52 |
+
taesd_xl.embd filter=lfs diff=lfs merge=lfs -text
|
53 |
+
vcruntime140.dll filter=lfs diff=lfs merge=lfs -text
|
54 |
+
vulkan-1.dll filter=lfs diff=lfs merge=lfs -text
|
55 |
+
winclinfo.exe filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/create-new-issue.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Create New Issue
|
3 |
+
about: Please describe the issue in detail
|
4 |
+
title: ''
|
5 |
+
labels: ''
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Describe the Issue**
|
11 |
+
A clear and detailed description of what the issue is, and how to duplicate it (if applicable).
|
12 |
+
|
13 |
+
**Additional Information:**
|
14 |
+
Please provide as much relevant information about your setup as possible, such as the Operating System, CPU, GPU, KoboldCpp Version, and relevant logs (helpful to include the launch params from the terminal output, flags and crash logs)
|
.github/workflows/kcpp-build-release-arm64.yaml
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Koboldcpp Linux ARM64
|
2 |
+
|
3 |
+
on: workflow_dispatch
|
4 |
+
env:
|
5 |
+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
linux-arm:
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
steps:
|
11 |
+
- name: Clone
|
12 |
+
id: checkout
|
13 |
+
uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
ref: ${{ github.head_ref || github.ref_name }}
|
16 |
+
|
17 |
+
- name: Install Dependencies
|
18 |
+
id: depends
|
19 |
+
run: |
|
20 |
+
sudo apt-get update
|
21 |
+
sudo apt-get install -y python3-tk python3-pip python3-dev build-essential \
|
22 |
+
libffi-dev libssl-dev libbz2-dev libreadline-dev libsqlite3-dev \
|
23 |
+
crossbuild-essential-arm64 gcc-aarch64-linux-gnu g++-aarch64-linux-gnu
|
24 |
+
|
25 |
+
- name: Install New GCC for Cross-Compilation
|
26 |
+
run: |
|
27 |
+
sudo apt-get install -y software-properties-common
|
28 |
+
sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
29 |
+
sudo apt-get update
|
30 |
+
sudo apt-get install -y gcc-12 g++-12 gcc-12-aarch64-linux-gnu g++-12-aarch64-linux-gnu
|
31 |
+
export CC=/usr/bin/aarch64-linux-gnu-gcc-12
|
32 |
+
export CXX=/usr/bin/aarch64-linux-gnu-g++-12
|
33 |
+
export AR=aarch64-linux-gnu-ar
|
34 |
+
export UNAME_M=aarch64
|
35 |
+
export UNAME_S=Linux
|
36 |
+
export PATH=/usr/bin:$PATH
|
37 |
+
make LLAMA_PORTABLE=1
|
38 |
+
chmod +x './create_ver_file.sh'
|
39 |
+
. create_ver_file.sh
|
40 |
+
mkdir -p dist
|
41 |
+
cp './koboldcpp_default.so' dist
|
42 |
+
ls
|
43 |
+
|
44 |
+
- name: Install QEMU
|
45 |
+
run: |
|
46 |
+
sudo apt-get update
|
47 |
+
sudo apt-get install -y qemu-user-static binfmt-support
|
48 |
+
|
49 |
+
- name: Setup QEMU for ARM64
|
50 |
+
run: |
|
51 |
+
docker run --rm --privileged multiarch/qemu-user-static --reset -p yes
|
52 |
+
|
53 |
+
- name: Build ARM64 PyInstaller
|
54 |
+
run: |
|
55 |
+
docker run --rm \
|
56 |
+
--platform linux/arm64 \
|
57 |
+
-v "${PWD}:/src" \
|
58 |
+
python:3.9-slim \
|
59 |
+
/bin/bash -c "
|
60 |
+
apt-get update && apt-get install -y build-essential && \
|
61 |
+
apt-get update && apt-get install -y gcc-12 g++-12 && \
|
62 |
+
export LD_LIBRARY_PATH=/usr/lib/gcc/x86_64-linux-gnu/12:$LD_LIBRARY_PATH && \
|
63 |
+
pip install customtkinter pyinstaller tk && \
|
64 |
+
cd /src && \
|
65 |
+
pyinstaller --noconfirm --onefile --collect-all customtkinter --collect-all psutil \
|
66 |
+
--add-data './koboldcpp_default.so:.' \
|
67 |
+
--add-data './kcpp_adapters:./kcpp_adapters' \
|
68 |
+
--add-data './koboldcpp.py:.' \
|
69 |
+
--add-data './klite.embd:.' \
|
70 |
+
--add-data './kcpp_docs.embd:.' \
|
71 |
+
--add-data './kcpp_sdui.embd:.' \
|
72 |
+
--add-data './taesd.embd:.' \
|
73 |
+
--add-data './taesd_xl.embd:.' \
|
74 |
+
--add-data './taesd_f.embd:.' \
|
75 |
+
--add-data './taesd_3.embd:.' \
|
76 |
+
--add-data './rwkv_vocab.embd:.' \
|
77 |
+
--add-data './rwkv_world_vocab.embd:.' \
|
78 |
+
--version-file './version.txt' \
|
79 |
+
--clean --console koboldcpp.py -n 'koboldcpp-linux-arm64'
|
80 |
+
"
|
81 |
+
|
82 |
+
- name: Save artifact
|
83 |
+
uses: actions/upload-artifact@v4
|
84 |
+
with:
|
85 |
+
name: kcpp_linux_arm64_binary
|
86 |
+
path: dist/
|
87 |
+
|
.github/workflows/kcpp-build-release-linux-cuda12.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Koboldcpp Linux CUDA12
|
2 |
+
|
3 |
+
on: workflow_dispatch
|
4 |
+
env:
|
5 |
+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
6 |
+
KCPP_CUDA: 12.1.0
|
7 |
+
REBUILD_VK_SHADERS: 1
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
linux:
|
11 |
+
runs-on: ubuntu-22.04
|
12 |
+
steps:
|
13 |
+
- name: Clone
|
14 |
+
id: checkout
|
15 |
+
uses: actions/checkout@v3
|
16 |
+
with:
|
17 |
+
ref: ${{ github.head_ref || github.ref_name }}
|
18 |
+
|
19 |
+
- name: Dependencies
|
20 |
+
id: depends
|
21 |
+
run: |
|
22 |
+
sudo apt-get update
|
23 |
+
sudo apt-get install git curl bzip2
|
24 |
+
|
25 |
+
- name: Build
|
26 |
+
id: make_build
|
27 |
+
run: |
|
28 |
+
./koboldcpp.sh dist
|
29 |
+
|
30 |
+
- name: Save artifact
|
31 |
+
uses: actions/upload-artifact@v4
|
32 |
+
with:
|
33 |
+
name: kcpp_linux_binary
|
34 |
+
path: dist/
|
.github/workflows/kcpp-build-release-linux.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Koboldcpp Linux
|
2 |
+
|
3 |
+
on: workflow_dispatch
|
4 |
+
env:
|
5 |
+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
6 |
+
NOAVX2: 1
|
7 |
+
REBUILD_VK_SHADERS: 1
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
linux:
|
11 |
+
runs-on: ubuntu-22.04
|
12 |
+
steps:
|
13 |
+
- name: Clone
|
14 |
+
id: checkout
|
15 |
+
uses: actions/checkout@v3
|
16 |
+
with:
|
17 |
+
ref: ${{ github.head_ref || github.ref_name }}
|
18 |
+
|
19 |
+
- name: Dependencies
|
20 |
+
id: depends
|
21 |
+
run: |
|
22 |
+
sudo apt-get update
|
23 |
+
sudo apt-get install git curl bzip2
|
24 |
+
|
25 |
+
- name: Build
|
26 |
+
id: make_build
|
27 |
+
run: |
|
28 |
+
./koboldcpp.sh dist
|
29 |
+
|
30 |
+
- name: Save artifact
|
31 |
+
uses: actions/upload-artifact@v4
|
32 |
+
with:
|
33 |
+
name: kcpp_linux_binary
|
34 |
+
path: dist/
|
.github/workflows/kcpp-build-release-osx.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Koboldcpp Mac
|
2 |
+
|
3 |
+
on: workflow_dispatch
|
4 |
+
env:
|
5 |
+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
osx:
|
9 |
+
runs-on: macos-latest
|
10 |
+
steps:
|
11 |
+
- name: Clone
|
12 |
+
id: checkout
|
13 |
+
uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
ref: ${{ github.head_ref || github.ref_name }}
|
16 |
+
|
17 |
+
- name: Dependencies
|
18 |
+
id: depends
|
19 |
+
run: |
|
20 |
+
pip install customtkinter pyinstaller tk
|
21 |
+
|
22 |
+
- name: Build
|
23 |
+
id: make_build
|
24 |
+
run: |
|
25 |
+
make LLAMA_METAL=1 LLAMA_PORTABLE=1
|
26 |
+
chmod +x './create_ver_file.sh'
|
27 |
+
. create_ver_file.sh
|
28 |
+
pyinstaller --noconfirm --onefile --collect-all customtkinter --collect-all psutil --add-data './koboldcpp_default.so:.' --add-data './ggml-metal-merged.metal:.' --add-data './kcpp_adapters:./kcpp_adapters' --add-data './koboldcpp.py:.' --add-data './klite.embd:.' --add-data './kcpp_docs.embd:.' --add-data './kcpp_sdui.embd:.' --add-data './taesd.embd:.' --add-data './taesd_xl.embd:.' --add-data './taesd_f.embd:.' --add-data './taesd_3.embd:.' --add-data './rwkv_vocab.embd:.' --add-data './rwkv_world_vocab.embd:.' --version-file './version.txt' --clean --console koboldcpp.py -n "koboldcpp-mac-arm64"
|
29 |
+
|
30 |
+
- name: Test
|
31 |
+
id: test
|
32 |
+
run: |
|
33 |
+
wget https://huggingface.co/concedo/koboldcpp/resolve/main/baby_llama.gguf
|
34 |
+
dist/koboldcpp-mac-arm64 --model baby_llama.gguf --gpulayers 99 --benchmark --prompt 'Hi, my name is'
|
35 |
+
|
36 |
+
- name: Save artifact
|
37 |
+
uses: actions/upload-artifact@v4
|
38 |
+
with:
|
39 |
+
name: kcpp_mac_binary
|
40 |
+
path: dist/
|
41 |
+
|
.github/workflows/kcpp-build-release-win-full-cu12.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Koboldcpp Windows Full Binaries CUDA 12
|
2 |
+
|
3 |
+
on: workflow_dispatch
|
4 |
+
env:
|
5 |
+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
windows:
|
9 |
+
runs-on: windows-2019
|
10 |
+
steps:
|
11 |
+
- name: Clone
|
12 |
+
id: checkout
|
13 |
+
uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
ref: ${{ github.head_ref || github.ref_name }}
|
16 |
+
|
17 |
+
- name: Get Python
|
18 |
+
uses: actions/setup-python@v2
|
19 |
+
with:
|
20 |
+
python-version: 3.8.10
|
21 |
+
|
22 |
+
- name: Install python dependencies
|
23 |
+
run: |
|
24 |
+
python -m pip install --upgrade pip
|
25 |
+
pip install customtkinter==5.2.0 pyinstaller==5.11.0 psutil==5.9.5
|
26 |
+
|
27 |
+
- name: Download and install win64devkit
|
28 |
+
run: |
|
29 |
+
curl -L https://github.com/skeeto/w64devkit/releases/download/v1.22.0/w64devkit-1.22.0.zip --output w64devkit.zip
|
30 |
+
Expand-Archive w64devkit.zip -DestinationPath .
|
31 |
+
|
32 |
+
- name: Add w64devkit to PATH
|
33 |
+
run: |
|
34 |
+
echo "$(Get-Location)\w64devkit\bin" | Out-File -Append -FilePath $env:GITHUB_PATH -Encoding utf8
|
35 |
+
|
36 |
+
- name: Print System Environment Variables
|
37 |
+
id: printvars
|
38 |
+
run: |
|
39 |
+
echo "Number of processors: ${env:NUMBER_OF_PROCESSORS}"
|
40 |
+
echo "Processor Architecture: ${env:PROCESSOR_ARCHITECTURE}"
|
41 |
+
echo "Computer Name: ${env:COMPUTERNAME}"
|
42 |
+
wmic cpu get name
|
43 |
+
wmic os get TotalVisibleMemorySize, FreePhysicalMemory
|
44 |
+
|
45 |
+
- name: Rebuild Vulkan Shaders
|
46 |
+
id: make_vk_shaders
|
47 |
+
run: |
|
48 |
+
make vulkan_shaders_gen -j ${env:NUMBER_OF_PROCESSORS}
|
49 |
+
|
50 |
+
- name: Build Non-CUDA
|
51 |
+
id: make_build
|
52 |
+
run: |
|
53 |
+
make LLAMA_CLBLAST=1 LLAMA_VULKAN=1 LLAMA_PORTABLE=1 -j ${env:NUMBER_OF_PROCESSORS}
|
54 |
+
echo "Vulkan Shaders Rebuilt"
|
55 |
+
|
56 |
+
- uses: Jimver/[email protected]
|
57 |
+
id: cuda-toolkit
|
58 |
+
with:
|
59 |
+
cuda: '12.1.0'
|
60 |
+
|
61 |
+
- name: Build CUDA
|
62 |
+
id: cmake_build
|
63 |
+
run: |
|
64 |
+
mkdir build
|
65 |
+
cd build
|
66 |
+
cmake .. -DLLAMA_CUBLAS=ON -DCMAKE_SYSTEM_VERSION="10.0.19041.0"
|
67 |
+
cmake --build . --config Release -j 2
|
68 |
+
cd ..
|
69 |
+
|
70 |
+
# note: The libraries that come from the github cuda directory seem to be larger, so they are not recommended
|
71 |
+
# - name: Download CuBLAS Libraries
|
72 |
+
# run: |
|
73 |
+
# curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublas64_11.dll --output cublas64_11.dll
|
74 |
+
# curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublasLt64_11.dll --output cublasLt64_11.dll
|
75 |
+
# ls
|
76 |
+
- name: Copy CuBLAS Libraries
|
77 |
+
run: |
|
78 |
+
copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cublasLt64_12.dll" .
|
79 |
+
copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cublas64_12.dll" .
|
80 |
+
ls
|
81 |
+
|
82 |
+
- name: Package PyInstallers
|
83 |
+
id: make_pyinstaller
|
84 |
+
run: |
|
85 |
+
./make_pyinstaller_cuda12.bat
|
86 |
+
|
87 |
+
- name: Save artifact
|
88 |
+
uses: actions/upload-artifact@v4
|
89 |
+
with:
|
90 |
+
name: kcpp_windows_pyinstallers
|
91 |
+
path: dist/
|
.github/workflows/kcpp-build-release-win-full.yaml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Koboldcpp Windows Full Binaries
|
2 |
+
|
3 |
+
on: workflow_dispatch
|
4 |
+
env:
|
5 |
+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
windows:
|
9 |
+
runs-on: windows-2019
|
10 |
+
steps:
|
11 |
+
- name: Clone
|
12 |
+
id: checkout
|
13 |
+
uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
ref: ${{ github.head_ref || github.ref_name }}
|
16 |
+
|
17 |
+
- name: Get Python
|
18 |
+
uses: actions/setup-python@v2
|
19 |
+
with:
|
20 |
+
python-version: 3.8.10
|
21 |
+
|
22 |
+
- name: Install python dependencies
|
23 |
+
run: |
|
24 |
+
python -m pip install --upgrade pip
|
25 |
+
pip install customtkinter==5.2.0 pyinstaller==5.11.0 psutil==5.9.5
|
26 |
+
|
27 |
+
- name: Download and install win64devkit
|
28 |
+
run: |
|
29 |
+
curl -L https://github.com/skeeto/w64devkit/releases/download/v1.22.0/w64devkit-1.22.0.zip --output w64devkit.zip
|
30 |
+
Expand-Archive w64devkit.zip -DestinationPath .
|
31 |
+
|
32 |
+
- name: Add w64devkit to PATH
|
33 |
+
run: |
|
34 |
+
echo "$(Get-Location)\w64devkit\bin" | Out-File -Append -FilePath $env:GITHUB_PATH -Encoding utf8
|
35 |
+
|
36 |
+
- name: Print System Environment Variables
|
37 |
+
id: printvars
|
38 |
+
run: |
|
39 |
+
echo "Number of processors: ${env:NUMBER_OF_PROCESSORS}"
|
40 |
+
echo "Processor Architecture: ${env:PROCESSOR_ARCHITECTURE}"
|
41 |
+
echo "Computer Name: ${env:COMPUTERNAME}"
|
42 |
+
wmic cpu get name
|
43 |
+
wmic os get TotalVisibleMemorySize, FreePhysicalMemory
|
44 |
+
|
45 |
+
- name: Rebuild Vulkan Shaders
|
46 |
+
id: make_vk_shaders
|
47 |
+
run: |
|
48 |
+
make vulkan_shaders_gen -j ${env:NUMBER_OF_PROCESSORS}
|
49 |
+
echo "Vulkan Shaders Rebuilt"
|
50 |
+
|
51 |
+
- name: Build Non-CUDA
|
52 |
+
id: make_build
|
53 |
+
run: |
|
54 |
+
make LLAMA_CLBLAST=1 LLAMA_VULKAN=1 LLAMA_PORTABLE=1 -j ${env:NUMBER_OF_PROCESSORS}
|
55 |
+
|
56 |
+
- uses: Jimver/[email protected]
|
57 |
+
id: cuda-toolkit
|
58 |
+
with:
|
59 |
+
cuda: '11.4.4'
|
60 |
+
|
61 |
+
- name: Build CUDA
|
62 |
+
id: cmake_build
|
63 |
+
run: |
|
64 |
+
mkdir build
|
65 |
+
cd build
|
66 |
+
cmake .. -DLLAMA_CUBLAS=ON -DCMAKE_SYSTEM_VERSION="10.0.19041.0"
|
67 |
+
cmake --build . --config Release -j 2
|
68 |
+
cd ..
|
69 |
+
|
70 |
+
# note: The libraries that come from the github cuda directory seem to be larger, so they are not recommended
|
71 |
+
- name: Download CuBLAS Libraries
|
72 |
+
run: |
|
73 |
+
curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublas64_11.dll --output cublas64_11.dll
|
74 |
+
curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublasLt64_11.dll --output cublasLt64_11.dll
|
75 |
+
ls
|
76 |
+
# - name: Copy CuBLAS Libraries
|
77 |
+
# run: |
|
78 |
+
# copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublasLt64_11.dll" .
|
79 |
+
# copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublas64_11.dll" .
|
80 |
+
# ls
|
81 |
+
|
82 |
+
- name: Package PyInstallers
|
83 |
+
id: make_pyinstaller
|
84 |
+
run: |
|
85 |
+
./make_pyinstaller.bat
|
86 |
+
./make_pyinstaller_cuda.bat
|
87 |
+
|
88 |
+
- name: Save artifact
|
89 |
+
uses: actions/upload-artifact@v4
|
90 |
+
with:
|
91 |
+
name: kcpp_windows_pyinstallers
|
92 |
+
path: dist/
|
.github/workflows/kcpp-build-release-win-oldcpu-full.yaml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Koboldcpp Windows Full OldCPU Binaries
|
2 |
+
|
3 |
+
on: workflow_dispatch
|
4 |
+
env:
|
5 |
+
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
windows:
|
9 |
+
runs-on: windows-2019
|
10 |
+
steps:
|
11 |
+
- name: Clone
|
12 |
+
id: checkout
|
13 |
+
uses: actions/checkout@v3
|
14 |
+
with:
|
15 |
+
ref: ${{ github.head_ref || github.ref_name }}
|
16 |
+
|
17 |
+
- name: Get Python
|
18 |
+
uses: actions/setup-python@v2
|
19 |
+
with:
|
20 |
+
python-version: 3.8.10
|
21 |
+
|
22 |
+
- name: Install python dependencies
|
23 |
+
run: |
|
24 |
+
python -m pip install --upgrade pip
|
25 |
+
pip install customtkinter==5.2.0 pyinstaller==5.11.0 psutil==5.9.5
|
26 |
+
|
27 |
+
- name: Download and install win64devkit
|
28 |
+
run: |
|
29 |
+
curl -L https://github.com/skeeto/w64devkit/releases/download/v1.22.0/w64devkit-1.22.0.zip --output w64devkit.zip
|
30 |
+
Expand-Archive w64devkit.zip -DestinationPath .
|
31 |
+
|
32 |
+
- name: Add w64devkit to PATH
|
33 |
+
run: |
|
34 |
+
echo "$(Get-Location)\w64devkit\bin" | Out-File -Append -FilePath $env:GITHUB_PATH -Encoding utf8
|
35 |
+
|
36 |
+
- name: Print System Environment Variables
|
37 |
+
id: printvars
|
38 |
+
run: |
|
39 |
+
echo "Number of processors: ${env:NUMBER_OF_PROCESSORS}"
|
40 |
+
echo "Processor Architecture: ${env:PROCESSOR_ARCHITECTURE}"
|
41 |
+
echo "Computer Name: ${env:COMPUTERNAME}"
|
42 |
+
wmic cpu get name
|
43 |
+
wmic os get TotalVisibleMemorySize, FreePhysicalMemory
|
44 |
+
|
45 |
+
- name: Rebuild Vulkan Shaders
|
46 |
+
id: make_vk_shaders
|
47 |
+
run: |
|
48 |
+
make vulkan_shaders_gen -j ${env:NUMBER_OF_PROCESSORS}
|
49 |
+
echo "Vulkan Shaders Rebuilt"
|
50 |
+
|
51 |
+
- name: Build Non-CUDA
|
52 |
+
id: make_build
|
53 |
+
run: |
|
54 |
+
make LLAMA_CLBLAST=1 LLAMA_VULKAN=1 LLAMA_PORTABLE=1 -j ${env:NUMBER_OF_PROCESSORS} LLAMA_NOAVX2=1
|
55 |
+
|
56 |
+
- uses: Jimver/[email protected]
|
57 |
+
id: cuda-toolkit
|
58 |
+
with:
|
59 |
+
cuda: '11.4.4'
|
60 |
+
|
61 |
+
- name: Build CUDA
|
62 |
+
id: cmake_build
|
63 |
+
run: |
|
64 |
+
mkdir build
|
65 |
+
cd build
|
66 |
+
cmake .. -DLLAMA_CUBLAS=ON -DLLAMA_AVX2=OFF -DCMAKE_SYSTEM_VERSION="10.0.19041.0"
|
67 |
+
cmake --build . --config Release -j 2
|
68 |
+
cd ..
|
69 |
+
|
70 |
+
# note: The libraries that come from the github cuda directory seem to be larger, so they are not recommended
|
71 |
+
- name: Download CuBLAS Libraries
|
72 |
+
run: |
|
73 |
+
curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublas64_11.dll --output cublas64_11.dll
|
74 |
+
curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublasLt64_11.dll --output cublasLt64_11.dll
|
75 |
+
ls
|
76 |
+
# - name: Copy CuBLAS Libraries
|
77 |
+
# run: |
|
78 |
+
# copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublasLt64_11.dll" .
|
79 |
+
# copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublas64_11.dll" .
|
80 |
+
# ls
|
81 |
+
|
82 |
+
- name: Package PyInstallers
|
83 |
+
id: make_pyinstaller
|
84 |
+
run: |
|
85 |
+
./make_pyinstaller_cuda_oldcpu.bat
|
86 |
+
|
87 |
+
- name: Save artifact
|
88 |
+
uses: actions/upload-artifact@v4
|
89 |
+
with:
|
90 |
+
name: kcpp_windows_pyinstallers
|
91 |
+
path: dist/
|
.gitignore
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.o
|
2 |
+
*.a
|
3 |
+
*.bin
|
4 |
+
.DS_Store
|
5 |
+
.build/
|
6 |
+
.cache/
|
7 |
+
.ccls-cache/
|
8 |
+
.direnv/
|
9 |
+
.envrc
|
10 |
+
.swiftpm
|
11 |
+
.venv
|
12 |
+
.clang-tidy
|
13 |
+
.vs/
|
14 |
+
.vscode/
|
15 |
+
|
16 |
+
ggml-metal-embed.metal
|
17 |
+
|
18 |
+
lcov-report/
|
19 |
+
gcovr-report/
|
20 |
+
|
21 |
+
build*/
|
22 |
+
out/
|
23 |
+
tmp/
|
24 |
+
autogen-*.md
|
25 |
+
|
26 |
+
models/*
|
27 |
+
models-mnt
|
28 |
+
|
29 |
+
/Pipfile
|
30 |
+
/baby-llama
|
31 |
+
/beam-search
|
32 |
+
/benchmark-matmult
|
33 |
+
/convert-llama2c-to-ggml
|
34 |
+
/embd-input-test
|
35 |
+
/embedding
|
36 |
+
/eval-callback
|
37 |
+
/gguf
|
38 |
+
/gguf-llama-simple
|
39 |
+
/gritlm
|
40 |
+
/imatrix
|
41 |
+
/infill
|
42 |
+
/libllama.so
|
43 |
+
/llama-bench
|
44 |
+
/llava-cli
|
45 |
+
/lookahead
|
46 |
+
/lookup
|
47 |
+
/main
|
48 |
+
/metal
|
49 |
+
/passkey
|
50 |
+
/perplexity
|
51 |
+
/q8dot
|
52 |
+
/quantize
|
53 |
+
/quantize-stats
|
54 |
+
/result
|
55 |
+
/save-load-state
|
56 |
+
/server
|
57 |
+
/simple
|
58 |
+
/batched
|
59 |
+
/batched-bench
|
60 |
+
/export-lora
|
61 |
+
/finetune
|
62 |
+
/speculative
|
63 |
+
/parallel
|
64 |
+
/train-text-from-scratch
|
65 |
+
/tokenize
|
66 |
+
/vdot
|
67 |
+
/common/build-info.cpp
|
68 |
+
arm_neon.h
|
69 |
+
compile_commands.json
|
70 |
+
CMakeSettings.json
|
71 |
+
|
72 |
+
__pycache__
|
73 |
+
dist
|
74 |
+
|
75 |
+
dist/
|
76 |
+
*.spec
|
77 |
+
|
78 |
+
zig-out/
|
79 |
+
zig-cache/
|
80 |
+
|
81 |
+
ppl-*.txt
|
82 |
+
qnt-*.txt
|
83 |
+
perf-*.txt
|
84 |
+
|
85 |
+
examples/jeopardy/results.txt
|
86 |
+
|
87 |
+
poetry.lock
|
88 |
+
poetry.toml
|
89 |
+
|
90 |
+
ggml-metal-merged.metal
|
91 |
+
|
92 |
+
# Test binaries
|
93 |
+
/tests/test-llama-grammar
|
94 |
+
tests/test-double-float
|
95 |
+
tests/test-grad0
|
96 |
+
tests/test-opt
|
97 |
+
tests/test-quantize-fns
|
98 |
+
tests/test-quantize-perf
|
99 |
+
tests/test-sampling
|
100 |
+
tests/test-tokenizer-0
|
101 |
+
tests/test-tokenizer-0-llama
|
102 |
+
tests/test-tokenizer-0-falcon
|
103 |
+
tests/test-tokenizer-1-llama
|
104 |
+
tests/test-tokenizer-1-bpe
|
105 |
+
/tests/test-rope
|
106 |
+
/tests/test-backend-ops
|
107 |
+
|
108 |
+
/koboldcpp_default.so
|
109 |
+
/koboldcpp_failsafe.so
|
110 |
+
/koboldcpp_noavx2.so
|
111 |
+
/koboldcpp_clblast.so
|
112 |
+
/koboldcpp_clblast_noavx2.so
|
113 |
+
/koboldcpp_clblast_failsafe.so
|
114 |
+
/koboldcpp_cublas.so
|
115 |
+
/koboldcpp_vulkan.so
|
116 |
+
/koboldcpp_vulkan_noavx2.so
|
117 |
+
/koboldcpp_default.dll
|
118 |
+
/koboldcpp_failsafe.dll
|
119 |
+
/koboldcpp_noavx2.dll
|
120 |
+
/koboldcpp_clblast.dll
|
121 |
+
/koboldcpp_clblast_noavx2.dll
|
122 |
+
/koboldcpp_vulkan_noavx2.dll
|
123 |
+
/koboldcpp_clblast_failsafe.dll
|
124 |
+
/koboldcpp_cublas.dll
|
125 |
+
/koboldcpp_vulkan.dll
|
126 |
+
/cublas64_11.dll
|
127 |
+
/cublasLt64_11.dll
|
128 |
+
/cublas64_12.dll
|
129 |
+
/cublasLt64_12.dll
|
130 |
+
/rocblas/
|
131 |
+
rocblas.dll
|
132 |
+
hipblas.dll
|
133 |
+
koboldcpp_hipblas.so
|
134 |
+
koboldcpp_hipblas.dll
|
135 |
+
|
136 |
+
bin/
|
137 |
+
conda/
|
138 |
+
|
139 |
+
# Jetbrains idea folder
|
140 |
+
.idea/
|
CLINFO_LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Windows binaries obtained from the clinfo repo fork here:
|
2 |
+
|
3 |
+
https://github.com/ahoylabs/clinfo/releases/tag/master-d2baa06
|
4 |
+
|
5 |
+
Source available here:
|
6 |
+
https://github.com/Oblomov/clinfo
|
7 |
+
|
8 |
+
see below LICENSE file for details on clinfo license
|
9 |
+
|
10 |
+
=======
|
11 |
+
|
12 |
+
clinfo by Giuseppe Bilotta
|
13 |
+
|
14 |
+
To the extent possible under law, the person who associated CC0 with
|
15 |
+
clinfo has waived all copyright and related or neighboring rights
|
16 |
+
to clinfo.
|
17 |
+
|
18 |
+
You should have received a copy of the CC0 legalcode along with this
|
19 |
+
work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>
|
CMakeLists.txt
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# THIS FILE IS ONLY INTENDED CUBLAS BUILD PURPOSES ON WINDOWS VISUAL STUDIO.
|
2 |
+
# YOU'RE NOT RECOMMENDED TO USE IT
|
3 |
+
|
4 |
+
message(STATUS "============== ============== ==============")
|
5 |
+
message(STATUS "WARNING! Recommend NOT to use this file. It is UNSUPPORTED for normal users. Use MAKE instead.")
|
6 |
+
message(STATUS "It is ONLY for CUBLAS builds on windows visual studio. IT WILL OVERWRITE YOUR EXISTING MAKEFILE !!!")
|
7 |
+
message(STATUS "IF YOU ARE SEEING THIS, you MUST ONLY be building CUBLAS BUILDS! NOTHING ELSE WILL BE SUPPORTED !!!")
|
8 |
+
message(STATUS "============== ============== ==============")
|
9 |
+
|
10 |
+
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
|
11 |
+
project("llama.cpp" C CXX)
|
12 |
+
|
13 |
+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
14 |
+
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS 1)
|
15 |
+
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
16 |
+
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Release")
|
17 |
+
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
18 |
+
set(LLAMA_STANDALONE ON)
|
19 |
+
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
20 |
+
set(LLAMA_STATIC OFF)
|
21 |
+
set(LLAMA_NATIVE OFF)
|
22 |
+
set(LLAMA_LTO OFF)
|
23 |
+
set(LLAMA_ALL_WARNINGS OFF)
|
24 |
+
set(LLAMA_ALL_WARNINGS_3RD_PARTY OFF)
|
25 |
+
set(LLAMA_GPROF OFF)
|
26 |
+
set(LLAMA_SANITIZE_THREAD OFF)
|
27 |
+
set(LLAMA_SANITIZE_ADDRESS OFF)
|
28 |
+
set(LLAMA_SANITIZE_UNDEFINED OFF)
|
29 |
+
|
30 |
+
|
31 |
+
# instruction set specific
|
32 |
+
option(LLAMA_AVX "llama: enable AVX" ON)
|
33 |
+
option(LLAMA_AVX2 "llama: enable AVX2" ON)
|
34 |
+
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
35 |
+
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
36 |
+
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
37 |
+
option(LLAMA_FMA "llama: enable FMA" ON)
|
38 |
+
# in MSVC F16C is implied with AVX2/AVX512
|
39 |
+
if (NOT MSVC)
|
40 |
+
option(LLAMA_F16C "llama: enable F16C" ON)
|
41 |
+
endif()
|
42 |
+
|
43 |
+
# 3rd party libs
|
44 |
+
option(LLAMA_CUBLAS "llama: use CUDA" ON)
|
45 |
+
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
|
46 |
+
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
47 |
+
"llama: max. batch size for using peer access")
|
48 |
+
|
49 |
+
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
|
50 |
+
|
51 |
+
# Other
|
52 |
+
option(LLAMA_OPENMP "llama: use OpenMP" OFF)
|
53 |
+
|
54 |
+
#
|
55 |
+
# Compile flags
|
56 |
+
#
|
57 |
+
|
58 |
+
set(CMAKE_CXX_STANDARD 17)
|
59 |
+
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
60 |
+
set(CMAKE_C_STANDARD 11)
|
61 |
+
set(CMAKE_C_STANDARD_REQUIRED true)
|
62 |
+
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
63 |
+
find_package(Threads REQUIRED)
|
64 |
+
|
65 |
+
add_compile_definitions(LOG_DISABLE_LOGS)
|
66 |
+
add_compile_definitions(GGML_USE_CPU)
|
67 |
+
add_compile_definitions(GGML_USE_CPU_AARCH64)
|
68 |
+
|
69 |
+
if (MSVC)
|
70 |
+
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
|
71 |
+
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
|
72 |
+
add_compile_options("$<$<COMPILE_LANGUAGE:C>:/bigobj>")
|
73 |
+
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
|
74 |
+
endif()
|
75 |
+
|
76 |
+
file(GLOB GGML_SOURCES_CUDA "ggml/src/ggml-cuda/*.cu")
|
77 |
+
list(APPEND GGML_SOURCES_CUDA "ggml/src/ggml-cuda/ggml-cuda.cu")
|
78 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-mma*.cu")
|
79 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
80 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
|
81 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
82 |
+
set(GGML_V3_CUDA_SOURCES otherarch/ggml_v3-cuda.cu otherarch/ggml_v3-cuda.h)
|
83 |
+
set(GGML_V2_CUDA_SOURCES otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h)
|
84 |
+
set(GGML_V2_LEGACY_CUDA_SOURCES otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h)
|
85 |
+
|
86 |
+
|
87 |
+
if (LLAMA_CUBLAS)
|
88 |
+
cmake_minimum_required(VERSION 3.17)
|
89 |
+
|
90 |
+
find_package(CUDAToolkit)
|
91 |
+
if (CUDAToolkit_FOUND)
|
92 |
+
message(STATUS "cuBLAS found")
|
93 |
+
|
94 |
+
enable_language(CUDA)
|
95 |
+
|
96 |
+
add_compile_definitions(GGML_USE_LLAMAFILE)
|
97 |
+
add_compile_definitions(GGML_USE_CUDA)
|
98 |
+
add_compile_definitions(SD_USE_CUBLAS)
|
99 |
+
|
100 |
+
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
|
101 |
+
add_compile_definitions(GGML_CUDA_F16)
|
102 |
+
endif()
|
103 |
+
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
|
104 |
+
|
105 |
+
# only build minimal quants required for fattn quant kv
|
106 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
|
107 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
108 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
|
109 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
110 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
|
111 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
112 |
+
|
113 |
+
if (LLAMA_STATIC)
|
114 |
+
if (WIN32)
|
115 |
+
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
|
116 |
+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
117 |
+
else ()
|
118 |
+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
119 |
+
endif()
|
120 |
+
else()
|
121 |
+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
122 |
+
endif()
|
123 |
+
|
124 |
+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
|
125 |
+
|
126 |
+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
127 |
+
# 50 == lowest CUDA 12 standard
|
128 |
+
# 60 == f16 CUDA intrinsics
|
129 |
+
# 61 == integer CUDA intrinsics
|
130 |
+
# 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
|
131 |
+
# 75 == int8 tensor cores
|
132 |
+
if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
|
133 |
+
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") # needed for f16 CUDA intrinsics
|
134 |
+
else()
|
135 |
+
message("CUDA Toolkit Version: ${CUDAToolkit_VERSION}")
|
136 |
+
if(CUDAToolkit_VERSION VERSION_GREATER 12)
|
137 |
+
add_compile_definitions(GGML_CUDA_USE_GRAPHS) #try enable cuda graphs on cu12 build
|
138 |
+
set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
|
139 |
+
else()
|
140 |
+
set(CMAKE_CUDA_ARCHITECTURES "37;50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
|
141 |
+
endif()
|
142 |
+
endif()
|
143 |
+
endif()
|
144 |
+
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
145 |
+
|
146 |
+
else()
|
147 |
+
message(WARNING "cuBLAS not found")
|
148 |
+
endif()
|
149 |
+
endif()
|
150 |
+
|
151 |
+
if (LLAMA_HIPBLAS)
|
152 |
+
if (MSVC)
|
153 |
+
list(APPEND CMAKE_PREFIX_PATH "C:/Program Files/AMD/ROCm/5.5")
|
154 |
+
else()
|
155 |
+
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
156 |
+
endif()
|
157 |
+
|
158 |
+
|
159 |
+
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
|
160 |
+
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
|
161 |
+
endif()
|
162 |
+
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
|
163 |
+
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
|
164 |
+
endif()
|
165 |
+
|
166 |
+
find_package(hip)
|
167 |
+
find_package(hipblas)
|
168 |
+
find_package(rocblas)
|
169 |
+
|
170 |
+
if (${hipblas_FOUND} AND ${hip_FOUND})
|
171 |
+
message(STATUS "HIP and hipBLAS found")
|
172 |
+
file(GLOB GGML_SOURCES_ROCM "ggml/src/ggml-cuda/*.cu")
|
173 |
+
list(APPEND GGML_SOURCES_ROCM "ggml/src/ggml-cuda/ggml-cuda.cu")
|
174 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-mma*.cu")
|
175 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
176 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
|
177 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
178 |
+
add_compile_definitions(GGML_USE_HIP GGML_USE_CUDA SD_USE_CUBLAS)
|
179 |
+
add_library(ggml-rocm ${GGML_SOURCES_CUDA})
|
180 |
+
|
181 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
|
182 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
183 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
|
184 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
185 |
+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
|
186 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
187 |
+
|
188 |
+
# only build minimal quants required for fattn quant kv
|
189 |
+
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
190 |
+
target_link_libraries(ggml-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
|
191 |
+
|
192 |
+
add_library(ggml-v2-rocm ${GGML_V2_CUDA_SOURCES})
|
193 |
+
set_source_files_properties(otherarch/ggml_v2-cuda.cu PROPERTIES LANGUAGE CXX)
|
194 |
+
target_link_libraries(ggml-v2-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
|
195 |
+
|
196 |
+
add_library(ggml-v3-rocm ${GGML_V3_CUDA_SOURCES})
|
197 |
+
set_source_files_properties(otherarch/ggml_v3-cuda.cu PROPERTIES LANGUAGE CXX)
|
198 |
+
target_link_libraries(ggml-v3-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
|
199 |
+
|
200 |
+
add_library(ggml-v2-legacy-rocm ${GGML_V2_LEGACY_CUDA_SOURCES})
|
201 |
+
set_source_files_properties(otherarch/ggml_v2-cuda-legacy.cu PROPERTIES LANGUAGE CXX)
|
202 |
+
target_link_libraries(ggml-v2-legacy-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
|
203 |
+
|
204 |
+
if (LLAMA_STATIC)
|
205 |
+
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
|
206 |
+
endif()
|
207 |
+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm ggml-v2-rocm ggml-v2-legacy-rocm)
|
208 |
+
else()
|
209 |
+
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
|
210 |
+
endif()
|
211 |
+
endif()
|
212 |
+
|
213 |
+
if (LLAMA_ALL_WARNINGS)
|
214 |
+
if (NOT MSVC)
|
215 |
+
set(warning_flags -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
|
216 |
+
set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int
|
217 |
+
-Werror=implicit-function-declaration)
|
218 |
+
set(cxx_flags -Wmissing-declarations -Wmissing-noreturn)
|
219 |
+
|
220 |
+
if (CMAKE_C_COMPILER_ID MATCHES "Clang")
|
221 |
+
set(warning_flags ${warning_flags} -Wunreachable-code-break -Wunreachable-code-return)
|
222 |
+
set(cxx_flags ${cxx_flags} -Wmissing-prototypes -Wextra-semi)
|
223 |
+
|
224 |
+
if (
|
225 |
+
(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 3.8.0) OR
|
226 |
+
(CMAKE_C_COMPILER_ID STREQUAL "AppleClang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 7.3.0)
|
227 |
+
)
|
228 |
+
set(c_flags ${c_flags} -Wdouble-promotion)
|
229 |
+
endif()
|
230 |
+
elseif (CMAKE_C_COMPILER_ID STREQUAL "GNU")
|
231 |
+
set(c_flags ${c_flags} -Wdouble-promotion)
|
232 |
+
set(cxx_flags ${cxx_flags} -Wno-array-bounds)
|
233 |
+
|
234 |
+
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 7.1.0)
|
235 |
+
set(cxx_flags ${cxx_flags} -Wno-format-truncation)
|
236 |
+
endif()
|
237 |
+
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1.0)
|
238 |
+
set(cxx_flags ${cxx_flags} -Wextra-semi)
|
239 |
+
endif()
|
240 |
+
endif()
|
241 |
+
else()
|
242 |
+
# todo : msvc
|
243 |
+
endif()
|
244 |
+
|
245 |
+
add_compile_options(
|
246 |
+
${warning_flags}
|
247 |
+
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
|
248 |
+
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
|
249 |
+
)
|
250 |
+
|
251 |
+
endif()
|
252 |
+
|
253 |
+
if (WIN32)
|
254 |
+
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
255 |
+
|
256 |
+
if (BUILD_SHARED_LIBS)
|
257 |
+
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
258 |
+
endif()
|
259 |
+
endif()
|
260 |
+
|
261 |
+
if (LLAMA_LTO)
|
262 |
+
include(CheckIPOSupported)
|
263 |
+
check_ipo_supported(RESULT result OUTPUT output)
|
264 |
+
if (result)
|
265 |
+
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
|
266 |
+
else()
|
267 |
+
message(WARNING "IPO is not supported: ${output}")
|
268 |
+
endif()
|
269 |
+
endif()
|
270 |
+
|
271 |
+
if (LLAMA_OPENMP)
|
272 |
+
find_package(OpenMP)
|
273 |
+
if (OpenMP_FOUND)
|
274 |
+
message(STATUS "OpenMP found")
|
275 |
+
add_compile_definitions(GGML_USE_OPENMP)
|
276 |
+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
277 |
+
else()
|
278 |
+
message(WARNING "OpenMP not found")
|
279 |
+
endif()
|
280 |
+
endif()
|
281 |
+
|
282 |
+
# this version of Apple ld64 is buggy
|
283 |
+
execute_process(
|
284 |
+
COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
|
285 |
+
ERROR_VARIABLE output
|
286 |
+
)
|
287 |
+
if (output MATCHES "dyld-1015\.7")
|
288 |
+
add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
|
289 |
+
endif()
|
290 |
+
|
291 |
+
# Architecture specific
|
292 |
+
# TODO: probably these flags need to be tweaked on some architectures
|
293 |
+
# feel free to update the Makefile for your architecture and send a pull request or issue
|
294 |
+
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
295 |
+
if (NOT MSVC)
|
296 |
+
if (LLAMA_STATIC)
|
297 |
+
add_link_options(-static)
|
298 |
+
if (MINGW)
|
299 |
+
add_link_options(-static-libgcc -static-libstdc++)
|
300 |
+
endif()
|
301 |
+
endif()
|
302 |
+
if (LLAMA_GPROF)
|
303 |
+
add_compile_options(-pg)
|
304 |
+
endif()
|
305 |
+
if (LLAMA_NATIVE)
|
306 |
+
add_compile_options(-march=native)
|
307 |
+
endif()
|
308 |
+
endif()
|
309 |
+
|
310 |
+
if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64"))
|
311 |
+
message(STATUS "ARM detected")
|
312 |
+
if (MSVC)
|
313 |
+
# TODO: arm msvc?
|
314 |
+
else()
|
315 |
+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
|
316 |
+
# Raspberry Pi 1, Zero
|
317 |
+
add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access)
|
318 |
+
endif()
|
319 |
+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
|
320 |
+
# Raspberry Pi 2
|
321 |
+
add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations)
|
322 |
+
endif()
|
323 |
+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
|
324 |
+
# Raspberry Pi 3, 4, Zero 2 (32-bit)
|
325 |
+
add_compile_options(-mfp16-format=ieee -mno-unaligned-access)
|
326 |
+
endif()
|
327 |
+
endif()
|
328 |
+
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
|
329 |
+
message(STATUS "x86 detected")
|
330 |
+
if (MSVC)
|
331 |
+
if (LLAMA_AVX512)
|
332 |
+
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
|
333 |
+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
|
334 |
+
# MSVC has no compile-time flags enabling specific
|
335 |
+
# AVX512 extensions, neither it defines the
|
336 |
+
# macros corresponding to the extensions.
|
337 |
+
# Do it manually.
|
338 |
+
if (LLAMA_AVX512_VBMI)
|
339 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
|
340 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
|
341 |
+
endif()
|
342 |
+
if (LLAMA_AVX512_VNNI)
|
343 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
344 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
345 |
+
endif()
|
346 |
+
elseif (LLAMA_AVX2)
|
347 |
+
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
|
348 |
+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
|
349 |
+
elseif (LLAMA_AVX)
|
350 |
+
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
|
351 |
+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
|
352 |
+
endif()
|
353 |
+
else()
|
354 |
+
if (LLAMA_F16C)
|
355 |
+
add_compile_options(-mf16c)
|
356 |
+
endif()
|
357 |
+
if (LLAMA_FMA)
|
358 |
+
add_compile_options(-mfma)
|
359 |
+
endif()
|
360 |
+
if (LLAMA_AVX)
|
361 |
+
add_compile_options(-mavx)
|
362 |
+
endif()
|
363 |
+
if (LLAMA_AVX2)
|
364 |
+
add_compile_options(-mavx2)
|
365 |
+
endif()
|
366 |
+
if (LLAMA_AVX512)
|
367 |
+
add_compile_options(-mavx512f)
|
368 |
+
add_compile_options(-mavx512bw)
|
369 |
+
endif()
|
370 |
+
if (LLAMA_AVX512_VBMI)
|
371 |
+
add_compile_options(-mavx512vbmi)
|
372 |
+
endif()
|
373 |
+
if (LLAMA_AVX512_VNNI)
|
374 |
+
add_compile_options(-mavx512vnni)
|
375 |
+
endif()
|
376 |
+
endif()
|
377 |
+
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
378 |
+
message(STATUS "PowerPC detected")
|
379 |
+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
380 |
+
add_compile_options(-mcpu=powerpc64le)
|
381 |
+
else()
|
382 |
+
add_compile_options(-mcpu=native -mtune=native)
|
383 |
+
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
|
384 |
+
endif()
|
385 |
+
else()
|
386 |
+
message(STATUS "Unknown architecture")
|
387 |
+
endif()
|
388 |
+
|
389 |
+
if (MINGW)
|
390 |
+
# Target Windows 8 for PrefetchVirtualMemory
|
391 |
+
add_compile_definitions(_WIN32_WINNT=0x602)
|
392 |
+
endif()
|
393 |
+
|
394 |
+
#
|
395 |
+
# Build libraries
|
396 |
+
#
|
397 |
+
|
398 |
+
add_library(ggml
|
399 |
+
ggml/src/ggml.c
|
400 |
+
ggml/include/ggml.h
|
401 |
+
ggml/src/ggml-cpu/ggml-cpu.c
|
402 |
+
ggml/include/ggml-cpu.h
|
403 |
+
ggml/src/ggml-alloc.c
|
404 |
+
ggml/include/ggml-alloc.h
|
405 |
+
ggml/src/ggml-backend.cpp
|
406 |
+
ggml/src/ggml-backend-impl.h
|
407 |
+
ggml/include/ggml-backend.h
|
408 |
+
ggml/include/ggml-cpp.h
|
409 |
+
ggml/src/ggml-quants.c
|
410 |
+
ggml/src/ggml-quants.h
|
411 |
+
ggml/src/ggml-cpu/llamafile/sgemm.cpp
|
412 |
+
ggml/src/ggml-cpu/llamafile/sgemm.h
|
413 |
+
ggml/src/ggml-cpu/ggml-cpu-traits.cpp
|
414 |
+
ggml/src/ggml-cpu/ggml-cpu-traits.h
|
415 |
+
ggml/src/ggml-threading.cpp
|
416 |
+
ggml/src/ggml-cpu/ggml-cpu.cpp
|
417 |
+
ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
|
418 |
+
ggml/src/ggml-cpu/ggml-cpu-aarch64.h
|
419 |
+
ggml/src/ggml-cpu/ggml-cpu-quants.c
|
420 |
+
ggml/src/ggml-cpu/ggml-cpu-quants.h
|
421 |
+
ggml/src/ggml-backend-reg.cpp
|
422 |
+
ggml/include/gguf.h
|
423 |
+
ggml/src/gguf.cpp
|
424 |
+
${GGML_SOURCES_CUDA})
|
425 |
+
target_include_directories(ggml PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
|
426 |
+
target_compile_features(ggml PUBLIC c_std_11) # don't bump
|
427 |
+
target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
428 |
+
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
429 |
+
|
430 |
+
add_library(ggml_v1
|
431 |
+
otherarch/ggml_v1.c
|
432 |
+
otherarch/ggml_v1.h)
|
433 |
+
target_include_directories(ggml_v1 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
|
434 |
+
target_compile_features(ggml_v1 PUBLIC c_std_11) # don't bump
|
435 |
+
target_link_libraries(ggml_v1 PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
436 |
+
set_target_properties(ggml_v1 PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
437 |
+
|
438 |
+
add_library(ggml_v2
|
439 |
+
otherarch/ggml_v2.c
|
440 |
+
otherarch/ggml_v2.h
|
441 |
+
${GGML_V2_CUDA_SOURCES}
|
442 |
+
${GGML_V2_LEGACY_CUDA_SOURCES})
|
443 |
+
target_include_directories(ggml_v2 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
|
444 |
+
target_compile_features(ggml_v2 PUBLIC c_std_11) # don't bump
|
445 |
+
target_link_libraries(ggml_v2 PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
446 |
+
set_target_properties(ggml_v2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
447 |
+
|
448 |
+
add_library(ggml_v3
|
449 |
+
otherarch/ggml_v3.c
|
450 |
+
otherarch/ggml_v3.h
|
451 |
+
${GGML_V3_CUDA_SOURCES})
|
452 |
+
target_include_directories(ggml_v3 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
|
453 |
+
target_compile_features(ggml_v3 PUBLIC c_std_11) # don't bump
|
454 |
+
target_link_libraries(ggml_v3 PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
455 |
+
set_target_properties(ggml_v3 PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
456 |
+
|
457 |
+
add_library(common2
|
458 |
+
common/common.cpp
|
459 |
+
common/common.h
|
460 |
+
common/sampling.cpp
|
461 |
+
common/sampling.h
|
462 |
+
examples/llava/llava.cpp
|
463 |
+
examples/llava/llava.h
|
464 |
+
examples/llava/clip.cpp
|
465 |
+
examples/llava/clip.h
|
466 |
+
src/unicode.h
|
467 |
+
src/unicode.cpp
|
468 |
+
src/unicode-data.cpp
|
469 |
+
otherarch/utils.cpp
|
470 |
+
otherarch/utils.h)
|
471 |
+
target_include_directories(common2 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
|
472 |
+
target_compile_features(common2 PUBLIC cxx_std_17) # don't bump
|
473 |
+
target_link_libraries(common2 PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
474 |
+
set_target_properties(common2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
475 |
+
|
476 |
+
add_library(sdtype_adapter
|
477 |
+
otherarch/sdcpp/sdtype_adapter.cpp)
|
478 |
+
target_include_directories(sdtype_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
|
479 |
+
target_compile_features(sdtype_adapter PUBLIC cxx_std_17) # don't bump
|
480 |
+
target_link_libraries(sdtype_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
|
481 |
+
set_target_properties(sdtype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
482 |
+
|
483 |
+
add_library(whisper_adapter
|
484 |
+
otherarch/whispercpp/whisper_adapter.cpp)
|
485 |
+
target_include_directories(whisper_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/whispercpp ./examples ./common)
|
486 |
+
target_compile_features(whisper_adapter PUBLIC cxx_std_17) # don't bump
|
487 |
+
target_link_libraries(whisper_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
|
488 |
+
set_target_properties(whisper_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
489 |
+
|
490 |
+
add_library(tts_adapter
|
491 |
+
otherarch/tts_adapter.cpp)
|
492 |
+
target_include_directories(tts_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./examples ./common)
|
493 |
+
target_compile_features(tts_adapter PUBLIC cxx_std_17) # don't bump
|
494 |
+
target_link_libraries(tts_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
|
495 |
+
set_target_properties(tts_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
496 |
+
|
497 |
+
add_library(gpttype_adapter
|
498 |
+
gpttype_adapter.cpp)
|
499 |
+
target_include_directories(gpttype_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
|
500 |
+
target_compile_features(gpttype_adapter PUBLIC cxx_std_17) # don't bump
|
501 |
+
target_link_libraries(gpttype_adapter PRIVATE common2 ggml ggml_v1 ggml_v2 ggml_v3 ${LLAMA_EXTRA_LIBS})
|
502 |
+
set_target_properties(gpttype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
503 |
+
|
504 |
+
if (LLAMA_CUBLAS)
|
505 |
+
set(TARGET koboldcpp_cublas)
|
506 |
+
add_library(${TARGET} SHARED expose.cpp expose.h)
|
507 |
+
target_include_directories(${TARGET} PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
|
508 |
+
target_compile_features(${TARGET} PUBLIC cxx_std_17) # don't bump
|
509 |
+
set_target_properties(${TARGET} PROPERTIES PREFIX "")
|
510 |
+
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_cublas")
|
511 |
+
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
512 |
+
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter tts_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
|
513 |
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
514 |
+
|
515 |
+
add_custom_command(
|
516 |
+
TARGET koboldcpp_cublas POST_BUILD
|
517 |
+
COMMAND ${CMAKE_COMMAND} -E copy
|
518 |
+
$<TARGET_FILE:koboldcpp_cublas> # The generated DLL
|
519 |
+
${CMAKE_SOURCE_DIR}/ # Destination directory
|
520 |
+
COMMENT "Copying DLL to parent directory"
|
521 |
+
)
|
522 |
+
endif()
|
523 |
+
|
524 |
+
if (LLAMA_HIPBLAS)
|
525 |
+
set(TARGET koboldcpp_hipblas)
|
526 |
+
add_library(${TARGET} SHARED expose.cpp expose.h)
|
527 |
+
target_include_directories(${TARGET} PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
|
528 |
+
target_compile_features(${TARGET} PUBLIC cxx_std_17) # don't bump
|
529 |
+
set_target_properties(${TARGET} PROPERTIES PREFIX "")
|
530 |
+
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_hipblas")
|
531 |
+
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
532 |
+
target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter tts_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
|
533 |
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
534 |
+
|
535 |
+
add_custom_command(
|
536 |
+
TARGET koboldcpp_hipblas POST_BUILD
|
537 |
+
COMMAND ${CMAKE_COMMAND} -E copy
|
538 |
+
$<TARGET_FILE:koboldcpp_hipblas> # The generated DLL
|
539 |
+
${CMAKE_SOURCE_DIR}/ # Destination directory
|
540 |
+
COMMENT "Copying DLL to parent directory"
|
541 |
+
)
|
542 |
+
endif()
|
543 |
+
|
LICENSE.md
ADDED
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU AFFERO GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 19 November 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU Affero General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works, specifically designed to ensure
|
12 |
+
cooperation with the community in the case of network server software.
|
13 |
+
|
14 |
+
The licenses for most software and other practical works are designed
|
15 |
+
to take away your freedom to share and change the works. By contrast,
|
16 |
+
our General Public Licenses are intended to guarantee your freedom to
|
17 |
+
share and change all versions of a program--to make sure it remains free
|
18 |
+
software for all its users.
|
19 |
+
|
20 |
+
When we speak of free software, we are referring to freedom, not
|
21 |
+
price. Our General Public Licenses are designed to make sure that you
|
22 |
+
have the freedom to distribute copies of free software (and charge for
|
23 |
+
them if you wish), that you receive source code or can get it if you
|
24 |
+
want it, that you can change the software or use pieces of it in new
|
25 |
+
free programs, and that you know you can do these things.
|
26 |
+
|
27 |
+
Developers that use our General Public Licenses protect your rights
|
28 |
+
with two steps: (1) assert copyright on the software, and (2) offer
|
29 |
+
you this License which gives you legal permission to copy, distribute
|
30 |
+
and/or modify the software.
|
31 |
+
|
32 |
+
A secondary benefit of defending all users' freedom is that
|
33 |
+
improvements made in alternate versions of the program, if they
|
34 |
+
receive widespread use, become available for other developers to
|
35 |
+
incorporate. Many developers of free software are heartened and
|
36 |
+
encouraged by the resulting cooperation. However, in the case of
|
37 |
+
software used on network servers, this result may fail to come about.
|
38 |
+
The GNU General Public License permits making a modified version and
|
39 |
+
letting the public access it on a server without ever releasing its
|
40 |
+
source code to the public.
|
41 |
+
|
42 |
+
The GNU Affero General Public License is designed specifically to
|
43 |
+
ensure that, in such cases, the modified source code becomes available
|
44 |
+
to the community. It requires the operator of a network server to
|
45 |
+
provide the source code of the modified version running there to the
|
46 |
+
users of that server. Therefore, public use of a modified version, on
|
47 |
+
a publicly accessible server, gives the public access to the source
|
48 |
+
code of the modified version.
|
49 |
+
|
50 |
+
An older license, called the Affero General Public License and
|
51 |
+
published by Affero, was designed to accomplish similar goals. This is
|
52 |
+
a different license, not a version of the Affero GPL, but Affero has
|
53 |
+
released a new version of the Affero GPL which permits relicensing under
|
54 |
+
this license.
|
55 |
+
|
56 |
+
The precise terms and conditions for copying, distribution and
|
57 |
+
modification follow.
|
58 |
+
|
59 |
+
TERMS AND CONDITIONS
|
60 |
+
|
61 |
+
0. Definitions.
|
62 |
+
|
63 |
+
"This License" refers to version 3 of the GNU Affero General Public License.
|
64 |
+
|
65 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
66 |
+
works, such as semiconductor masks.
|
67 |
+
|
68 |
+
"The Program" refers to any copyrightable work licensed under this
|
69 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
70 |
+
"recipients" may be individuals or organizations.
|
71 |
+
|
72 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
73 |
+
in a fashion requiring copyright permission, other than the making of an
|
74 |
+
exact copy. The resulting work is called a "modified version" of the
|
75 |
+
earlier work or a work "based on" the earlier work.
|
76 |
+
|
77 |
+
A "covered work" means either the unmodified Program or a work based
|
78 |
+
on the Program.
|
79 |
+
|
80 |
+
To "propagate" a work means to do anything with it that, without
|
81 |
+
permission, would make you directly or secondarily liable for
|
82 |
+
infringement under applicable copyright law, except executing it on a
|
83 |
+
computer or modifying a private copy. Propagation includes copying,
|
84 |
+
distribution (with or without modification), making available to the
|
85 |
+
public, and in some countries other activities as well.
|
86 |
+
|
87 |
+
To "convey" a work means any kind of propagation that enables other
|
88 |
+
parties to make or receive copies. Mere interaction with a user through
|
89 |
+
a computer network, with no transfer of a copy, is not conveying.
|
90 |
+
|
91 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
92 |
+
to the extent that it includes a convenient and prominently visible
|
93 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
94 |
+
tells the user that there is no warranty for the work (except to the
|
95 |
+
extent that warranties are provided), that licensees may convey the
|
96 |
+
work under this License, and how to view a copy of this License. If
|
97 |
+
the interface presents a list of user commands or options, such as a
|
98 |
+
menu, a prominent item in the list meets this criterion.
|
99 |
+
|
100 |
+
1. Source Code.
|
101 |
+
|
102 |
+
The "source code" for a work means the preferred form of the work
|
103 |
+
for making modifications to it. "Object code" means any non-source
|
104 |
+
form of a work.
|
105 |
+
|
106 |
+
A "Standard Interface" means an interface that either is an official
|
107 |
+
standard defined by a recognized standards body, or, in the case of
|
108 |
+
interfaces specified for a particular programming language, one that
|
109 |
+
is widely used among developers working in that language.
|
110 |
+
|
111 |
+
The "System Libraries" of an executable work include anything, other
|
112 |
+
than the work as a whole, that (a) is included in the normal form of
|
113 |
+
packaging a Major Component, but which is not part of that Major
|
114 |
+
Component, and (b) serves only to enable use of the work with that
|
115 |
+
Major Component, or to implement a Standard Interface for which an
|
116 |
+
implementation is available to the public in source code form. A
|
117 |
+
"Major Component", in this context, means a major essential component
|
118 |
+
(kernel, window system, and so on) of the specific operating system
|
119 |
+
(if any) on which the executable work runs, or a compiler used to
|
120 |
+
produce the work, or an object code interpreter used to run it.
|
121 |
+
|
122 |
+
The "Corresponding Source" for a work in object code form means all
|
123 |
+
the source code needed to generate, install, and (for an executable
|
124 |
+
work) run the object code and to modify the work, including scripts to
|
125 |
+
control those activities. However, it does not include the work's
|
126 |
+
System Libraries, or general-purpose tools or generally available free
|
127 |
+
programs which are used unmodified in performing those activities but
|
128 |
+
which are not part of the work. For example, Corresponding Source
|
129 |
+
includes interface definition files associated with source files for
|
130 |
+
the work, and the source code for shared libraries and dynamically
|
131 |
+
linked subprograms that the work is specifically designed to require,
|
132 |
+
such as by intimate data communication or control flow between those
|
133 |
+
subprograms and other parts of the work.
|
134 |
+
|
135 |
+
The Corresponding Source need not include anything that users
|
136 |
+
can regenerate automatically from other parts of the Corresponding
|
137 |
+
Source.
|
138 |
+
|
139 |
+
The Corresponding Source for a work in source code form is that
|
140 |
+
same work.
|
141 |
+
|
142 |
+
2. Basic Permissions.
|
143 |
+
|
144 |
+
All rights granted under this License are granted for the term of
|
145 |
+
copyright on the Program, and are irrevocable provided the stated
|
146 |
+
conditions are met. This License explicitly affirms your unlimited
|
147 |
+
permission to run the unmodified Program. The output from running a
|
148 |
+
covered work is covered by this License only if the output, given its
|
149 |
+
content, constitutes a covered work. This License acknowledges your
|
150 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
151 |
+
|
152 |
+
You may make, run and propagate covered works that you do not
|
153 |
+
convey, without conditions so long as your license otherwise remains
|
154 |
+
in force. You may convey covered works to others for the sole purpose
|
155 |
+
of having them make modifications exclusively for you, or provide you
|
156 |
+
with facilities for running those works, provided that you comply with
|
157 |
+
the terms of this License in conveying all material for which you do
|
158 |
+
not control copyright. Those thus making or running the covered works
|
159 |
+
for you must do so exclusively on your behalf, under your direction
|
160 |
+
and control, on terms that prohibit them from making any copies of
|
161 |
+
your copyrighted material outside their relationship with you.
|
162 |
+
|
163 |
+
Conveying under any other circumstances is permitted solely under
|
164 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
165 |
+
makes it unnecessary.
|
166 |
+
|
167 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
168 |
+
|
169 |
+
No covered work shall be deemed part of an effective technological
|
170 |
+
measure under any applicable law fulfilling obligations under article
|
171 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
172 |
+
similar laws prohibiting or restricting circumvention of such
|
173 |
+
measures.
|
174 |
+
|
175 |
+
When you convey a covered work, you waive any legal power to forbid
|
176 |
+
circumvention of technological measures to the extent such circumvention
|
177 |
+
is effected by exercising rights under this License with respect to
|
178 |
+
the covered work, and you disclaim any intention to limit operation or
|
179 |
+
modification of the work as a means of enforcing, against the work's
|
180 |
+
users, your or third parties' legal rights to forbid circumvention of
|
181 |
+
technological measures.
|
182 |
+
|
183 |
+
4. Conveying Verbatim Copies.
|
184 |
+
|
185 |
+
You may convey verbatim copies of the Program's source code as you
|
186 |
+
receive it, in any medium, provided that you conspicuously and
|
187 |
+
appropriately publish on each copy an appropriate copyright notice;
|
188 |
+
keep intact all notices stating that this License and any
|
189 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
190 |
+
keep intact all notices of the absence of any warranty; and give all
|
191 |
+
recipients a copy of this License along with the Program.
|
192 |
+
|
193 |
+
You may charge any price or no price for each copy that you convey,
|
194 |
+
and you may offer support or warranty protection for a fee.
|
195 |
+
|
196 |
+
5. Conveying Modified Source Versions.
|
197 |
+
|
198 |
+
You may convey a work based on the Program, or the modifications to
|
199 |
+
produce it from the Program, in the form of source code under the
|
200 |
+
terms of section 4, provided that you also meet all of these conditions:
|
201 |
+
|
202 |
+
a) The work must carry prominent notices stating that you modified
|
203 |
+
it, and giving a relevant date.
|
204 |
+
|
205 |
+
b) The work must carry prominent notices stating that it is
|
206 |
+
released under this License and any conditions added under section
|
207 |
+
7. This requirement modifies the requirement in section 4 to
|
208 |
+
"keep intact all notices".
|
209 |
+
|
210 |
+
c) You must license the entire work, as a whole, under this
|
211 |
+
License to anyone who comes into possession of a copy. This
|
212 |
+
License will therefore apply, along with any applicable section 7
|
213 |
+
additional terms, to the whole of the work, and all its parts,
|
214 |
+
regardless of how they are packaged. This License gives no
|
215 |
+
permission to license the work in any other way, but it does not
|
216 |
+
invalidate such permission if you have separately received it.
|
217 |
+
|
218 |
+
d) If the work has interactive user interfaces, each must display
|
219 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
220 |
+
interfaces that do not display Appropriate Legal Notices, your
|
221 |
+
work need not make them do so.
|
222 |
+
|
223 |
+
A compilation of a covered work with other separate and independent
|
224 |
+
works, which are not by their nature extensions of the covered work,
|
225 |
+
and which are not combined with it such as to form a larger program,
|
226 |
+
in or on a volume of a storage or distribution medium, is called an
|
227 |
+
"aggregate" if the compilation and its resulting copyright are not
|
228 |
+
used to limit the access or legal rights of the compilation's users
|
229 |
+
beyond what the individual works permit. Inclusion of a covered work
|
230 |
+
in an aggregate does not cause this License to apply to the other
|
231 |
+
parts of the aggregate.
|
232 |
+
|
233 |
+
6. Conveying Non-Source Forms.
|
234 |
+
|
235 |
+
You may convey a covered work in object code form under the terms
|
236 |
+
of sections 4 and 5, provided that you also convey the
|
237 |
+
machine-readable Corresponding Source under the terms of this License,
|
238 |
+
in one of these ways:
|
239 |
+
|
240 |
+
a) Convey the object code in, or embodied in, a physical product
|
241 |
+
(including a physical distribution medium), accompanied by the
|
242 |
+
Corresponding Source fixed on a durable physical medium
|
243 |
+
customarily used for software interchange.
|
244 |
+
|
245 |
+
b) Convey the object code in, or embodied in, a physical product
|
246 |
+
(including a physical distribution medium), accompanied by a
|
247 |
+
written offer, valid for at least three years and valid for as
|
248 |
+
long as you offer spare parts or customer support for that product
|
249 |
+
model, to give anyone who possesses the object code either (1) a
|
250 |
+
copy of the Corresponding Source for all the software in the
|
251 |
+
product that is covered by this License, on a durable physical
|
252 |
+
medium customarily used for software interchange, for a price no
|
253 |
+
more than your reasonable cost of physically performing this
|
254 |
+
conveying of source, or (2) access to copy the
|
255 |
+
Corresponding Source from a network server at no charge.
|
256 |
+
|
257 |
+
c) Convey individual copies of the object code with a copy of the
|
258 |
+
written offer to provide the Corresponding Source. This
|
259 |
+
alternative is allowed only occasionally and noncommercially, and
|
260 |
+
only if you received the object code with such an offer, in accord
|
261 |
+
with subsection 6b.
|
262 |
+
|
263 |
+
d) Convey the object code by offering access from a designated
|
264 |
+
place (gratis or for a charge), and offer equivalent access to the
|
265 |
+
Corresponding Source in the same way through the same place at no
|
266 |
+
further charge. You need not require recipients to copy the
|
267 |
+
Corresponding Source along with the object code. If the place to
|
268 |
+
copy the object code is a network server, the Corresponding Source
|
269 |
+
may be on a different server (operated by you or a third party)
|
270 |
+
that supports equivalent copying facilities, provided you maintain
|
271 |
+
clear directions next to the object code saying where to find the
|
272 |
+
Corresponding Source. Regardless of what server hosts the
|
273 |
+
Corresponding Source, you remain obligated to ensure that it is
|
274 |
+
available for as long as needed to satisfy these requirements.
|
275 |
+
|
276 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
277 |
+
you inform other peers where the object code and Corresponding
|
278 |
+
Source of the work are being offered to the general public at no
|
279 |
+
charge under subsection 6d.
|
280 |
+
|
281 |
+
A separable portion of the object code, whose source code is excluded
|
282 |
+
from the Corresponding Source as a System Library, need not be
|
283 |
+
included in conveying the object code work.
|
284 |
+
|
285 |
+
A "User Product" is either (1) a "consumer product", which means any
|
286 |
+
tangible personal property which is normally used for personal, family,
|
287 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
288 |
+
into a dwelling. In determining whether a product is a consumer product,
|
289 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
290 |
+
product received by a particular user, "normally used" refers to a
|
291 |
+
typical or common use of that class of product, regardless of the status
|
292 |
+
of the particular user or of the way in which the particular user
|
293 |
+
actually uses, or expects or is expected to use, the product. A product
|
294 |
+
is a consumer product regardless of whether the product has substantial
|
295 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
296 |
+
the only significant mode of use of the product.
|
297 |
+
|
298 |
+
"Installation Information" for a User Product means any methods,
|
299 |
+
procedures, authorization keys, or other information required to install
|
300 |
+
and execute modified versions of a covered work in that User Product from
|
301 |
+
a modified version of its Corresponding Source. The information must
|
302 |
+
suffice to ensure that the continued functioning of the modified object
|
303 |
+
code is in no case prevented or interfered with solely because
|
304 |
+
modification has been made.
|
305 |
+
|
306 |
+
If you convey an object code work under this section in, or with, or
|
307 |
+
specifically for use in, a User Product, and the conveying occurs as
|
308 |
+
part of a transaction in which the right of possession and use of the
|
309 |
+
User Product is transferred to the recipient in perpetuity or for a
|
310 |
+
fixed term (regardless of how the transaction is characterized), the
|
311 |
+
Corresponding Source conveyed under this section must be accompanied
|
312 |
+
by the Installation Information. But this requirement does not apply
|
313 |
+
if neither you nor any third party retains the ability to install
|
314 |
+
modified object code on the User Product (for example, the work has
|
315 |
+
been installed in ROM).
|
316 |
+
|
317 |
+
The requirement to provide Installation Information does not include a
|
318 |
+
requirement to continue to provide support service, warranty, or updates
|
319 |
+
for a work that has been modified or installed by the recipient, or for
|
320 |
+
the User Product in which it has been modified or installed. Access to a
|
321 |
+
network may be denied when the modification itself materially and
|
322 |
+
adversely affects the operation of the network or violates the rules and
|
323 |
+
protocols for communication across the network.
|
324 |
+
|
325 |
+
Corresponding Source conveyed, and Installation Information provided,
|
326 |
+
in accord with this section must be in a format that is publicly
|
327 |
+
documented (and with an implementation available to the public in
|
328 |
+
source code form), and must require no special password or key for
|
329 |
+
unpacking, reading or copying.
|
330 |
+
|
331 |
+
7. Additional Terms.
|
332 |
+
|
333 |
+
"Additional permissions" are terms that supplement the terms of this
|
334 |
+
License by making exceptions from one or more of its conditions.
|
335 |
+
Additional permissions that are applicable to the entire Program shall
|
336 |
+
be treated as though they were included in this License, to the extent
|
337 |
+
that they are valid under applicable law. If additional permissions
|
338 |
+
apply only to part of the Program, that part may be used separately
|
339 |
+
under those permissions, but the entire Program remains governed by
|
340 |
+
this License without regard to the additional permissions.
|
341 |
+
|
342 |
+
When you convey a copy of a covered work, you may at your option
|
343 |
+
remove any additional permissions from that copy, or from any part of
|
344 |
+
it. (Additional permissions may be written to require their own
|
345 |
+
removal in certain cases when you modify the work.) You may place
|
346 |
+
additional permissions on material, added by you to a covered work,
|
347 |
+
for which you have or can give appropriate copyright permission.
|
348 |
+
|
349 |
+
Notwithstanding any other provision of this License, for material you
|
350 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
351 |
+
that material) supplement the terms of this License with terms:
|
352 |
+
|
353 |
+
a) Disclaiming warranty or limiting liability differently from the
|
354 |
+
terms of sections 15 and 16 of this License; or
|
355 |
+
|
356 |
+
b) Requiring preservation of specified reasonable legal notices or
|
357 |
+
author attributions in that material or in the Appropriate Legal
|
358 |
+
Notices displayed by works containing it; or
|
359 |
+
|
360 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
361 |
+
requiring that modified versions of such material be marked in
|
362 |
+
reasonable ways as different from the original version; or
|
363 |
+
|
364 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
365 |
+
authors of the material; or
|
366 |
+
|
367 |
+
e) Declining to grant rights under trademark law for use of some
|
368 |
+
trade names, trademarks, or service marks; or
|
369 |
+
|
370 |
+
f) Requiring indemnification of licensors and authors of that
|
371 |
+
material by anyone who conveys the material (or modified versions of
|
372 |
+
it) with contractual assumptions of liability to the recipient, for
|
373 |
+
any liability that these contractual assumptions directly impose on
|
374 |
+
those licensors and authors.
|
375 |
+
|
376 |
+
All other non-permissive additional terms are considered "further
|
377 |
+
restrictions" within the meaning of section 10. If the Program as you
|
378 |
+
received it, or any part of it, contains a notice stating that it is
|
379 |
+
governed by this License along with a term that is a further
|
380 |
+
restriction, you may remove that term. If a license document contains
|
381 |
+
a further restriction but permits relicensing or conveying under this
|
382 |
+
License, you may add to a covered work material governed by the terms
|
383 |
+
of that license document, provided that the further restriction does
|
384 |
+
not survive such relicensing or conveying.
|
385 |
+
|
386 |
+
If you add terms to a covered work in accord with this section, you
|
387 |
+
must place, in the relevant source files, a statement of the
|
388 |
+
additional terms that apply to those files, or a notice indicating
|
389 |
+
where to find the applicable terms.
|
390 |
+
|
391 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
392 |
+
form of a separately written license, or stated as exceptions;
|
393 |
+
the above requirements apply either way.
|
394 |
+
|
395 |
+
8. Termination.
|
396 |
+
|
397 |
+
You may not propagate or modify a covered work except as expressly
|
398 |
+
provided under this License. Any attempt otherwise to propagate or
|
399 |
+
modify it is void, and will automatically terminate your rights under
|
400 |
+
this License (including any patent licenses granted under the third
|
401 |
+
paragraph of section 11).
|
402 |
+
|
403 |
+
However, if you cease all violation of this License, then your
|
404 |
+
license from a particular copyright holder is reinstated (a)
|
405 |
+
provisionally, unless and until the copyright holder explicitly and
|
406 |
+
finally terminates your license, and (b) permanently, if the copyright
|
407 |
+
holder fails to notify you of the violation by some reasonable means
|
408 |
+
prior to 60 days after the cessation.
|
409 |
+
|
410 |
+
Moreover, your license from a particular copyright holder is
|
411 |
+
reinstated permanently if the copyright holder notifies you of the
|
412 |
+
violation by some reasonable means, this is the first time you have
|
413 |
+
received notice of violation of this License (for any work) from that
|
414 |
+
copyright holder, and you cure the violation prior to 30 days after
|
415 |
+
your receipt of the notice.
|
416 |
+
|
417 |
+
Termination of your rights under this section does not terminate the
|
418 |
+
licenses of parties who have received copies or rights from you under
|
419 |
+
this License. If your rights have been terminated and not permanently
|
420 |
+
reinstated, you do not qualify to receive new licenses for the same
|
421 |
+
material under section 10.
|
422 |
+
|
423 |
+
9. Acceptance Not Required for Having Copies.
|
424 |
+
|
425 |
+
You are not required to accept this License in order to receive or
|
426 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
427 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
428 |
+
to receive a copy likewise does not require acceptance. However,
|
429 |
+
nothing other than this License grants you permission to propagate or
|
430 |
+
modify any covered work. These actions infringe copyright if you do
|
431 |
+
not accept this License. Therefore, by modifying or propagating a
|
432 |
+
covered work, you indicate your acceptance of this License to do so.
|
433 |
+
|
434 |
+
10. Automatic Licensing of Downstream Recipients.
|
435 |
+
|
436 |
+
Each time you convey a covered work, the recipient automatically
|
437 |
+
receives a license from the original licensors, to run, modify and
|
438 |
+
propagate that work, subject to this License. You are not responsible
|
439 |
+
for enforcing compliance by third parties with this License.
|
440 |
+
|
441 |
+
An "entity transaction" is a transaction transferring control of an
|
442 |
+
organization, or substantially all assets of one, or subdividing an
|
443 |
+
organization, or merging organizations. If propagation of a covered
|
444 |
+
work results from an entity transaction, each party to that
|
445 |
+
transaction who receives a copy of the work also receives whatever
|
446 |
+
licenses to the work the party's predecessor in interest had or could
|
447 |
+
give under the previous paragraph, plus a right to possession of the
|
448 |
+
Corresponding Source of the work from the predecessor in interest, if
|
449 |
+
the predecessor has it or can get it with reasonable efforts.
|
450 |
+
|
451 |
+
You may not impose any further restrictions on the exercise of the
|
452 |
+
rights granted or affirmed under this License. For example, you may
|
453 |
+
not impose a license fee, royalty, or other charge for exercise of
|
454 |
+
rights granted under this License, and you may not initiate litigation
|
455 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
456 |
+
any patent claim is infringed by making, using, selling, offering for
|
457 |
+
sale, or importing the Program or any portion of it.
|
458 |
+
|
459 |
+
11. Patents.
|
460 |
+
|
461 |
+
A "contributor" is a copyright holder who authorizes use under this
|
462 |
+
License of the Program or a work on which the Program is based. The
|
463 |
+
work thus licensed is called the contributor's "contributor version".
|
464 |
+
|
465 |
+
A contributor's "essential patent claims" are all patent claims
|
466 |
+
owned or controlled by the contributor, whether already acquired or
|
467 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
468 |
+
by this License, of making, using, or selling its contributor version,
|
469 |
+
but do not include claims that would be infringed only as a
|
470 |
+
consequence of further modification of the contributor version. For
|
471 |
+
purposes of this definition, "control" includes the right to grant
|
472 |
+
patent sublicenses in a manner consistent with the requirements of
|
473 |
+
this License.
|
474 |
+
|
475 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
476 |
+
patent license under the contributor's essential patent claims, to
|
477 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
478 |
+
propagate the contents of its contributor version.
|
479 |
+
|
480 |
+
In the following three paragraphs, a "patent license" is any express
|
481 |
+
agreement or commitment, however denominated, not to enforce a patent
|
482 |
+
(such as an express permission to practice a patent or covenant not to
|
483 |
+
sue for patent infringement). To "grant" such a patent license to a
|
484 |
+
party means to make such an agreement or commitment not to enforce a
|
485 |
+
patent against the party.
|
486 |
+
|
487 |
+
If you convey a covered work, knowingly relying on a patent license,
|
488 |
+
and the Corresponding Source of the work is not available for anyone
|
489 |
+
to copy, free of charge and under the terms of this License, through a
|
490 |
+
publicly available network server or other readily accessible means,
|
491 |
+
then you must either (1) cause the Corresponding Source to be so
|
492 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
493 |
+
patent license for this particular work, or (3) arrange, in a manner
|
494 |
+
consistent with the requirements of this License, to extend the patent
|
495 |
+
license to downstream recipients. "Knowingly relying" means you have
|
496 |
+
actual knowledge that, but for the patent license, your conveying the
|
497 |
+
covered work in a country, or your recipient's use of the covered work
|
498 |
+
in a country, would infringe one or more identifiable patents in that
|
499 |
+
country that you have reason to believe are valid.
|
500 |
+
|
501 |
+
If, pursuant to or in connection with a single transaction or
|
502 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
503 |
+
covered work, and grant a patent license to some of the parties
|
504 |
+
receiving the covered work authorizing them to use, propagate, modify
|
505 |
+
or convey a specific copy of the covered work, then the patent license
|
506 |
+
you grant is automatically extended to all recipients of the covered
|
507 |
+
work and works based on it.
|
508 |
+
|
509 |
+
A patent license is "discriminatory" if it does not include within
|
510 |
+
the scope of its coverage, prohibits the exercise of, or is
|
511 |
+
conditioned on the non-exercise of one or more of the rights that are
|
512 |
+
specifically granted under this License. You may not convey a covered
|
513 |
+
work if you are a party to an arrangement with a third party that is
|
514 |
+
in the business of distributing software, under which you make payment
|
515 |
+
to the third party based on the extent of your activity of conveying
|
516 |
+
the work, and under which the third party grants, to any of the
|
517 |
+
parties who would receive the covered work from you, a discriminatory
|
518 |
+
patent license (a) in connection with copies of the covered work
|
519 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
520 |
+
for and in connection with specific products or compilations that
|
521 |
+
contain the covered work, unless you entered into that arrangement,
|
522 |
+
or that patent license was granted, prior to 28 March 2007.
|
523 |
+
|
524 |
+
Nothing in this License shall be construed as excluding or limiting
|
525 |
+
any implied license or other defenses to infringement that may
|
526 |
+
otherwise be available to you under applicable patent law.
|
527 |
+
|
528 |
+
12. No Surrender of Others' Freedom.
|
529 |
+
|
530 |
+
If conditions are imposed on you (whether by court order, agreement or
|
531 |
+
otherwise) that contradict the conditions of this License, they do not
|
532 |
+
excuse you from the conditions of this License. If you cannot convey a
|
533 |
+
covered work so as to satisfy simultaneously your obligations under this
|
534 |
+
License and any other pertinent obligations, then as a consequence you may
|
535 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
536 |
+
to collect a royalty for further conveying from those to whom you convey
|
537 |
+
the Program, the only way you could satisfy both those terms and this
|
538 |
+
License would be to refrain entirely from conveying the Program.
|
539 |
+
|
540 |
+
13. Remote Network Interaction; Use with the GNU General Public License.
|
541 |
+
|
542 |
+
Notwithstanding any other provision of this License, if you modify the
|
543 |
+
Program, your modified version must prominently offer all users
|
544 |
+
interacting with it remotely through a computer network (if your version
|
545 |
+
supports such interaction) an opportunity to receive the Corresponding
|
546 |
+
Source of your version by providing access to the Corresponding Source
|
547 |
+
from a network server at no charge, through some standard or customary
|
548 |
+
means of facilitating copying of software. This Corresponding Source
|
549 |
+
shall include the Corresponding Source for any work covered by version 3
|
550 |
+
of the GNU General Public License that is incorporated pursuant to the
|
551 |
+
following paragraph.
|
552 |
+
|
553 |
+
Notwithstanding any other provision of this License, you have
|
554 |
+
permission to link or combine any covered work with a work licensed
|
555 |
+
under version 3 of the GNU General Public License into a single
|
556 |
+
combined work, and to convey the resulting work. The terms of this
|
557 |
+
License will continue to apply to the part which is the covered work,
|
558 |
+
but the work with which it is combined will remain governed by version
|
559 |
+
3 of the GNU General Public License.
|
560 |
+
|
561 |
+
14. Revised Versions of this License.
|
562 |
+
|
563 |
+
The Free Software Foundation may publish revised and/or new versions of
|
564 |
+
the GNU Affero General Public License from time to time. Such new versions
|
565 |
+
will be similar in spirit to the present version, but may differ in detail to
|
566 |
+
address new problems or concerns.
|
567 |
+
|
568 |
+
Each version is given a distinguishing version number. If the
|
569 |
+
Program specifies that a certain numbered version of the GNU Affero General
|
570 |
+
Public License "or any later version" applies to it, you have the
|
571 |
+
option of following the terms and conditions either of that numbered
|
572 |
+
version or of any later version published by the Free Software
|
573 |
+
Foundation. If the Program does not specify a version number of the
|
574 |
+
GNU Affero General Public License, you may choose any version ever published
|
575 |
+
by the Free Software Foundation.
|
576 |
+
|
577 |
+
If the Program specifies that a proxy can decide which future
|
578 |
+
versions of the GNU Affero General Public License can be used, that proxy's
|
579 |
+
public statement of acceptance of a version permanently authorizes you
|
580 |
+
to choose that version for the Program.
|
581 |
+
|
582 |
+
Later license versions may give you additional or different
|
583 |
+
permissions. However, no additional obligations are imposed on any
|
584 |
+
author or copyright holder as a result of your choosing to follow a
|
585 |
+
later version.
|
586 |
+
|
587 |
+
15. Disclaimer of Warranty.
|
588 |
+
|
589 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
590 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
591 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
592 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
593 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
594 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
595 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
596 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
597 |
+
|
598 |
+
16. Limitation of Liability.
|
599 |
+
|
600 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
601 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
602 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
603 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
604 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
605 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
606 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
607 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
608 |
+
SUCH DAMAGES.
|
609 |
+
|
610 |
+
17. Interpretation of Sections 15 and 16.
|
611 |
+
|
612 |
+
If the disclaimer of warranty and limitation of liability provided
|
613 |
+
above cannot be given local legal effect according to their terms,
|
614 |
+
reviewing courts shall apply local law that most closely approximates
|
615 |
+
an absolute waiver of all civil liability in connection with the
|
616 |
+
Program, unless a warranty or assumption of liability accompanies a
|
617 |
+
copy of the Program in return for a fee.
|
618 |
+
|
619 |
+
END OF TERMS AND CONDITIONS
|
620 |
+
|
621 |
+
How to Apply These Terms to Your New Programs
|
622 |
+
|
623 |
+
If you develop a new program, and you want it to be of the greatest
|
624 |
+
possible use to the public, the best way to achieve this is to make it
|
625 |
+
free software which everyone can redistribute and change under these terms.
|
626 |
+
|
627 |
+
To do so, attach the following notices to the program. It is safest
|
628 |
+
to attach them to the start of each source file to most effectively
|
629 |
+
state the exclusion of warranty; and each file should have at least
|
630 |
+
the "copyright" line and a pointer to where the full notice is found.
|
631 |
+
|
632 |
+
<one line to give the program's name and a brief idea of what it does.>
|
633 |
+
Copyright (C) <year> <name of author>
|
634 |
+
|
635 |
+
This program is free software: you can redistribute it and/or modify
|
636 |
+
it under the terms of the GNU Affero General Public License as published
|
637 |
+
by the Free Software Foundation, either version 3 of the License, or
|
638 |
+
(at your option) any later version.
|
639 |
+
|
640 |
+
This program is distributed in the hope that it will be useful,
|
641 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
642 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
643 |
+
GNU Affero General Public License for more details.
|
644 |
+
|
645 |
+
You should have received a copy of the GNU Affero General Public License
|
646 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
647 |
+
|
648 |
+
Also add information on how to contact you by electronic and paper mail.
|
649 |
+
|
650 |
+
If your software can interact with users remotely through a computer
|
651 |
+
network, you should also make sure that it provides a way for users to
|
652 |
+
get its source. For example, if your program is a web application, its
|
653 |
+
interface could display a "Source" link that leads users to an archive
|
654 |
+
of the code. There are many ways you could offer source, and different
|
655 |
+
solutions will be better for different programs; see section 13 for the
|
656 |
+
specific requirements.
|
657 |
+
|
658 |
+
You should also get your employer (if you work as a programmer) or school,
|
659 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
660 |
+
For more information on this, and how to apply and follow the GNU AGPL, see
|
661 |
+
<https://www.gnu.org/licenses/>.
|
MIT_LICENSE_GGML_LLAMACPP_ONLY
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Georgi Gerganov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
|
23 |
+
===================================
|
24 |
+
|
25 |
+
Note that the above license applies ONLY to the GGML library and llama.cpp by ggerganov which are licensed under the MIT License
|
26 |
+
KoboldAI Lite by Concedo and the provided python ctypes bindings in koboldcpp dlls are licensed under the AGPL v3.0 License
|
Makefile
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Add custom options to Makefile.local rather than editing this file.
|
2 |
+
-include $(abspath $(lastword ${MAKEFILE_LIST})).local
|
3 |
+
|
4 |
+
.PHONY: finishedmsg
|
5 |
+
|
6 |
+
default: koboldcpp_default koboldcpp_failsafe koboldcpp_noavx2 koboldcpp_clblast koboldcpp_clblast_noavx2 koboldcpp_clblast_failsafe koboldcpp_cublas koboldcpp_hipblas koboldcpp_vulkan koboldcpp_vulkan_noavx2 finishedmsg
|
7 |
+
tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip ttsmain whispermain sdmain gguf-split
|
8 |
+
|
9 |
+
ifndef UNAME_S
|
10 |
+
UNAME_S := $(shell uname -s)
|
11 |
+
endif
|
12 |
+
|
13 |
+
ifndef UNAME_P
|
14 |
+
UNAME_P := $(shell uname -p)
|
15 |
+
endif
|
16 |
+
|
17 |
+
ifndef UNAME_M
|
18 |
+
UNAME_M := $(shell uname -m)
|
19 |
+
endif
|
20 |
+
|
21 |
+
ifndef UNAME_O
|
22 |
+
UNAME_O := $(shell uname -o)
|
23 |
+
endif
|
24 |
+
|
25 |
+
ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),)
|
26 |
+
ARCH_ADD = -lcblas
|
27 |
+
endif
|
28 |
+
|
29 |
+
|
30 |
+
# Mac OS + Arm can report x86_64
|
31 |
+
# ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
|
32 |
+
ifeq ($(UNAME_S),Darwin)
|
33 |
+
ifneq ($(UNAME_P),arm)
|
34 |
+
SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null)
|
35 |
+
ifeq ($(SYSCTL_M),1)
|
36 |
+
# UNAME_P := arm
|
37 |
+
# UNAME_M := arm64
|
38 |
+
warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
|
39 |
+
endif
|
40 |
+
endif
|
41 |
+
endif
|
42 |
+
|
43 |
+
#
|
44 |
+
# Compile flags
|
45 |
+
#
|
46 |
+
|
47 |
+
# keep standard at C11 and C++17
|
48 |
+
CFLAGS =
|
49 |
+
CXXFLAGS =
|
50 |
+
ifdef KCPP_DEBUG
|
51 |
+
CFLAGS = -g -O0
|
52 |
+
CXXFLAGS = -g -O0
|
53 |
+
endif
|
54 |
+
CFLAGS += -I. -Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -I./include -I./include/CL -I./otherarch -I./otherarch/tools -I./otherarch/sdcpp -I./otherarch/sdcpp/thirdparty -I./include/vulkan -O3 -fno-finite-math-only -std=c11 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE -DGGML_USE_CPU -DGGML_USE_CPU_AARCH64
|
55 |
+
CXXFLAGS += -I. -Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -I./common -I./include -I./include/CL -I./otherarch -I./otherarch/tools -I./otherarch/sdcpp -I./otherarch/sdcpp/thirdparty -I./include/vulkan -O3 -fno-finite-math-only -std=c++17 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE -DGGML_USE_CPU -DGGML_USE_CPU_AARCH64
|
56 |
+
ifndef KCPP_DEBUG
|
57 |
+
CFLAGS += -DNDEBUG -s
|
58 |
+
CXXFLAGS += -DNDEBUG -s
|
59 |
+
endif
|
60 |
+
ifdef LLAMA_NO_LLAMAFILE
|
61 |
+
GGML_NO_LLAMAFILE := 1
|
62 |
+
endif
|
63 |
+
ifndef GGML_NO_LLAMAFILE
|
64 |
+
CFLAGS += -DGGML_USE_LLAMAFILE
|
65 |
+
CXXFLAGS += -DGGML_USE_LLAMAFILE
|
66 |
+
endif
|
67 |
+
|
68 |
+
#lets try enabling everything
|
69 |
+
CFLAGS += -pthread -Wno-deprecated -Wno-deprecated-declarations -Wno-unused-variable
|
70 |
+
CXXFLAGS += -pthread -Wno-multichar -Wno-write-strings -Wno-deprecated -Wno-deprecated-declarations -Wno-unused-variable
|
71 |
+
|
72 |
+
LDFLAGS =
|
73 |
+
FASTCFLAGS = $(subst -O3,-Ofast,$(CFLAGS))
|
74 |
+
FASTCXXFLAGS = $(subst -O3,-Ofast,$(CXXFLAGS))
|
75 |
+
|
76 |
+
# these are used on windows, to build some libraries with extra old device compatibility
|
77 |
+
SIMPLECFLAGS =
|
78 |
+
SIMPLERCFLAGS =
|
79 |
+
FULLCFLAGS =
|
80 |
+
NONECFLAGS =
|
81 |
+
|
82 |
+
CLBLAST_FLAGS = -DGGML_USE_CLBLAST
|
83 |
+
FAILSAFE_FLAGS = -DUSE_FAILSAFE
|
84 |
+
VULKAN_FLAGS = -DGGML_USE_VULKAN -DSD_USE_VULKAN
|
85 |
+
ifdef LLAMA_CUBLAS
|
86 |
+
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS
|
87 |
+
else
|
88 |
+
CUBLAS_FLAGS =
|
89 |
+
endif
|
90 |
+
CUBLASLD_FLAGS =
|
91 |
+
CUBLAS_OBJS =
|
92 |
+
|
93 |
+
OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o ggml-cpu-aarch64.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o sampling.o kcpputils.o
|
94 |
+
OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants_noavx2.o ggml-cpu-aarch64_noavx2.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o sampling.o kcpputils.o
|
95 |
+
OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants_noavx1.o ggml-cpu-aarch64_noavx1.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o sampling.o kcpputils.o
|
96 |
+
OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants_failsafe.o ggml-cpu-aarch64_failsafe.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o sampling.o kcpputils.o
|
97 |
+
|
98 |
+
# OS specific
|
99 |
+
ifeq ($(UNAME_S),Linux)
|
100 |
+
CFLAGS += -pthread
|
101 |
+
CXXFLAGS += -pthread
|
102 |
+
LDFLAGS += -ldl
|
103 |
+
endif
|
104 |
+
|
105 |
+
ifeq ($(UNAME_S),Darwin)
|
106 |
+
CFLAGS += -pthread
|
107 |
+
CXXFLAGS += -pthread
|
108 |
+
CLANG_VER = $(shell clang -v 2>&1 | head -n 1 | awk 'BEGIN {FS="[. ]"};{print $$1 $$2 $$4}')
|
109 |
+
ifeq ($(CLANG_VER),Appleclang15)
|
110 |
+
LDFLAGS += -ld_classic
|
111 |
+
endif
|
112 |
+
endif
|
113 |
+
ifeq ($(UNAME_S),FreeBSD)
|
114 |
+
CFLAGS += -pthread
|
115 |
+
CXXFLAGS += -pthread
|
116 |
+
endif
|
117 |
+
ifeq ($(UNAME_S),NetBSD)
|
118 |
+
CFLAGS += -pthread
|
119 |
+
CXXFLAGS += -pthread
|
120 |
+
endif
|
121 |
+
ifeq ($(UNAME_S),OpenBSD)
|
122 |
+
CFLAGS += -pthread
|
123 |
+
CXXFLAGS += -pthread
|
124 |
+
endif
|
125 |
+
ifeq ($(UNAME_S),Haiku)
|
126 |
+
CFLAGS += -pthread
|
127 |
+
CXXFLAGS += -pthread
|
128 |
+
endif
|
129 |
+
|
130 |
+
ifdef LLAMA_GPROF
|
131 |
+
CFLAGS += -pg
|
132 |
+
CXXFLAGS += -pg
|
133 |
+
endif
|
134 |
+
ifdef LLAMA_PERF
|
135 |
+
CFLAGS += -DGGML_PERF
|
136 |
+
CXXFLAGS += -DGGML_PERF
|
137 |
+
endif
|
138 |
+
|
139 |
+
CCV := $(shell $(CC) --version | head -n 1)
|
140 |
+
CXXV := $(shell $(CXX) --version | head -n 1)
|
141 |
+
|
142 |
+
# Architecture specific
|
143 |
+
# For x86 based architectures
|
144 |
+
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
|
145 |
+
ifdef LLAMA_PORTABLE
|
146 |
+
SIMPLECFLAGS += -mavx -msse3 -mssse3
|
147 |
+
SIMPLERCFLAGS += -msse3 -mssse3
|
148 |
+
ifdef LLAMA_NOAVX2
|
149 |
+
FULLCFLAGS += -msse3 -mssse3 -mavx
|
150 |
+
else
|
151 |
+
FULLCFLAGS += -mavx2 -msse3 -mssse3 -mfma -mf16c -mavx
|
152 |
+
endif # LLAMA_NOAVX2
|
153 |
+
else
|
154 |
+
CFLAGS += -march=native -mtune=native
|
155 |
+
endif # LLAMA_PORTABLE
|
156 |
+
endif # if x86
|
157 |
+
|
158 |
+
ifndef LLAMA_NO_ACCELERATE
|
159 |
+
# Mac M1 - include Accelerate framework.
|
160 |
+
# `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time).
|
161 |
+
ifeq ($(UNAME_S),Darwin)
|
162 |
+
CFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE
|
163 |
+
CXXFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE
|
164 |
+
LDFLAGS += -framework Accelerate
|
165 |
+
OBJS += ggml-blas.o
|
166 |
+
endif
|
167 |
+
endif
|
168 |
+
|
169 |
+
# it is recommended to use the CMAKE file to build for cublas if you can - will likely work better
|
170 |
+
OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu))
|
171 |
+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu))
|
172 |
+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
|
173 |
+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
|
174 |
+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
|
175 |
+
|
176 |
+
ifdef LLAMA_CUBLAS
|
177 |
+
CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
|
178 |
+
CUBLASLD_FLAGS = -lcuda -lcublas -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/local/cuda/targets/sbsa-linux/lib -L/usr/lib/wsl/lib
|
179 |
+
CUBLAS_OBJS = ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
|
180 |
+
CUBLAS_OBJS += $(patsubst %.cu,%.o,$(filter-out ggml/src/ggml-cuda/ggml-cuda.cu, $(wildcard ggml/src/ggml-cuda/*.cu)))
|
181 |
+
CUBLAS_OBJS += $(OBJS_CUDA_TEMP_INST)
|
182 |
+
NVCC = nvcc
|
183 |
+
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
|
184 |
+
|
185 |
+
ifdef LLAMA_ADD_CONDA_PATHS
|
186 |
+
CUBLASLD_FLAGS += -Lconda/envs/linux/lib -Lconda/envs/linux/lib/stubs
|
187 |
+
endif
|
188 |
+
|
189 |
+
ifdef CUDA_DOCKER_ARCH
|
190 |
+
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
|
191 |
+
else
|
192 |
+
ifdef LLAMA_PORTABLE
|
193 |
+
ifdef LLAMA_COLAB #colab does not need all targets, all-major doesnt work correctly with pascal
|
194 |
+
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=all-major
|
195 |
+
else
|
196 |
+
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=all
|
197 |
+
endif #LLAMA_COLAB
|
198 |
+
else
|
199 |
+
NVCCFLAGS += -arch=native
|
200 |
+
endif #LLAMA_PORTABLE
|
201 |
+
endif # CUDA_DOCKER_ARCH
|
202 |
+
|
203 |
+
ifdef LLAMA_CUDA_F16
|
204 |
+
NVCCFLAGS += -DGGML_CUDA_F16
|
205 |
+
endif # LLAMA_CUDA_F16
|
206 |
+
ifdef LLAMA_CUDA_DMMV_F16
|
207 |
+
NVCCFLAGS += -DGGML_CUDA_F16
|
208 |
+
endif # LLAMA_CUDA_DMMV_F16
|
209 |
+
|
210 |
+
ifdef LLAMA_CUDA_CCBIN
|
211 |
+
NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
|
212 |
+
endif
|
213 |
+
|
214 |
+
ggml/src/ggml-cuda/%.o: ggml/src/ggml-cuda/%.cu ggml/include/ggml.h ggml/src/ggml-common.h ggml/src/ggml-cuda/common.cuh
|
215 |
+
$(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
|
216 |
+
ggml-cuda.o: ggml/src/ggml-cuda/ggml-cuda.cu ggml/include/ggml-cuda.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/src/ggml-backend-impl.h ggml/src/ggml-common.h $(wildcard ggml/src/ggml-cuda/*.cuh)
|
217 |
+
$(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
|
218 |
+
ggml_v2-cuda.o: otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h
|
219 |
+
$(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
|
220 |
+
ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h
|
221 |
+
$(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
|
222 |
+
ggml_v3-cuda.o: otherarch/ggml_v3-cuda.cu otherarch/ggml_v3-cuda.h
|
223 |
+
$(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
|
224 |
+
endif # LLAMA_CUBLAS
|
225 |
+
|
226 |
+
ifdef LLAMA_HIPBLAS
|
227 |
+
ifeq ($(wildcard /opt/rocm),)
|
228 |
+
ROCM_PATH ?= /usr
|
229 |
+
GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
|
230 |
+
HCC := $(ROCM_PATH)/bin/hipcc
|
231 |
+
HCXX := $(ROCM_PATH)/bin/hipcc
|
232 |
+
else
|
233 |
+
ROCM_PATH ?= /opt/rocm
|
234 |
+
GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
|
235 |
+
HCC := $(ROCM_PATH)/llvm/bin/clang
|
236 |
+
HCXX := $(ROCM_PATH)/llvm/bin/clang++
|
237 |
+
endif
|
238 |
+
HIPFLAGS += -DGGML_USE_HIP -DGGML_HIP_NO_VMM -DGGML_HIP_ROCWMMA_FATTN -DGGML_USE_CUDA -DSD_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
|
239 |
+
HIPLDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
|
240 |
+
HIPLDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64
|
241 |
+
HIPLDFLAGS += -lhipblas -lamdhip64 -lrocblas
|
242 |
+
HIP_OBJS += ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
|
243 |
+
HIP_OBJS += $(patsubst %.cu,%.o,$(filter-out ggml/src/ggml-cuda/ggml-cuda.cu, $(wildcard ggml/src/ggml-cuda/*.cu)))
|
244 |
+
HIP_OBJS += $(OBJS_CUDA_TEMP_INST)
|
245 |
+
|
246 |
+
HIPFLAGS2 += $(addprefix --offload-arch=,$(GPU_TARGETS))
|
247 |
+
|
248 |
+
ggml/src/ggml-cuda/%.o: ggml/src/ggml-cuda/%.cu ggml/include/ggml.h ggml/src/ggml-common.h ggml/src/ggml-cuda/common.cuh
|
249 |
+
$(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
|
250 |
+
ggml-cuda.o: ggml/src/ggml-cuda/ggml-cuda.cu ggml/include/ggml-cuda.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/src/ggml-backend-impl.h ggml/src/ggml-common.h $(wildcard ggml/src/ggml-cuda/*.cuh)
|
251 |
+
$(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
|
252 |
+
ggml_v2-cuda.o: otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h
|
253 |
+
$(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
|
254 |
+
ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h
|
255 |
+
$(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
|
256 |
+
ggml_v3-cuda.o: otherarch/ggml_v3-cuda.cu otherarch/ggml_v3-cuda.h
|
257 |
+
$(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
|
258 |
+
endif # LLAMA_HIPBLAS
|
259 |
+
|
260 |
+
|
261 |
+
ifdef LLAMA_METAL
|
262 |
+
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG -DSD_USE_METAL
|
263 |
+
CXXFLAGS += -DGGML_USE_METAL -DSD_USE_METAL
|
264 |
+
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
265 |
+
OBJS += ggml-metal.o
|
266 |
+
|
267 |
+
ggml-metal.o: ggml/src/ggml-metal/ggml-metal.m ggml/src/ggml-metal/ggml-metal-impl.h ggml/include/ggml-metal.h
|
268 |
+
@echo "== Preparing merged Metal file =="
|
269 |
+
@sed -e '/#include "ggml-common.h"/r ggml/src/ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
|
270 |
+
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-merged.metal
|
271 |
+
@cp ggml/src/ggml-metal/ggml-metal-merged.metal ./ggml-metal-merged.metal
|
272 |
+
$(CC) $(CFLAGS) -c $< -o $@
|
273 |
+
endif # LLAMA_METAL
|
274 |
+
|
275 |
+
ifneq ($(filter aarch64%,$(UNAME_M)),)
|
276 |
+
# Apple M1, M2, etc.
|
277 |
+
# Raspberry Pi 3, 4, Zero 2 (64-bit)
|
278 |
+
ifdef LLAMA_PORTABLE
|
279 |
+
CFLAGS +=
|
280 |
+
CXXFLAGS +=
|
281 |
+
else
|
282 |
+
# sve is cooked on termux so we are disabling it
|
283 |
+
ifeq ($(UNAME_O), Android)
|
284 |
+
ifneq ($(findstring clang, $(CCV)), )
|
285 |
+
CFLAGS += -mcpu=native+nosve
|
286 |
+
CXXFLAGS += -mcpu=native+nosve
|
287 |
+
else
|
288 |
+
CFLAGS += -mcpu=native
|
289 |
+
CXXFLAGS += -mcpu=native
|
290 |
+
endif
|
291 |
+
else
|
292 |
+
CFLAGS += -mcpu=native
|
293 |
+
CXXFLAGS += -mcpu=native
|
294 |
+
endif
|
295 |
+
endif
|
296 |
+
endif
|
297 |
+
|
298 |
+
ifneq ($(filter armv6%,$(UNAME_M)),)
|
299 |
+
# Raspberry Pi 1, Zero
|
300 |
+
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
|
301 |
+
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
|
302 |
+
endif
|
303 |
+
ifneq ($(filter armv7%,$(UNAME_M)),)
|
304 |
+
# Raspberry Pi 2
|
305 |
+
CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
|
306 |
+
CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
|
307 |
+
endif
|
308 |
+
ifneq ($(filter armv8%,$(UNAME_M)),)
|
309 |
+
# Raspberry Pi 3, 4, Zero 2 (32-bit)
|
310 |
+
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
311 |
+
CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
312 |
+
endif
|
313 |
+
ifneq ($(filter ppc64%,$(UNAME_M)),)
|
314 |
+
POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
|
315 |
+
ifneq (,$(findstring POWER9,$(POWER9_M)))
|
316 |
+
CFLAGS += -mcpu=power9
|
317 |
+
CXXFLAGS += -mcpu=power9
|
318 |
+
endif
|
319 |
+
endif
|
320 |
+
|
321 |
+
|
322 |
+
DEFAULT_BUILD =
|
323 |
+
FAILSAFE_BUILD =
|
324 |
+
NOAVX2_BUILD =
|
325 |
+
CLBLAST_BUILD =
|
326 |
+
CUBLAS_BUILD =
|
327 |
+
HIPBLAS_BUILD =
|
328 |
+
VULKAN_BUILD =
|
329 |
+
NOTIFY_MSG =
|
330 |
+
|
331 |
+
ifeq ($(OS),Windows_NT)
|
332 |
+
DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
|
333 |
+
ifdef LLAMA_PORTABLE
|
334 |
+
FAILSAFE_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
|
335 |
+
NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
|
336 |
+
endif
|
337 |
+
|
338 |
+
ifdef LLAMA_CLBLAST
|
339 |
+
CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ lib/OpenCL.lib lib/clblast.lib -shared -o [email protected] $(LDFLAGS)
|
340 |
+
endif
|
341 |
+
ifdef LLAMA_VULKAN
|
342 |
+
VULKAN_BUILD = $(CXX) $(CXXFLAGS) $^ lib/vulkan-1.lib -shared -o [email protected] $(LDFLAGS)
|
343 |
+
endif
|
344 |
+
|
345 |
+
ifdef LLAMA_CUBLAS
|
346 |
+
CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o [email protected] $(CUBLASLD_FLAGS) $(LDFLAGS)
|
347 |
+
endif
|
348 |
+
ifdef LLAMA_HIPBLAS
|
349 |
+
HIPBLAS_BUILD = $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o [email protected] $(HIPLDFLAGS) $(LDFLAGS)
|
350 |
+
endif
|
351 |
+
else
|
352 |
+
DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
|
353 |
+
ifdef LLAMA_PORTABLE
|
354 |
+
ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
|
355 |
+
FAILSAFE_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
|
356 |
+
NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
|
357 |
+
endif
|
358 |
+
endif
|
359 |
+
|
360 |
+
ifdef LLAMA_CLBLAST
|
361 |
+
ifeq ($(UNAME_S),Darwin)
|
362 |
+
CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -framework OpenCL $(ARCH_ADD) -shared -o [email protected] $(LDFLAGS)
|
363 |
+
else
|
364 |
+
CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -lOpenCL $(ARCH_ADD) -shared -o [email protected] $(LDFLAGS)
|
365 |
+
endif
|
366 |
+
endif
|
367 |
+
ifdef LLAMA_CUBLAS
|
368 |
+
CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o [email protected] $(CUBLASLD_FLAGS) $(LDFLAGS)
|
369 |
+
endif
|
370 |
+
ifdef LLAMA_HIPBLAS
|
371 |
+
HIPBLAS_BUILD = $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o [email protected] $(HIPLDFLAGS) $(LDFLAGS)
|
372 |
+
endif
|
373 |
+
ifdef LLAMA_VULKAN
|
374 |
+
VULKAN_BUILD = $(CXX) $(CXXFLAGS) $^ -lvulkan -shared -o [email protected] $(LDFLAGS)
|
375 |
+
endif
|
376 |
+
endif
|
377 |
+
|
378 |
+
ifndef LLAMA_CLBLAST
|
379 |
+
ifndef LLAMA_CUBLAS
|
380 |
+
ifndef LLAMA_HIPBLAS
|
381 |
+
ifndef LLAMA_VULKAN
|
382 |
+
ifndef LLAMA_METAL
|
383 |
+
NOTIFY_MSG = @echo -e '\n***\nYou did a basic CPU build. For faster speeds, consider installing and linking a GPU BLAS library. For example, set LLAMA_CLBLAST=1 LLAMA_VULKAN=1 to compile with Vulkan and CLBlast support. Add LLAMA_PORTABLE=1 to make a sharable build that other devices can use. Read the KoboldCpp Wiki for more information. This is just a reminder, not an error.\n***\n'
|
384 |
+
endif
|
385 |
+
endif
|
386 |
+
endif
|
387 |
+
endif
|
388 |
+
endif
|
389 |
+
|
390 |
+
|
391 |
+
#
|
392 |
+
# Print build information
|
393 |
+
#
|
394 |
+
|
395 |
+
$(info I koboldcpp build info: )
|
396 |
+
$(info I UNAME_S: $(UNAME_S))
|
397 |
+
$(info I UNAME_P: $(UNAME_P))
|
398 |
+
$(info I UNAME_M: $(UNAME_M))
|
399 |
+
$(info I UNAME_O: $(UNAME_O))
|
400 |
+
$(info I CFLAGS: $(CFLAGS))
|
401 |
+
$(info I CXXFLAGS: $(CXXFLAGS))
|
402 |
+
$(info I LDFLAGS: $(LDFLAGS))
|
403 |
+
$(info I CC: $(CCV))
|
404 |
+
$(info I CXX: $(CXXV))
|
405 |
+
$(info )
|
406 |
+
|
407 |
+
#
|
408 |
+
# Build library
|
409 |
+
#
|
410 |
+
|
411 |
+
ggml.o: ggml/src/ggml.c ggml/include/ggml.h
|
412 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
|
413 |
+
ggml_v4_failsafe.o: ggml/src/ggml.c ggml/include/ggml.h
|
414 |
+
$(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
|
415 |
+
ggml_v4_noavx2.o: ggml/src/ggml.c ggml/include/ggml.h
|
416 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
417 |
+
ggml_v4_clblast.o: ggml/src/ggml.c ggml/include/ggml.h
|
418 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
419 |
+
ggml_v4_cublas.o: ggml/src/ggml.c ggml/include/ggml.h
|
420 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
421 |
+
ggml_v4_clblast_noavx2.o: ggml/src/ggml.c ggml/include/ggml.h
|
422 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
423 |
+
ggml_v4_clblast_failsafe.o: ggml/src/ggml.c ggml/include/ggml.h
|
424 |
+
$(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
425 |
+
ggml_v4_vulkan.o: ggml/src/ggml.c ggml/include/ggml.h
|
426 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
427 |
+
ggml_v4_vulkan_noavx2.o: ggml/src/ggml.c ggml/include/ggml.h
|
428 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
429 |
+
|
430 |
+
# cpu and clblast separated
|
431 |
+
ggml-cpu.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
|
432 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
|
433 |
+
ggml-cpu_v4_failsafe.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
|
434 |
+
$(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
|
435 |
+
ggml-cpu_v4_noavx2.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
|
436 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
437 |
+
ggml-cpu_v4_clblast.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
|
438 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
439 |
+
ggml-cpu_v4_clblast_noavx2.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
|
440 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
441 |
+
ggml-cpu_v4_clblast_failsafe.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
|
442 |
+
$(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
443 |
+
|
444 |
+
#quants
|
445 |
+
ggml-quants.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
|
446 |
+
$(CC) $(CFLAGS) $(FULLCFLAGS) -c $< -o $@
|
447 |
+
ggml-quants_noavx2.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
|
448 |
+
$(CC) $(CFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
449 |
+
ggml-quants_noavx1.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
|
450 |
+
$(CC) $(CFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
|
451 |
+
ggml-quants_failsafe.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
|
452 |
+
$(CC) $(CFLAGS) $(NONECFLAGS) -c $< -o $@
|
453 |
+
ggml-cpu-quants.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
|
454 |
+
$(CC) $(CFLAGS) $(FULLCFLAGS) -c $< -o $@
|
455 |
+
ggml-cpu-quants_noavx2.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
|
456 |
+
$(CC) $(CFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
457 |
+
ggml-cpu-quants_noavx1.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
|
458 |
+
$(CC) $(CFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
|
459 |
+
ggml-cpu-quants_failsafe.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
|
460 |
+
$(CC) $(CFLAGS) $(NONECFLAGS) -c $< -o $@
|
461 |
+
|
462 |
+
#aarch64
|
463 |
+
ggml-cpu-aarch64.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
|
464 |
+
$(CXX) $(CXXFLAGS) $(FULLCFLAGS) -c $< -o $@
|
465 |
+
ggml-cpu-aarch64_noavx2.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
|
466 |
+
$(CXX) $(CXXFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
467 |
+
ggml-cpu-aarch64_noavx1.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
|
468 |
+
$(CXX) $(CXXFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
|
469 |
+
ggml-cpu-aarch64_failsafe.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
|
470 |
+
$(CXX) $(CXXFLAGS) $(NONECFLAGS) -c $< -o $@
|
471 |
+
|
472 |
+
#sgemm
|
473 |
+
sgemm.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
|
474 |
+
$(CXX) $(CXXFLAGS) $(FULLCFLAGS) -c $< -o $@
|
475 |
+
sgemm_noavx2.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
|
476 |
+
$(CXX) $(CXXFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
477 |
+
sgemm_noavx1.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
|
478 |
+
$(CXX) $(CXXFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
|
479 |
+
sgemm_failsafe.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
|
480 |
+
$(CXX) $(CXXFLAGS) $(NONECFLAGS) -c $< -o $@
|
481 |
+
|
482 |
+
#there's no intrinsics or special gpu ops used here, so we can have a universal object
|
483 |
+
ggml-alloc.o: ggml/src/ggml-alloc.c ggml/include/ggml.h ggml/include/ggml-alloc.h
|
484 |
+
$(CC) $(CFLAGS) -c $< -o $@
|
485 |
+
llava.o: examples/llava/llava.cpp examples/llava/llava.h
|
486 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
487 |
+
unicode.o: src/unicode.cpp src/unicode.h
|
488 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
489 |
+
unicode-data.o: src/unicode-data.cpp src/unicode-data.h
|
490 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
491 |
+
ggml-cpu-traits.o: ggml/src/ggml-cpu/ggml-cpu-traits.cpp ggml/src/ggml-cpu/ggml-cpu-traits.h ggml/include/ggml.h
|
492 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
493 |
+
ggml-threading.o: ggml/src/ggml-threading.cpp ggml/include/ggml.h
|
494 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
495 |
+
ggml-cpu-cpp.o: ggml/src/ggml-cpu/ggml-cpu.cpp ggml/include/ggml.h ggml/src/ggml-common.h
|
496 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
497 |
+
gguf.o: ggml/src/gguf.cpp ggml/include/gguf.h
|
498 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
499 |
+
kcpputils.o: otherarch/utils.cpp otherarch/utils.h
|
500 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
501 |
+
|
502 |
+
#these have special gpu defines
|
503 |
+
ggml-backend_default.o: ggml/src/ggml-backend.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h
|
504 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
505 |
+
ggml-backend_vulkan.o: ggml/src/ggml-backend.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h
|
506 |
+
$(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
507 |
+
ggml-backend_cublas.o: ggml/src/ggml-backend.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h
|
508 |
+
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
509 |
+
ggml-backend-reg_default.o: ggml/src/ggml-backend-reg.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/include/ggml-cpu.h
|
510 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
511 |
+
ggml-backend-reg_vulkan.o: ggml/src/ggml-backend-reg.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/include/ggml-cpu.h
|
512 |
+
$(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
513 |
+
ggml-backend-reg_cublas.o: ggml/src/ggml-backend-reg.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/include/ggml-cpu.h
|
514 |
+
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
515 |
+
llavaclip_default.o: examples/llava/clip.cpp examples/llava/clip.h
|
516 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
517 |
+
llavaclip_cublas.o: examples/llava/clip.cpp examples/llava/clip.h
|
518 |
+
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
519 |
+
llavaclip_vulkan.o: examples/llava/clip.cpp examples/llava/clip.h
|
520 |
+
$(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
521 |
+
|
522 |
+
#this is only used for accelerate
|
523 |
+
ggml-blas.o: ggml/src/ggml-blas/ggml-blas.cpp ggml/include/ggml-blas.h
|
524 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
525 |
+
|
526 |
+
#version 3 libs
|
527 |
+
ggml_v3.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
|
528 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
|
529 |
+
ggml_v3_failsafe.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
|
530 |
+
$(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
|
531 |
+
ggml_v3_noavx2.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
|
532 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
533 |
+
ggml_v3_clblast.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
|
534 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
535 |
+
ggml_v3_cublas.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
|
536 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
537 |
+
ggml_v3_clblast_noavx2.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
|
538 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
539 |
+
ggml_v3_clblast_failsafe.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
|
540 |
+
$(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
541 |
+
|
542 |
+
#version 2 libs
|
543 |
+
ggml_v2.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
|
544 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
|
545 |
+
ggml_v2_failsafe.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
|
546 |
+
$(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
|
547 |
+
ggml_v2_noavx2.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
|
548 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
|
549 |
+
ggml_v2_clblast.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
|
550 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
551 |
+
ggml_v2_cublas.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
|
552 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
553 |
+
ggml_v2_clblast_noavx2.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
|
554 |
+
$(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
555 |
+
ggml_v2_clblast_failsafe.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
|
556 |
+
$(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
557 |
+
|
558 |
+
#extreme old version compat
|
559 |
+
ggml_v1.o: otherarch/ggml_v1.c otherarch/ggml_v1.h
|
560 |
+
$(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
|
561 |
+
ggml_v1_failsafe.o: otherarch/ggml_v1.c otherarch/ggml_v1.h
|
562 |
+
$(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
|
563 |
+
|
564 |
+
#opencl
|
565 |
+
ggml-opencl.o: otherarch/ggml_v3b-opencl.cpp otherarch/ggml_v3b-opencl.h
|
566 |
+
$(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
567 |
+
ggml_v2-opencl.o: otherarch/ggml_v2-opencl.cpp otherarch/ggml_v2-opencl.h
|
568 |
+
$(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
569 |
+
ggml_v2-opencl-legacy.o: otherarch/ggml_v2-opencl-legacy.c otherarch/ggml_v2-opencl-legacy.h
|
570 |
+
$(CC) $(CFLAGS) -c $< -o $@
|
571 |
+
ggml_v3-opencl.o: otherarch/ggml_v3-opencl.cpp otherarch/ggml_v3-opencl.h
|
572 |
+
$(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
573 |
+
|
574 |
+
#vulkan
|
575 |
+
ggml-vulkan.o: ggml/src/ggml-vulkan/ggml-vulkan.cpp ggml/include/ggml-vulkan.h ggml/src/ggml-vulkan-shaders.hpp ggml/src/ggml-vulkan-shaders.cpp
|
576 |
+
$(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
577 |
+
|
578 |
+
# intermediate objects
|
579 |
+
llama.o: src/llama.cpp ggml/include/ggml.h ggml/include/ggml-alloc.h ggml/include/ggml-backend.h ggml/include/ggml-cuda.h ggml/include/ggml-metal.h include/llama.h otherarch/llama-util.h
|
580 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
581 |
+
common.o: common/common.cpp common/common.h common/log.h
|
582 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
583 |
+
sampling.o: common/sampling.cpp common/common.h common/sampling.h common/log.h
|
584 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
585 |
+
console.o: common/console.cpp common/console.h
|
586 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
587 |
+
expose.o: expose.cpp expose.h
|
588 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
589 |
+
|
590 |
+
# sd.cpp objects
|
591 |
+
sdcpp_default.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
|
592 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
593 |
+
sdcpp_cublas.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
|
594 |
+
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
595 |
+
sdcpp_vulkan.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
|
596 |
+
$(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
597 |
+
|
598 |
+
|
599 |
+
#whisper objects
|
600 |
+
whispercpp_default.o: otherarch/whispercpp/whisper_adapter.cpp
|
601 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
602 |
+
whispercpp_cublas.o: otherarch/whispercpp/whisper_adapter.cpp
|
603 |
+
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
604 |
+
|
605 |
+
#tts objects
|
606 |
+
tts_default.o: otherarch/tts_adapter.cpp
|
607 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
608 |
+
|
609 |
+
# idiotic "for easier compilation"
|
610 |
+
GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
|
611 |
+
gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER)
|
612 |
+
$(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) -c $< -o $@
|
613 |
+
gpttype_adapter.o: $(GPTTYPE_ADAPTER)
|
614 |
+
$(CXX) $(CXXFLAGS) -c $< -o $@
|
615 |
+
gpttype_adapter_clblast.o: $(GPTTYPE_ADAPTER)
|
616 |
+
$(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
617 |
+
gpttype_adapter_cublas.o: $(GPTTYPE_ADAPTER)
|
618 |
+
$(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
|
619 |
+
gpttype_adapter_clblast_noavx2.o: $(GPTTYPE_ADAPTER)
|
620 |
+
$(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) $(CLBLAST_FLAGS) -c $< -o $@
|
621 |
+
gpttype_adapter_vulkan.o: $(GPTTYPE_ADAPTER)
|
622 |
+
$(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
623 |
+
gpttype_adapter_vulkan_noavx2.o: $(GPTTYPE_ADAPTER)
|
624 |
+
$(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) $(VULKAN_FLAGS) -c $< -o $@
|
625 |
+
|
626 |
+
clean:
|
627 |
+
rm -vf *.o main sdmain whispermain quantize_gguf quantize_clip quantize_gpt2 quantize_gptj quantize_neox quantize_mpt vulkan-shaders-gen gguf-split gguf-split.exe vulkan-shaders-gen.exe main.exe sdmain.exe whispermain.exe quantize_clip.exe quantize_gguf.exe quantize_gptj.exe quantize_gpt2.exe quantize_neox.exe quantize_mpt.exe koboldcpp_default.dll koboldcpp_failsafe.dll koboldcpp_noavx2.dll koboldcpp_clblast.dll koboldcpp_clblast_noavx2.dll koboldcpp_clblast_failsafe.dll koboldcpp_cublas.dll koboldcpp_hipblas.dll koboldcpp_vulkan.dll koboldcpp_vulkan_noavx2.dll koboldcpp_default.so koboldcpp_failsafe.so koboldcpp_noavx2.so koboldcpp_clblast.so koboldcpp_clblast_noavx2.so koboldcpp_clblast_failsafe.so koboldcpp_cublas.so koboldcpp_hipblas.so koboldcpp_vulkan.so koboldcpp_vulkan_noavx2.so
|
628 |
+
rm -vrf ggml/src/ggml-cuda/*.o
|
629 |
+
rm -vrf ggml/src/ggml-cuda/template-instances/*.o
|
630 |
+
|
631 |
+
# useful tools
|
632 |
+
main: examples/main/main.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
633 |
+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
634 |
+
sdmain: otherarch/sdcpp/util.cpp otherarch/sdcpp/main.cpp otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c build-info.h ggml.o ggml-cpu.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
635 |
+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
636 |
+
whispermain: otherarch/whispercpp/main.cpp otherarch/whispercpp/whisper.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
637 |
+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
638 |
+
ttsmain: examples/tts/tts.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
639 |
+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
640 |
+
gguf-split: examples/gguf-split/gguf-split.cpp ggml.o ggml-cpu.o llama.o build-info.h llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
641 |
+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
642 |
+
gemma3-cli: examples/llava/gemma3-cli.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
643 |
+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
644 |
+
|
645 |
+
vulkan-shaders-gen: ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
|
646 |
+
@echo 'This command can be MANUALLY run to regenerate vulkan shaders. Normally concedo will do it, so you do not have to.'
|
647 |
+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
648 |
+
ifeq ($(OS),Windows_NT)
|
649 |
+
@echo 'Now rebuilding vulkan shaders for Windows...'
|
650 |
+
$(shell) vulkan-shaders-gen --glslc glslc --input-dir ggml/src/ggml-vulkan/vulkan-shaders --target-hpp ggml/src/ggml-vulkan-shaders.hpp --target-cpp ggml/src/ggml-vulkan-shaders.cpp
|
651 |
+
else
|
652 |
+
@echo 'Now rebuilding vulkan shaders for Linux...'
|
653 |
+
${shell} chmod +x vulkan-shaders-gen
|
654 |
+
${shell} chmod +x glslc-linux
|
655 |
+
$(shell) ./vulkan-shaders-gen --glslc ./glslc-linux --input-dir ggml/src/ggml-vulkan/vulkan-shaders --target-hpp ggml/src/ggml-vulkan-shaders.hpp --target-cpp ggml/src/ggml-vulkan-shaders.cpp
|
656 |
+
endif
|
657 |
+
|
658 |
+
#generated libraries
|
659 |
+
koboldcpp_default: ggml.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
660 |
+
$(DEFAULT_BUILD)
|
661 |
+
|
662 |
+
ifdef FAILSAFE_BUILD
|
663 |
+
koboldcpp_failsafe: ggml_v4_failsafe.o ggml-cpu_v4_failsafe.o ggml_v3_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FAILSAFE) $(OBJS)
|
664 |
+
$(FAILSAFE_BUILD)
|
665 |
+
else
|
666 |
+
koboldcpp_failsafe:
|
667 |
+
$(DONOTHING)
|
668 |
+
endif
|
669 |
+
|
670 |
+
ifdef NOAVX2_BUILD
|
671 |
+
koboldcpp_noavx2: ggml_v4_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLE) $(OBJS)
|
672 |
+
$(NOAVX2_BUILD)
|
673 |
+
else
|
674 |
+
koboldcpp_noavx2:
|
675 |
+
$(DONOTHING)
|
676 |
+
endif
|
677 |
+
|
678 |
+
ifdef CLBLAST_BUILD
|
679 |
+
koboldcpp_clblast: ggml_v4_clblast.o ggml-cpu_v4_clblast.o ggml_v3_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
|
680 |
+
$(CLBLAST_BUILD)
|
681 |
+
ifdef NOAVX2_BUILD
|
682 |
+
koboldcpp_clblast_noavx2: ggml_v4_clblast_noavx2.o ggml-cpu_v4_clblast_noavx2.o ggml_v3_clblast_noavx2.o ggml_v2_clblast_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLE) $(OBJS)
|
683 |
+
$(CLBLAST_BUILD)
|
684 |
+
koboldcpp_clblast_failsafe: ggml_v4_clblast_failsafe.o ggml-cpu_v4_clblast_failsafe.o ggml_v3_clblast_failsafe.o ggml_v2_clblast_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLER) $(OBJS)
|
685 |
+
$(CLBLAST_BUILD)
|
686 |
+
else
|
687 |
+
koboldcpp_clblast_noavx2:
|
688 |
+
$(DONOTHING)
|
689 |
+
koboldcpp_clblast_failsafe:
|
690 |
+
$(DONOTHING)
|
691 |
+
endif
|
692 |
+
else
|
693 |
+
koboldcpp_clblast:
|
694 |
+
$(DONOTHING)
|
695 |
+
koboldcpp_clblast_noavx2:
|
696 |
+
$(DONOTHING)
|
697 |
+
koboldcpp_clblast_failsafe:
|
698 |
+
$(DONOTHING)
|
699 |
+
endif
|
700 |
+
|
701 |
+
ifdef CUBLAS_BUILD
|
702 |
+
koboldcpp_cublas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(CUBLAS_OBJS) $(OBJS_FULL) $(OBJS)
|
703 |
+
$(CUBLAS_BUILD)
|
704 |
+
else
|
705 |
+
koboldcpp_cublas:
|
706 |
+
$(DONOTHING)
|
707 |
+
endif
|
708 |
+
|
709 |
+
ifdef HIPBLAS_BUILD
|
710 |
+
koboldcpp_hipblas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(HIP_OBJS) $(OBJS_FULL) $(OBJS)
|
711 |
+
$(HIPBLAS_BUILD)
|
712 |
+
else
|
713 |
+
koboldcpp_hipblas:
|
714 |
+
$(DONOTHING)
|
715 |
+
endif
|
716 |
+
|
717 |
+
ifdef VULKAN_BUILD
|
718 |
+
koboldcpp_vulkan: ggml_v4_vulkan.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter_vulkan.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o tts_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_FULL) $(OBJS)
|
719 |
+
$(VULKAN_BUILD)
|
720 |
+
ifdef NOAVX2_BUILD
|
721 |
+
koboldcpp_vulkan_noavx2: ggml_v4_vulkan_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_vulkan_noavx2.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o tts_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_SIMPLE) $(OBJS)
|
722 |
+
$(VULKAN_BUILD)
|
723 |
+
else
|
724 |
+
koboldcpp_vulkan_noavx2:
|
725 |
+
$(DONOTHING)
|
726 |
+
endif
|
727 |
+
else
|
728 |
+
koboldcpp_vulkan:
|
729 |
+
$(DONOTHING)
|
730 |
+
koboldcpp_vulkan_noavx2:
|
731 |
+
$(DONOTHING)
|
732 |
+
endif
|
733 |
+
|
734 |
+
# tools
|
735 |
+
quantize_gguf: examples/quantize/quantize.cpp ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
|
736 |
+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
737 |
+
quantize_gptj: otherarch/tools/gptj_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
|
738 |
+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
739 |
+
quantize_gpt2: otherarch/tools/gpt2_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
|
740 |
+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
741 |
+
quantize_neox: otherarch/tools/neox_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
|
742 |
+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
743 |
+
quantize_mpt: otherarch/tools/mpt_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
|
744 |
+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
745 |
+
quantize_clip: examples/llava/clip.cpp examples/llava/clip.h examples/llava/quantclip.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
|
746 |
+
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
|
747 |
+
|
748 |
+
#window simple clinfo
|
749 |
+
simpleclinfo: simpleclinfo.cpp
|
750 |
+
$(CXX) $(CXXFLAGS) $^ lib/OpenCL.lib lib/clblast.lib -o $@ $(LDFLAGS)
|
751 |
+
|
752 |
+
build-info.h:
|
753 |
+
$(DONOTHING)
|
754 |
+
|
755 |
+
#phony for printing messages
|
756 |
+
finishedmsg:
|
757 |
+
$(NOTIFY_MSG)
|
758 |
+
$(DONOTHING)
|
OpenCL.dll
ADDED
Binary file (55.8 kB). View file
|
|
README.md
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# koboldcpp
|
2 |
+
|
3 |
+
KoboldCpp is an easy-to-use AI text-generation software for GGML and GGUF models, inspired by the original **KoboldAI**. It's a single self-contained distributable from Concedo, that builds off llama.cpp and adds many additional powerful features.
|
4 |
+
|
5 |
+

|
6 |
+

|
7 |
+

|
8 |
+

|
9 |
+
|
10 |
+
### Features
|
11 |
+
- Single file executable, with no installation required and no external dependencies
|
12 |
+
- Runs on CPU or GPU, supports full or partial offloaded
|
13 |
+
- LLM text generation (Supports all GGML and GGUF models, backwards compatibility with ALL past models)
|
14 |
+
- Image Generation (Stable Diffusion 1.5, SDXL, SD3, Flux)
|
15 |
+
- Speech-To-Text (Voice Recognition) via Whisper
|
16 |
+
- Text-To-Speech (Voice Generation) via OuteTTS
|
17 |
+
- Provides many compatible APIs endpoints for many popular webservices (KoboldCppApi OpenAiApi OllamaApi A1111ForgeApi ComfyUiApi WhisperTranscribeApi XttsApi OpenAiSpeechApi)
|
18 |
+
- Bundled KoboldAI Lite UI with editing tools, save formats, memory, world info, author's note, characters, scenarios.
|
19 |
+
- Includes multiple modes (chat, adventure, instruct, storywriter) and UI Themes (aesthetic roleplay, classic writer, corporate assistant, messsenger)
|
20 |
+
- Supports loading Tavern Character Cards, importing many different data formats from various sites, reading or exporting JSON savefiles and persistent stories.
|
21 |
+
- Many other features including new samplers, regex support, websearch, RAG via TextDB and more.
|
22 |
+
- Ready-to-use binaries for Windows, MacOS, Linux, Android (via Termux), Colab, Docker, also supports other platforms if self-compiled (like Raspberry PI).
|
23 |
+
- [Need help finding a model? Read this!](https://github.com/LostRuins/koboldcpp/wiki#getting-an-ai-model-file)
|
24 |
+
|
25 |
+
## Windows Usage (Precompiled Binary, Recommended)
|
26 |
+
- Windows binaries are provided in the form of **koboldcpp.exe**, which is a pyinstaller wrapper containing all necessary files. **[Download the latest koboldcpp.exe release here](https://github.com/LostRuins/koboldcpp/releases/latest)**
|
27 |
+
- To run, simply execute **koboldcpp.exe**.
|
28 |
+
- Launching with no command line arguments displays a GUI containing a subset of configurable settings. Generally you dont have to change much besides the `Presets` and `GPU Layers`. Read the `--help` for more info about each settings.
|
29 |
+
- Obtain and load a GGUF model. See [here](#Obtaining-a-GGUF-model)
|
30 |
+
- By default, you can connect to http://localhost:5001
|
31 |
+
- You can also run it using the command line. For info, please check `koboldcpp.exe --help`
|
32 |
+
|
33 |
+
## Linux Usage (Precompiled Binary, Recommended)
|
34 |
+
On modern Linux systems, you should download the `koboldcpp-linux-x64-cuda1150` prebuilt PyInstaller binary on the **[releases page](https://github.com/LostRuins/koboldcpp/releases/latest)**. Simply download and run the binary (You may have to `chmod +x` it first).
|
35 |
+
|
36 |
+
Alternatively, you can also install koboldcpp to the current directory by running the following terminal command:
|
37 |
+
```
|
38 |
+
curl -fLo koboldcpp https://github.com/LostRuins/koboldcpp/releases/latest/download/koboldcpp-linux-x64-cuda1150 && chmod +x koboldcpp
|
39 |
+
```
|
40 |
+
After running this command you can launch Koboldcpp from the current directory using `./koboldcpp` in the terminal (for CLI usage, run with `--help`).
|
41 |
+
Finally, obtain and load a GGUF model. See [here](#Obtaining-a-GGUF-model)
|
42 |
+
|
43 |
+
## MacOS (Precompiled Binary)
|
44 |
+
- PyInstaller binaries for Modern ARM64 MacOS (M1, M2, M3) are now available! **[Simply download the MacOS binary](https://github.com/LostRuins/koboldcpp/releases/latest)**
|
45 |
+
- In a MacOS terminal window, set the file to executable `chmod +x koboldcpp-mac-arm64` and run it with `./koboldcpp-mac-arm64`.
|
46 |
+
- In newer MacOS you may also have to whitelist it in security settings if it's blocked. [Here's a video guide](https://youtube.com/watch?v=NOW5dyA_JgY).
|
47 |
+
- Alternatively, or for older x86 MacOS computers, you can clone the repo and compile from source code, see Compiling for MacOS below.
|
48 |
+
- Finally, obtain and load a GGUF model. See [here](#Obtaining-a-GGUF-model)
|
49 |
+
|
50 |
+
## Run on Colab
|
51 |
+
- KoboldCpp now has an **official Colab GPU Notebook**! This is an easy way to get started without installing anything in a minute or two. [Try it here!](https://colab.research.google.com/github/LostRuins/koboldcpp/blob/concedo/colab.ipynb).
|
52 |
+
- Note that KoboldCpp is not responsible for your usage of this Colab Notebook, you should ensure that your own usage complies with Google Colab's terms of use.
|
53 |
+
|
54 |
+
## Run on RunPod
|
55 |
+
- KoboldCpp can now be used on RunPod cloud GPUs! This is an easy way to get started without installing anything in a minute or two, and is very scalable, capable of running 70B+ models at afforable cost. [Try our RunPod image here!](https://koboldai.org/runpodcpp).
|
56 |
+
|
57 |
+
## Run on Novita AI
|
58 |
+
KoboldCpp can now also be run on Novita AI, a newer alternative GPU cloud provider which has a quick launch KoboldCpp template for as well. [Check it out here!](https://koboldai.org/novitacpp)
|
59 |
+
|
60 |
+
## Docker
|
61 |
+
- The official docker can be found at https://hub.docker.com/r/koboldai/koboldcpp
|
62 |
+
- If you're building your own docker, remember to set CUDA_DOCKER_ARCH or enable LLAMA_PORTABLE
|
63 |
+
|
64 |
+
## Obtaining a GGUF model
|
65 |
+
- KoboldCpp uses GGUF models. They are not included with KoboldCpp, but you can download GGUF files from other places such as [TheBloke's Huggingface](https://huggingface.co/TheBloke). Search for "GGUF" on huggingface.co for plenty of compatible models in the `.gguf` format.
|
66 |
+
- For beginners, we recommend the models [Airoboros Mistral 7B](https://huggingface.co/TheBloke/airoboros-mistral2.2-7B-GGUF/resolve/main/airoboros-mistral2.2-7b.Q4_K_S.gguf) (smaller and weaker) or [Tiefighter 13B](https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf) (larger model) or [Beepo 22B](https://huggingface.co/concedo/Beepo-22B-GGUF/resolve/main/Beepo-22B-Q4_K_S.gguf) (largest and most powerful)
|
67 |
+
- [Alternatively, you can download the tools to convert models to the GGUF format yourself here](https://kcpptools.concedo.workers.dev). Run `convert-hf-to-gguf.py` to convert them, then `quantize_gguf.exe` to quantize the result.
|
68 |
+
- Other models for Whisper (speech recognition), Image Generation, Text to Speech or Image Recognition [can be found on the Wiki](https://github.com/LostRuins/koboldcpp/wiki#what-models-does-koboldcpp-support-what-architectures-are-supported)
|
69 |
+
|
70 |
+
## Improving Performance
|
71 |
+
- **GPU Acceleration**: If you're on Windows with an Nvidia GPU you can get CUDA support out of the box using the `--usecublas` flag (Nvidia Only), or `--usevulkan` (Any GPU), make sure you select the correct .exe with CUDA support.
|
72 |
+
- **GPU Layer Offloading**: Add `--gpulayers` to offload model layers to the GPU. The more layers you offload to VRAM, the faster generation speed will become. Experiment to determine number of layers to offload, and reduce by a few if you run out of memory.
|
73 |
+
- **Increasing Context Size**: Use `--contextsize (number)` to increase context size, allowing the model to read more text. Note that you may also need to increase the max context in the KoboldAI Lite UI as well (click and edit the number text field).
|
74 |
+
- **Old CPU Compatibility**: If you are having crashes or issues, you can try running in a non-avx2 compatibility mode by adding the `--noavx2` flag. You can also try turning off mmap with `--nommap` or reducing your `--blasbatchssize` (set -1 to avoid batching)
|
75 |
+
|
76 |
+
For more information, be sure to run the program with the `--help` flag, or **[check the wiki](https://github.com/LostRuins/koboldcpp/wiki).**
|
77 |
+
|
78 |
+
## Compiling KoboldCpp From Source Code
|
79 |
+
|
80 |
+
### Compiling on Linux (Using koboldcpp.sh automated compiler script)
|
81 |
+
when you can't use the precompiled binary directly, we provide an automated build script which uses conda to obtain all dependencies, and generates (from source) a ready-to-use a pyinstaller binary for linux users.
|
82 |
+
- Clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
|
83 |
+
- Simply execute the build script with `./koboldcpp.sh dist` and run the generated binary. (Not recommended for systems that already have an existing installation of conda. Dependencies: curl, bzip2)
|
84 |
+
```
|
85 |
+
./koboldcpp.sh # This launches the GUI for easy configuration and launching (X11 required).
|
86 |
+
./koboldcpp.sh --help # List all available terminal commands for using Koboldcpp, you can use koboldcpp.sh the same way as our python script and binaries.
|
87 |
+
./koboldcpp.sh rebuild # Automatically generates a new conda runtime and compiles a fresh copy of the libraries. Do this after updating Koboldcpp to keep everything functional.
|
88 |
+
./koboldcpp.sh dist # Generate your own precompiled binary (Due to the nature of Linux compiling these will only work on distributions equal or newer than your own.)
|
89 |
+
```
|
90 |
+
|
91 |
+
### Compiling on Linux (Manual Method)
|
92 |
+
- To compile your binaries from source, clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
|
93 |
+
- A makefile is provided, simply run `make`.
|
94 |
+
- Optional Vulkan: Link your own install of Vulkan SDK manually with `make LLAMA_VULKAN=1`
|
95 |
+
- Optional CLBlast: Link your own install of CLBlast manually with `make LLAMA_CLBLAST=1`
|
96 |
+
- Note: for these you will need to obtain and link OpenCL and CLBlast libraries.
|
97 |
+
- For Arch Linux: Install `cblas` and `clblast`.
|
98 |
+
- For Debian: Install `libclblast-dev`.
|
99 |
+
- You can attempt a CuBLAS build with `LLAMA_CUBLAS=1`, (or `LLAMA_HIPBLAS=1` for AMD). You will need CUDA Toolkit installed. Some have also reported success with the CMake file, though that is more for windows.
|
100 |
+
- For a full featured build (all backends), do `make LLAMA_CLBLAST=1 LLAMA_CUBLAS=1 LLAMA_VULKAN=1`. (Note that `LLAMA_CUBLAS=1` will not work on windows, you need visual studio)
|
101 |
+
- To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`
|
102 |
+
- After all binaries are built, you can run the python script with the command `koboldcpp.py [ggml_model.gguf] [port]`
|
103 |
+
|
104 |
+
### Compiling on Windows
|
105 |
+
- You're encouraged to use the .exe released, but if you want to compile your binaries from source at Windows, the easiest way is:
|
106 |
+
- Get the latest release of w64devkit (https://github.com/skeeto/w64devkit). Be sure to use the "vanilla one", not i686 or other different stuff. If you try they will conflit with the precompiled libs!
|
107 |
+
- Clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
|
108 |
+
- Make sure you are using the w64devkit integrated terminal, then run `make` at the KoboldCpp source folder. This will create the .dll files for a pure CPU native build.
|
109 |
+
- For a full featured build (all backends), do `make LLAMA_CLBLAST=1 LLAMA_VULKAN=1`. (Note that `LLAMA_CUBLAS=1` will not work on windows, you need visual studio)
|
110 |
+
- To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`
|
111 |
+
- If you want to generate the .exe file, make sure you have the python module PyInstaller installed with pip (`pip install PyInstaller`). Then run the script `make_pyinstaller.bat`
|
112 |
+
- The koboldcpp.exe file will be at your dist folder.
|
113 |
+
- **Building with CUDA**: Visual Studio, CMake and CUDA Toolkit is required. Clone the repo, then open the CMake file and compile it in Visual Studio. Copy the `koboldcpp_cublas.dll` generated into the same directory as the `koboldcpp.py` file. If you are bundling executables, you may need to include CUDA dynamic libraries (such as `cublasLt64_11.dll` and `cublas64_11.dll`) in order for the executable to work correctly on a different PC.
|
114 |
+
- **Replacing Libraries (Not Recommended)**: If you wish to use your own version of the additional Windows libraries (OpenCL, CLBlast, Vulkan), you can do it with:
|
115 |
+
- OpenCL - tested with https://github.com/KhronosGroup/OpenCL-SDK . If you wish to compile it, follow the repository instructions. You will need vcpkg.
|
116 |
+
- CLBlast - tested with https://github.com/CNugteren/CLBlast . If you wish to compile it you will need to reference the OpenCL files. It will only generate the ".lib" file if you compile using MSVC.
|
117 |
+
- Move the respectives .lib files to the /lib folder of your project, overwriting the older files.
|
118 |
+
- Also, replace the existing versions of the corresponding .dll files located in the project directory root (e.g. clblast.dll).
|
119 |
+
- Make the KoboldCpp project using the instructions above.
|
120 |
+
|
121 |
+
### Compiling on MacOS
|
122 |
+
- You can compile your binaries from source. You can clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
|
123 |
+
- A makefile is provided, simply run `make`.
|
124 |
+
- If you want Metal GPU support, instead run `make LLAMA_METAL=1`, note that MacOS metal libraries need to be installed.
|
125 |
+
- To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`
|
126 |
+
- After all binaries are built, you can run the python script with the command `koboldcpp.py --model [ggml_model.gguf]` (and add `--gpulayers (number of layer)` if you wish to offload layers to GPU).
|
127 |
+
|
128 |
+
### Compiling on Android (Termux Installation)
|
129 |
+
- [Install and run Termux from F-Droid](https://f-droid.org/en/packages/com.termux/)
|
130 |
+
- Enter the command `termux-change-repo` and choose `Mirror by BFSU`
|
131 |
+
- Install dependencies with `pkg install wget git python` (plus any other missing packages)
|
132 |
+
- Install dependencies `apt install openssl` (if needed)
|
133 |
+
- Clone the repo `git clone https://github.com/LostRuins/koboldcpp.git`
|
134 |
+
- Navigate to the koboldcpp folder `cd koboldcpp`
|
135 |
+
- Build the project `make`
|
136 |
+
- To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`, this disables usage of ARM instrinsics.
|
137 |
+
- Grab a small GGUF model, such as `wget https://huggingface.co/concedo/KobbleTinyV2-1.1B-GGUF/resolve/main/KobbleTiny-Q4_K.gguf`
|
138 |
+
- Start the python server `python koboldcpp.py --model KobbleTiny-Q4_K.gguf`
|
139 |
+
- Connect to `http://localhost:5001` on your mobile browser
|
140 |
+
- If you encounter any errors, make sure your packages are up-to-date with `pkg up`
|
141 |
+
- GPU acceleration for Termux may be possible but I have not explored it. If you find a good cross-device solution, do share or PR it.
|
142 |
+
|
143 |
+
## AMD Users
|
144 |
+
- For most users, you can get very decent speeds by selecting the **Vulkan** option instead, which supports both Nvidia and AMD GPUs.
|
145 |
+
- Alternatively, you can try the ROCM fork at https://github.com/YellowRoseCx/koboldcpp-rocm
|
146 |
+
|
147 |
+
## Third Party Resources
|
148 |
+
- These unofficial resources have been contributed by the community, and may be outdated or unmaintained. No official support will be provided for them!
|
149 |
+
- Arch Linux Packages: [CUBLAS](https://aur.archlinux.org/packages/koboldcpp-cuda), and [HIPBLAS](https://aur.archlinux.org/packages/koboldcpp-hipblas).
|
150 |
+
- Unofficial Dockers: [korewaChino](https://github.com/korewaChino/koboldCppDocker) and [noneabove1182](https://github.com/noneabove1182/koboldcpp-docker)
|
151 |
+
- Nix & NixOS: KoboldCpp is available on Nixpkgs and can be installed by adding just `koboldcpp` to your `environment.systemPackages` *(or it can also be placed in `home.packages`)*.
|
152 |
+
- [Example Nix Setup and further information](examples/nix_example.md)
|
153 |
+
- If you face any issues with running KoboldCpp on Nix, please open an issue [here](https://github.com/NixOS/nixpkgs/issues/new?assignees=&labels=0.kind%3A+bug&projects=&template=bug_report.md&title=).
|
154 |
+
- [GPTLocalhost](https://gptlocalhost.com/demo#KoboldCpp) - KoboldCpp is supported by GPTLocalhost, a local Word Add-in for you to use KoboldCpp in Microsoft Word. A local alternative to "Copilot in Word."
|
155 |
+
|
156 |
+
## Questions and Help Wiki
|
157 |
+
- **First, please check out [The KoboldCpp FAQ and Knowledgebase](https://github.com/LostRuins/koboldcpp/wiki) which may already have answers to your questions! Also please search through past issues and discussions.**
|
158 |
+
- If you cannot find an answer, open an issue on this github, or find us on the [KoboldAI Discord](https://koboldai.org/discord).
|
159 |
+
|
160 |
+
## KoboldCpp and KoboldAI API Documentation
|
161 |
+
- [Documentation for KoboldAI and KoboldCpp endpoints can be found here](https://lite.koboldai.net/koboldcpp_api)
|
162 |
+
|
163 |
+
## KoboldCpp Public Demo
|
164 |
+
- [A public KoboldCpp demo can be found at our Huggingface Space. Please do not abuse it.](https://koboldai-koboldcpp-tiefighter.hf.space/)
|
165 |
+
|
166 |
+
## Considerations
|
167 |
+
- For Windows: No installation, single file executable, (It Just Works)
|
168 |
+
- Since v1.15, requires CLBlast if enabled, the prebuilt windows binaries are included in this repo. If not found, it will fall back to a mode without CLBlast.
|
169 |
+
- Since v1.33, you can set the context size to be above what the model supports officially. It does increases perplexity but should still work well below 4096 even on untuned models. (For GPT-NeoX, GPT-J, and Llama models) Customize this with `--ropeconfig`.
|
170 |
+
- Since v1.42, supports GGUF models for LLAMA and Falcon
|
171 |
+
- Since v1.55, lcuda paths on Linux are hardcoded and may require manual changes to the makefile if you do not use koboldcpp.sh for the compilation.
|
172 |
+
- Since v1.60, provides native image generation with StableDiffusion.cpp, you can load any SD1.5 or SDXL .safetensors model and it will provide an A1111 compatible API to use.
|
173 |
+
- **I try to keep backwards compatibility with ALL past llama.cpp models**. But you are also encouraged to reconvert/update your models if possible for best results.
|
174 |
+
- Since v1.75, openblas has been deprecated and removed in favor of the native CPU implementation.
|
175 |
+
|
176 |
+
## License
|
177 |
+
- The original GGML library and llama.cpp by ggerganov are licensed under the MIT License
|
178 |
+
- However, KoboldAI Lite is licensed under the AGPL v3.0 License
|
179 |
+
- KoboldCpp code and other files are also under the AGPL v3.0 License unless otherwise stated
|
180 |
+
|
181 |
+
## Notes
|
182 |
+
- If you wish, after building the koboldcpp libraries with `make`, you can rebuild the exe yourself with pyinstaller by using `make_pyinstaller.bat`
|
183 |
+
- API documentation available at `/api` (e.g. `http://localhost:5001/api`) and https://lite.koboldai.net/koboldcpp_api. An OpenAI compatible API is also provided at `/v1` route (e.g. `http://localhost:5001/v1`).
|
184 |
+
- **All up-to-date GGUF models are supported**, and KoboldCpp also includes backward compatibility for older versions/legacy GGML `.bin` models, though some newer features might be unavailable.
|
185 |
+
- An incomplete list of architectures is listed, but there are *many hundreds of other GGUF models*. In general, if it's GGUF, it should work.
|
186 |
+
- Llama / Llama2 / Llama3 / Alpaca / GPT4All / Vicuna / Koala / Pygmalion / Metharme / WizardLM / Mistral / Mixtral / Miqu / Qwen / Qwen2 / Yi / Gemma / Gemma2 / GPT-2 / Cerebras / Phi-2 / Phi-3 / GPT-NeoX / Pythia / StableLM / Dolly / RedPajama / GPT-J / RWKV4 / MPT / Falcon / Starcoder / Deepseek and many, **many** more.
|
187 |
+
|
188 |
+
# Where can I download AI model files?
|
189 |
+
- The best place to get GGUF text models is huggingface. For image models, CivitAI has a good selection. Here are some to get started.
|
190 |
+
- Text Generation: [Airoboros Mistral 7B](https://huggingface.co/TheBloke/airoboros-mistral2.2-7B-GGUF/resolve/main/airoboros-mistral2.2-7b.Q4_K_S.gguf) (smaller and weaker) or [Tiefighter 13B](https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf) (larger model) or [Beepo 22B](https://huggingface.co/concedo/Beepo-22B-GGUF/resolve/main/Beepo-22B-Q4_K_S.gguf) (largest and most powerful)
|
191 |
+
- Image Generation: [Anything v3](https://huggingface.co/admruul/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp16.safetensors) or [Deliberate V2](https://huggingface.co/Yntec/Deliberate2/resolve/main/Deliberate_v2.safetensors) or [Dreamshaper SDXL](https://huggingface.co/Lykon/dreamshaper-xl-v2-turbo/resolve/main/DreamShaperXL_Turbo_v2_1.safetensors)
|
192 |
+
- Image Recognition MMproj: [Pick the correct one for your model architecture here](https://huggingface.co/koboldcpp/mmproj/tree/main)
|
193 |
+
- Speech Recognition: [Whisper models for Speech-To-Text](https://huggingface.co/koboldcpp/whisper/tree/main)
|
194 |
+
- Text-To-Speech: [TTS models for Narration](https://huggingface.co/koboldcpp/tts/tree/main)
|
Remote-Link.cmd
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
: # This script will help setup a cloudflared tunnel for accessing KoboldCpp over the internet
|
2 |
+
: # It should work out of the box on both linux and windows
|
3 |
+
: # ======
|
4 |
+
: # WINDOWS PORTION
|
5 |
+
:<<BATCH
|
6 |
+
@echo off
|
7 |
+
echo Starting Cloudflare Tunnel for Windows
|
8 |
+
curl -L https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-amd64.exe -o cloudflared.exe
|
9 |
+
cloudflared.exe tunnel --url localhost:5001
|
10 |
+
GOTO ENDING
|
11 |
+
BATCH
|
12 |
+
: # LINUX PORTION
|
13 |
+
echo 'Starting Cloudflare Tunnel for Linux'
|
14 |
+
curl -L https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -o 'cloudflared-linux-amd64' #
|
15 |
+
chmod +x 'cloudflared-linux-amd64' #
|
16 |
+
./cloudflared-linux-amd64 tunnel --url http://localhost:5001 #
|
17 |
+
exit #
|
18 |
+
:ENDING
|
build-info.h
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef BUILD_INFO_H
|
2 |
+
#define BUILD_INFO_H
|
3 |
+
|
4 |
+
#define LLAMA_BUILD_NUMBER 999
|
5 |
+
#define LLAMA_COMMIT "KOBOLDCPP"
|
6 |
+
#define LLAMA_COMPILER "KCPP"
|
7 |
+
#define LLAMA_TARGET "KCPP"
|
8 |
+
#define LLAMA_BUILD_COMMIT "KOBOLDCPP"
|
9 |
+
#define LLAMA_BUILD_COMPILER "KCPP"
|
10 |
+
#define LLAMA_BUILD_TARGET "KCPP"
|
11 |
+
|
12 |
+
#endif // BUILD_INFO_H
|
build-xcframework.sh
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#
|
3 |
+
# Options
|
4 |
+
IOS_MIN_OS_VERSION=16.4
|
5 |
+
MACOS_MIN_OS_VERSION=13.3
|
6 |
+
VISIONOS_MIN_OS_VERSION=1.0
|
7 |
+
TVOS_MIN_OS_VERSION=16.4
|
8 |
+
|
9 |
+
BUILD_SHARED_LIBS=OFF
|
10 |
+
LLAMA_BUILD_EXAMPLES=OFF
|
11 |
+
LLAMA_BUILD_TESTS=OFF
|
12 |
+
LLAMA_BUILD_SERVER=OFF
|
13 |
+
GGML_METAL=ON
|
14 |
+
GGML_METAL_EMBED_LIBRARY=ON
|
15 |
+
GGML_BLAS_DEFAULT=ON
|
16 |
+
GGML_METAL_USE_BF16=ON
|
17 |
+
GGML_OPENMP=OFF
|
18 |
+
|
19 |
+
COMMON_C_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
|
20 |
+
COMMON_CXX_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
|
21 |
+
|
22 |
+
# Common options for all builds
|
23 |
+
COMMON_CMAKE_ARGS=(
|
24 |
+
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_REQUIRED=NO
|
25 |
+
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY=""
|
26 |
+
-DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED=NO
|
27 |
+
-DCMAKE_XCODE_ATTRIBUTE_DEBUG_INFORMATION_FORMAT="dwarf-with-dsym"
|
28 |
+
-DCMAKE_XCODE_ATTRIBUTE_GCC_GENERATE_DEBUGGING_SYMBOLS=YES
|
29 |
+
-DCMAKE_XCODE_ATTRIBUTE_COPY_PHASE_STRIP=NO
|
30 |
+
-DCMAKE_XCODE_ATTRIBUTE_STRIP_INSTALLED_PRODUCT=NO
|
31 |
+
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
|
32 |
+
-DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS}
|
33 |
+
-DLLAMA_BUILD_EXAMPLES=${LLAMA_BUILD_EXAMPLES}
|
34 |
+
-DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS}
|
35 |
+
-DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER}
|
36 |
+
-DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY}
|
37 |
+
-DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT}
|
38 |
+
-DGGML_METAL=${GGML_METAL}
|
39 |
+
-DGGML_METAL_USE_BF16=${GGML_METAL_USE_BF16}
|
40 |
+
-DGGML_NATIVE=OFF
|
41 |
+
-DGGML_OPENMP=${GGML_OPENMP}
|
42 |
+
)
|
43 |
+
|
44 |
+
check_required_tool() {
|
45 |
+
local tool=$1
|
46 |
+
local install_message=$2
|
47 |
+
|
48 |
+
if ! command -v $tool &> /dev/null; then
|
49 |
+
echo "Error: $tool is required but not found."
|
50 |
+
echo "$install_message"
|
51 |
+
exit 1
|
52 |
+
fi
|
53 |
+
}
|
54 |
+
echo "Checking for required tools..."
|
55 |
+
check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
|
56 |
+
check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
57 |
+
check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
|
58 |
+
check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
|
59 |
+
|
60 |
+
set -e
|
61 |
+
|
62 |
+
## Clean up previous builds
|
63 |
+
rm -rf build-apple
|
64 |
+
rm -rf build-ios-sim
|
65 |
+
rm -rf build-ios-device
|
66 |
+
rm -rf build-macos
|
67 |
+
rm -rf build-visionos
|
68 |
+
rm -rf build-visionos-sim
|
69 |
+
rm -rf build-tvos-sim
|
70 |
+
rm -rf build-tvos-device
|
71 |
+
|
72 |
+
# Setup the xcframework build directory structure
|
73 |
+
setup_framework_structure() {
|
74 |
+
local build_dir=$1
|
75 |
+
local min_os_version=$2
|
76 |
+
local platform=$3 # "ios", "macos", "visionos", or "tvos"
|
77 |
+
local framework_name="llama"
|
78 |
+
|
79 |
+
echo "Creating ${platform}-style framework structure for ${build_dir}"
|
80 |
+
|
81 |
+
if [[ "$platform" == "macos" ]]; then
|
82 |
+
# macOS versioned structure uses versioned directories
|
83 |
+
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Headers
|
84 |
+
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Modules
|
85 |
+
mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Resources
|
86 |
+
|
87 |
+
# Create symbolic links
|
88 |
+
ln -sf A ${build_dir}/framework/${framework_name}.framework/Versions/Current
|
89 |
+
ln -sf Versions/Current/Headers ${build_dir}/framework/${framework_name}.framework/Headers
|
90 |
+
ln -sf Versions/Current/Modules ${build_dir}/framework/${framework_name}.framework/Modules
|
91 |
+
ln -sf Versions/Current/Resources ${build_dir}/framework/${framework_name}.framework/Resources
|
92 |
+
ln -sf Versions/Current/${framework_name} ${build_dir}/framework/${framework_name}.framework/${framework_name}
|
93 |
+
|
94 |
+
# Set header and module paths
|
95 |
+
local header_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Headers/
|
96 |
+
local module_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Modules/
|
97 |
+
else
|
98 |
+
# iOS/VisionOS/tvOS use a flat structure
|
99 |
+
mkdir -p ${build_dir}/framework/${framework_name}.framework/Headers
|
100 |
+
mkdir -p ${build_dir}/framework/${framework_name}.framework/Modules
|
101 |
+
|
102 |
+
# Remove any existing structure to ensure clean build
|
103 |
+
rm -rf ${build_dir}/framework/${framework_name}.framework/Versions
|
104 |
+
|
105 |
+
# Set header and module paths
|
106 |
+
local header_path=${build_dir}/framework/${framework_name}.framework/Headers/
|
107 |
+
local module_path=${build_dir}/framework/${framework_name}.framework/Modules/
|
108 |
+
fi
|
109 |
+
|
110 |
+
# Copy all required headers (common for all platforms)
|
111 |
+
cp include/llama.h ${header_path}
|
112 |
+
cp ggml/include/ggml.h ${header_path}
|
113 |
+
cp ggml/include/ggml-alloc.h ${header_path}
|
114 |
+
cp ggml/include/ggml-backend.h ${header_path}
|
115 |
+
cp ggml/include/ggml-metal.h ${header_path}
|
116 |
+
cp ggml/include/ggml-cpu.h ${header_path}
|
117 |
+
cp ggml/include/ggml-blas.h ${header_path}
|
118 |
+
cp ggml/include/gguf.h ${header_path}
|
119 |
+
|
120 |
+
# Create module map (common for all platforms)
|
121 |
+
cat > ${module_path}module.modulemap << EOF
|
122 |
+
framework module llama {
|
123 |
+
header "llama.h"
|
124 |
+
header "ggml.h"
|
125 |
+
header "ggml-alloc.h"
|
126 |
+
header "ggml-backend.h"
|
127 |
+
header "ggml-metal.h"
|
128 |
+
header "ggml-cpu.h"
|
129 |
+
header "ggml-blas.h"
|
130 |
+
header "gguf.h"
|
131 |
+
|
132 |
+
link "c++"
|
133 |
+
link framework "Accelerate"
|
134 |
+
link framework "Metal"
|
135 |
+
link framework "Foundation"
|
136 |
+
|
137 |
+
export *
|
138 |
+
}
|
139 |
+
EOF
|
140 |
+
|
141 |
+
# Platform-specific settings for Info.plist
|
142 |
+
local platform_name=""
|
143 |
+
local sdk_name=""
|
144 |
+
local supported_platform=""
|
145 |
+
|
146 |
+
case "$platform" in
|
147 |
+
"ios")
|
148 |
+
platform_name="iphoneos"
|
149 |
+
sdk_name="iphoneos${min_os_version}"
|
150 |
+
supported_platform="iPhoneOS"
|
151 |
+
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
|
152 |
+
local device_family=' <key>UIDeviceFamily</key>
|
153 |
+
<array>
|
154 |
+
<integer>1</integer>
|
155 |
+
<integer>2</integer>
|
156 |
+
</array>'
|
157 |
+
;;
|
158 |
+
"macos")
|
159 |
+
platform_name="macosx"
|
160 |
+
sdk_name="macosx${min_os_version}"
|
161 |
+
supported_platform="MacOSX"
|
162 |
+
local plist_path="${build_dir}/framework/${framework_name}.framework/Versions/A/Resources/Info.plist"
|
163 |
+
local device_family=""
|
164 |
+
;;
|
165 |
+
"visionos")
|
166 |
+
platform_name="xros"
|
167 |
+
sdk_name="xros${min_os_version}"
|
168 |
+
supported_platform="XRPlatform"
|
169 |
+
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
|
170 |
+
local device_family=""
|
171 |
+
;;
|
172 |
+
"tvos")
|
173 |
+
platform_name="appletvos"
|
174 |
+
sdk_name="appletvos${min_os_version}"
|
175 |
+
supported_platform="AppleTVOS"
|
176 |
+
local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
|
177 |
+
local device_family=' <key>UIDeviceFamily</key>
|
178 |
+
<array>
|
179 |
+
<integer>3</integer>
|
180 |
+
</array>'
|
181 |
+
;;
|
182 |
+
esac
|
183 |
+
|
184 |
+
# Create Info.plist
|
185 |
+
cat > ${plist_path} << EOF
|
186 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
187 |
+
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
188 |
+
<plist version="1.0">
|
189 |
+
<dict>
|
190 |
+
<key>CFBundleDevelopmentRegion</key>
|
191 |
+
<string>en</string>
|
192 |
+
<key>CFBundleExecutable</key>
|
193 |
+
<string>llama</string>
|
194 |
+
<key>CFBundleIdentifier</key>
|
195 |
+
<string>org.ggml.llama</string>
|
196 |
+
<key>CFBundleInfoDictionaryVersion</key>
|
197 |
+
<string>6.0</string>
|
198 |
+
<key>CFBundleName</key>
|
199 |
+
<string>llama</string>
|
200 |
+
<key>CFBundlePackageType</key>
|
201 |
+
<string>FMWK</string>
|
202 |
+
<key>CFBundleShortVersionString</key>
|
203 |
+
<string>1.0</string>
|
204 |
+
<key>CFBundleVersion</key>
|
205 |
+
<string>1</string>
|
206 |
+
<key>MinimumOSVersion</key>
|
207 |
+
<string>${min_os_version}</string>
|
208 |
+
<key>CFBundleSupportedPlatforms</key>
|
209 |
+
<array>
|
210 |
+
<string>${supported_platform}</string>
|
211 |
+
</array>${device_family}
|
212 |
+
<key>DTPlatformName</key>
|
213 |
+
<string>${platform_name}</string>
|
214 |
+
<key>DTSDKName</key>
|
215 |
+
<string>${sdk_name}</string>
|
216 |
+
</dict>
|
217 |
+
</plist>
|
218 |
+
EOF
|
219 |
+
}
|
220 |
+
|
221 |
+
# Create dynamic libraries from static libraries.
|
222 |
+
combine_static_libraries() {
|
223 |
+
local build_dir="$1"
|
224 |
+
local release_dir="$2"
|
225 |
+
local platform="$3" # "ios", "macos", "visionos", or "tvos"
|
226 |
+
local is_simulator="$4"
|
227 |
+
local base_dir="$(pwd)"
|
228 |
+
local framework_name="llama"
|
229 |
+
|
230 |
+
# Determine output path based on platform
|
231 |
+
local output_lib=""
|
232 |
+
if [[ "$platform" == "macos" ]]; then
|
233 |
+
# macOS uses versioned structure
|
234 |
+
output_lib="${build_dir}/framework/${framework_name}.framework/Versions/A/${framework_name}"
|
235 |
+
else
|
236 |
+
# iOS, visionOS, and tvOS use a directory flat structure
|
237 |
+
output_lib="${build_dir}/framework/${framework_name}.framework/${framework_name}"
|
238 |
+
fi
|
239 |
+
|
240 |
+
local libs=(
|
241 |
+
"${base_dir}/${build_dir}/src/${release_dir}/libllama.a"
|
242 |
+
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml.a"
|
243 |
+
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-base.a"
|
244 |
+
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a"
|
245 |
+
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
|
246 |
+
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
|
247 |
+
)
|
248 |
+
|
249 |
+
# Create temporary directory for processing
|
250 |
+
local temp_dir="${base_dir}/${build_dir}/temp"
|
251 |
+
mkdir -p "${temp_dir}"
|
252 |
+
|
253 |
+
# Since we have multiple architectures libtool will find object files that do not
|
254 |
+
# match the target architecture. We suppress these warnings.
|
255 |
+
libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
|
256 |
+
|
257 |
+
# Determine SDK, architectures, and install_name based on platform and simulator flag.
|
258 |
+
local sdk=""
|
259 |
+
local archs=""
|
260 |
+
local min_version_flag=""
|
261 |
+
local install_name=""
|
262 |
+
|
263 |
+
case "$platform" in
|
264 |
+
"ios")
|
265 |
+
if [[ "$is_simulator" == "true" ]]; then
|
266 |
+
sdk="iphonesimulator"
|
267 |
+
archs="arm64 x86_64"
|
268 |
+
min_version_flag="-mios-simulator-version-min=${IOS_MIN_OS_VERSION}"
|
269 |
+
else
|
270 |
+
sdk="iphoneos"
|
271 |
+
archs="arm64"
|
272 |
+
min_version_flag="-mios-version-min=${IOS_MIN_OS_VERSION}"
|
273 |
+
fi
|
274 |
+
install_name="@rpath/llama.framework/llama"
|
275 |
+
;;
|
276 |
+
"macos")
|
277 |
+
sdk="macosx"
|
278 |
+
archs="arm64 x86_64"
|
279 |
+
min_version_flag="-mmacosx-version-min=${MACOS_MIN_OS_VERSION}"
|
280 |
+
install_name="@rpath/llama.framework/Versions/Current/llama"
|
281 |
+
;;
|
282 |
+
"visionos")
|
283 |
+
if [[ "$is_simulator" == "true" ]]; then
|
284 |
+
sdk="xrsimulator"
|
285 |
+
archs="arm64 x86_64"
|
286 |
+
min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}-simulator"
|
287 |
+
else
|
288 |
+
sdk="xros"
|
289 |
+
archs="arm64"
|
290 |
+
min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}"
|
291 |
+
fi
|
292 |
+
# Use flat structure for visionOS, same as iOS
|
293 |
+
install_name="@rpath/llama.framework/llama"
|
294 |
+
;;
|
295 |
+
"tvos")
|
296 |
+
if [[ "$is_simulator" == "true" ]]; then
|
297 |
+
sdk="appletvsimulator"
|
298 |
+
archs="arm64 x86_64"
|
299 |
+
min_version_flag="-mtvos-simulator-version-min=${TVOS_MIN_OS_VERSION}"
|
300 |
+
else
|
301 |
+
sdk="appletvos"
|
302 |
+
archs="arm64"
|
303 |
+
min_version_flag="-mtvos-version-min=${TVOS_MIN_OS_VERSION}"
|
304 |
+
fi
|
305 |
+
install_name="@rpath/llama.framework/llama"
|
306 |
+
;;
|
307 |
+
esac
|
308 |
+
|
309 |
+
# Build architecture flags
|
310 |
+
local arch_flags=""
|
311 |
+
for arch in $archs; do
|
312 |
+
arch_flags+=" -arch $arch"
|
313 |
+
done
|
314 |
+
|
315 |
+
# Create dynamic library
|
316 |
+
echo "Creating dynamic library for ${platform}."
|
317 |
+
xcrun -sdk $sdk clang++ -dynamiclib \
|
318 |
+
-isysroot $(xcrun --sdk $sdk --show-sdk-path) \
|
319 |
+
$arch_flags \
|
320 |
+
$min_version_flag \
|
321 |
+
-Wl,-force_load,"${temp_dir}/combined.a" \
|
322 |
+
-framework Foundation -framework Metal -framework Accelerate \
|
323 |
+
-install_name "$install_name" \
|
324 |
+
-o "${base_dir}/${output_lib}"
|
325 |
+
|
326 |
+
# Platform-specific post-processing for device builds
|
327 |
+
if [[ "$is_simulator" == "false" ]]; then
|
328 |
+
if command -v vtool &>/dev/null; then
|
329 |
+
case "$platform" in
|
330 |
+
"ios")
|
331 |
+
echo "Marking binary as a framework binary for iOS..."
|
332 |
+
vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
|
333 |
+
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
334 |
+
;;
|
335 |
+
"visionos")
|
336 |
+
echo "Marking binary as a framework binary for visionOS..."
|
337 |
+
vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
|
338 |
+
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
339 |
+
;;
|
340 |
+
"tvos")
|
341 |
+
echo "Marking binary as a framework binary for tvOS..."
|
342 |
+
vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
|
343 |
+
-output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
|
344 |
+
;;
|
345 |
+
esac
|
346 |
+
else
|
347 |
+
echo "Warning: vtool not found. Binary may not pass App Store validation."
|
348 |
+
fi
|
349 |
+
fi
|
350 |
+
|
351 |
+
echo "Creating properly formatted dSYM..."
|
352 |
+
# Create a separate directory for dSYMs for all platforms
|
353 |
+
mkdir -p "${base_dir}/${build_dir}/dSYMs"
|
354 |
+
|
355 |
+
# iOS and visionOS style dSYM (flat structure)
|
356 |
+
if [[ "$platform" == "ios" || "$platform" == "visionos" || "$platform" == "tvos" ]]; then
|
357 |
+
# Generate dSYM in the dSYMs directory
|
358 |
+
xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
|
359 |
+
|
360 |
+
# Create a copy of the binary that will be stripped
|
361 |
+
cp "${base_dir}/${output_lib}" "${temp_dir}/binary_to_strip"
|
362 |
+
|
363 |
+
# Strip debug symbols from the copy
|
364 |
+
xcrun strip -S "${temp_dir}/binary_to_strip" -o "${temp_dir}/stripped_lib"
|
365 |
+
|
366 |
+
# Replace the original with the stripped version
|
367 |
+
mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
|
368 |
+
else
|
369 |
+
# macOS style dSYM
|
370 |
+
# First strip debug info to a separate file
|
371 |
+
xcrun strip -S "${base_dir}/${output_lib}" -o "${temp_dir}/stripped_lib"
|
372 |
+
|
373 |
+
# Generate dSYM in the dSYMs directory
|
374 |
+
xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
|
375 |
+
|
376 |
+
# Replace original binary with stripped version
|
377 |
+
mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
|
378 |
+
fi
|
379 |
+
|
380 |
+
# Remove any automatically generated dSYM files in the framework structure as they will
|
381 |
+
# otherwise case Invalid Bundle Structure validation errors.
|
382 |
+
if [ -d "${base_dir}/${output_lib}.dSYM" ]; then
|
383 |
+
echo "Removing generated dSYM file in framework structure: ${base_dir}/${output_lib}.dSYM"
|
384 |
+
rm -rf "${base_dir}/${output_lib}.dSYM"
|
385 |
+
fi
|
386 |
+
|
387 |
+
# Clean up
|
388 |
+
rm -rf "${temp_dir}"
|
389 |
+
}
|
390 |
+
|
391 |
+
echo "Building for iOS simulator..."
|
392 |
+
cmake -B build-ios-sim -G Xcode \
|
393 |
+
"${COMMON_CMAKE_ARGS[@]}" \
|
394 |
+
-DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
|
395 |
+
-DIOS=ON \
|
396 |
+
-DCMAKE_SYSTEM_NAME=iOS \
|
397 |
+
-DCMAKE_OSX_SYSROOT=iphonesimulator \
|
398 |
+
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
399 |
+
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphonesimulator \
|
400 |
+
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
401 |
+
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
402 |
+
-S .
|
403 |
+
cmake --build build-ios-sim --config Release -- -quiet
|
404 |
+
|
405 |
+
echo "Building for iOS devices..."
|
406 |
+
cmake -B build-ios-device -G Xcode \
|
407 |
+
"${COMMON_CMAKE_ARGS[@]}" \
|
408 |
+
-DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
|
409 |
+
-DCMAKE_OSX_SYSROOT=iphoneos \
|
410 |
+
-DCMAKE_OSX_ARCHITECTURES="arm64" \
|
411 |
+
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \
|
412 |
+
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
413 |
+
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
414 |
+
-S .
|
415 |
+
cmake --build build-ios-device --config Release -- -quiet
|
416 |
+
|
417 |
+
echo "Building for macOS..."
|
418 |
+
cmake -B build-macos -G Xcode \
|
419 |
+
"${COMMON_CMAKE_ARGS[@]}" \
|
420 |
+
-DCMAKE_OSX_DEPLOYMENT_TARGET=${MACOS_MIN_OS_VERSION} \
|
421 |
+
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
422 |
+
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
423 |
+
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
424 |
+
-S .
|
425 |
+
cmake --build build-macos --config Release -- -quiet
|
426 |
+
|
427 |
+
echo "Building for visionOS..."
|
428 |
+
cmake -B build-visionos -G Xcode \
|
429 |
+
"${COMMON_CMAKE_ARGS[@]}" \
|
430 |
+
-DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
|
431 |
+
-DCMAKE_OSX_ARCHITECTURES="arm64" \
|
432 |
+
-DCMAKE_SYSTEM_NAME=visionOS \
|
433 |
+
-DCMAKE_OSX_SYSROOT=xros \
|
434 |
+
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
|
435 |
+
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
|
436 |
+
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
|
437 |
+
-S .
|
438 |
+
cmake --build build-visionos --config Release -- -quiet
|
439 |
+
|
440 |
+
echo "Building for visionOS simulator..."
|
441 |
+
cmake -B build-visionos-sim -G Xcode \
|
442 |
+
"${COMMON_CMAKE_ARGS[@]}" \
|
443 |
+
-DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
|
444 |
+
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
445 |
+
-DCMAKE_SYSTEM_NAME=visionOS \
|
446 |
+
-DCMAKE_OSX_SYSROOT=xrsimulator \
|
447 |
+
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
|
448 |
+
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
|
449 |
+
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
|
450 |
+
-S .
|
451 |
+
cmake --build build-visionos-sim --config Release -- -quiet
|
452 |
+
|
453 |
+
# Add tvOS builds (might need the same u_int definitions as watchOS and visionOS)
|
454 |
+
echo "Building for tvOS simulator..."
|
455 |
+
cmake -B build-tvos-sim -G Xcode \
|
456 |
+
"${COMMON_CMAKE_ARGS[@]}" \
|
457 |
+
-DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
|
458 |
+
-DCMAKE_SYSTEM_NAME=tvOS \
|
459 |
+
-DCMAKE_OSX_SYSROOT=appletvsimulator \
|
460 |
+
-DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
|
461 |
+
-DGGML_METAL=ON \
|
462 |
+
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvsimulator \
|
463 |
+
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
464 |
+
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
465 |
+
-S .
|
466 |
+
cmake --build build-tvos-sim --config Release -- -quiet
|
467 |
+
|
468 |
+
echo "Building for tvOS devices..."
|
469 |
+
cmake -B build-tvos-device -G Xcode \
|
470 |
+
"${COMMON_CMAKE_ARGS[@]}" \
|
471 |
+
-DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
|
472 |
+
-DCMAKE_SYSTEM_NAME=tvOS \
|
473 |
+
-DCMAKE_OSX_SYSROOT=appletvos \
|
474 |
+
-DCMAKE_OSX_ARCHITECTURES="arm64" \
|
475 |
+
-DGGML_METAL=ON \
|
476 |
+
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvos \
|
477 |
+
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
478 |
+
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
479 |
+
-S .
|
480 |
+
cmake --build build-tvos-device --config Release -- -quiet
|
481 |
+
|
482 |
+
# Setup frameworks and copy binaries and headers
|
483 |
+
echo "Setting up framework structures..."
|
484 |
+
setup_framework_structure "build-ios-sim" ${IOS_MIN_OS_VERSION} "ios"
|
485 |
+
setup_framework_structure "build-ios-device" ${IOS_MIN_OS_VERSION} "ios"
|
486 |
+
setup_framework_structure "build-macos" ${MACOS_MIN_OS_VERSION} "macos"
|
487 |
+
setup_framework_structure "build-visionos" ${VISIONOS_MIN_OS_VERSION} "visionos"
|
488 |
+
setup_framework_structure "build-visionos-sim" ${VISIONOS_MIN_OS_VERSION} "visionos"
|
489 |
+
setup_framework_structure "build-tvos-sim" ${TVOS_MIN_OS_VERSION} "tvos"
|
490 |
+
setup_framework_structure "build-tvos-device" ${TVOS_MIN_OS_VERSION} "tvos"
|
491 |
+
|
492 |
+
# Create dynamic libraries from static libraries
|
493 |
+
echo "Creating dynamic libraries from static libraries..."
|
494 |
+
combine_static_libraries "build-ios-sim" "Release-iphonesimulator" "ios" "true"
|
495 |
+
combine_static_libraries "build-ios-device" "Release-iphoneos" "ios" "false"
|
496 |
+
combine_static_libraries "build-macos" "Release" "macos" "false"
|
497 |
+
combine_static_libraries "build-visionos" "Release-xros" "visionos" "false"
|
498 |
+
combine_static_libraries "build-visionos-sim" "Release-xrsimulator" "visionos" "true"
|
499 |
+
combine_static_libraries "build-tvos-sim" "Release-appletvsimulator" "tvos" "true"
|
500 |
+
combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
|
501 |
+
|
502 |
+
# Create XCFramework with correct debug symbols paths
|
503 |
+
echo "Creating XCFramework..."
|
504 |
+
xcodebuild -create-xcframework \
|
505 |
+
-framework $(pwd)/build-ios-sim/framework/llama.framework \
|
506 |
+
-debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
|
507 |
+
-framework $(pwd)/build-ios-device/framework/llama.framework \
|
508 |
+
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
|
509 |
+
-framework $(pwd)/build-macos/framework/llama.framework \
|
510 |
+
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
|
511 |
+
-framework $(pwd)/build-visionos/framework/llama.framework \
|
512 |
+
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
|
513 |
+
-framework $(pwd)/build-visionos-sim/framework/llama.framework \
|
514 |
+
-debug-symbols $(pwd)/build-visionos-sim/dSYMs/llama.dSYM \
|
515 |
+
-framework $(pwd)/build-tvos-device/framework/llama.framework \
|
516 |
+
-debug-symbols $(pwd)/build-tvos-device/dSYMs/llama.dSYM \
|
517 |
+
-framework $(pwd)/build-tvos-sim/framework/llama.framework \
|
518 |
+
-debug-symbols $(pwd)/build-tvos-sim/dSYMs/llama.dSYM \
|
519 |
+
-output $(pwd)/build-apple/llama.xcframework
|
clblast.dll
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0611442b931691d9b3c9bc5ebe7625f17a5c5902e1a2b9e98cbad440d1459625
|
3 |
+
size 5450752
|
colab.ipynb
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"colab_type": "text",
|
7 |
+
"id": "view-in-github"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"<a href=\"https://colab.research.google.com/github/LostRuins/koboldcpp/blob/concedo/colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "markdown",
|
15 |
+
"metadata": {
|
16 |
+
"id": "2FCn5tmpn3UV"
|
17 |
+
},
|
18 |
+
"source": [
|
19 |
+
"## Welcome to the Official KoboldCpp Colab Notebook\n",
|
20 |
+
"It's really easy to get started. Just press the two **Play** buttons below, and then connect to the **Cloudflare URL** shown at the end.\n",
|
21 |
+
"You can select a model from the dropdown, or enter a **custom URL** to a GGUF model (Example: `https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_M.gguf`)\n",
|
22 |
+
"\n",
|
23 |
+
"**Keep this page open and occationally check for captcha's so that your AI is not shut down**"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": null,
|
29 |
+
"metadata": {
|
30 |
+
"id": "QNaj3u0jn3UW"
|
31 |
+
},
|
32 |
+
"outputs": [],
|
33 |
+
"source": [
|
34 |
+
"#@title <-- Tap this if you play on Mobile { display-mode: \"form\" }\n",
|
35 |
+
"%%html\n",
|
36 |
+
"<b>Press play on the music player to keep the tab alive, then start KoboldCpp below</b><br/>\n",
|
37 |
+
"<audio autoplay=\"\" src=\"https://raw.githubusercontent.com/KoboldAI/KoboldAI-Client/main/colab/silence.m4a\" loop controls>"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"metadata": {
|
44 |
+
"cellView": "form",
|
45 |
+
"id": "uJS9i_Dltv8Y"
|
46 |
+
},
|
47 |
+
"outputs": [],
|
48 |
+
"source": [
|
49 |
+
"#@title <b>v-- Enter your model below and then click this to start Koboldcpp</b>\n",
|
50 |
+
"\n",
|
51 |
+
"Model = \"https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf\" #@param [\"https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf\",\"https://huggingface.co/KoboldAI/LLaMA2-13B-Estopia-GGUF/resolve/main/LLaMA2-13B-Estopia.Q4_K_S.gguf\",\"https://huggingface.co/mradermacher/Fimbulvetr-11B-v2-GGUF/resolve/main/Fimbulvetr-11B-v2.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/MythoMax-L2-13B-GGUF/resolve/main/mythomax-l2-13b.Q4_K_M.gguf\",\"https://huggingface.co/TheBloke/ReMM-SLERP-L2-13B-GGUF/resolve/main/remm-slerp-l2-13b.Q4_K_M.gguf\",\"https://huggingface.co/TheBloke/Xwin-LM-13B-v0.2-GGUF/resolve/main/xwin-lm-13b-v0.2.Q4_K_M.gguf\",\"https://huggingface.co/mradermacher/mini-magnum-12b-v1.1-GGUF/resolve/main/mini-magnum-12b-v1.1.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/Stheno-L2-13B-GGUF/resolve/main/stheno-l2-13b.Q4_K_M.gguf\",\"https://huggingface.co/TheBloke/MythoMax-L2-Kimiko-v2-13B-GGUF/resolve/main/mythomax-l2-kimiko-v2-13b.Q4_K_M.gguf\",\"https://huggingface.co/bartowski/Rocinante-12B-v1.1-GGUF/resolve/main/Rocinante-12B-v1.1-Q4_K_S.gguf\",\"https://huggingface.co/KoboldAI/Llama-3.1-8B-BookAdventures-GGUF/resolve/main/Llama-3.1-8B-BookAdventures.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/MistRP-Airoboros-7B-GGUF/resolve/main/mistrp-airoboros-7b.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/airoboros-mistral2.2-7B-GGUF/resolve/main/airoboros-mistral2.2-7b.Q4_K_S.gguf\",\"https://huggingface.co/concedo/KobbleTinyV2-1.1B-GGUF/resolve/main/KobbleTiny-Q4_K.gguf\",\"https://huggingface.co/grimjim/kukulemon-7B-GGUF/resolve/main/kukulemon-7B.Q8_0.gguf\",\"https://huggingface.co/mradermacher/LemonKunoichiWizardV3-GGUF/resolve/main/LemonKunoichiWizardV3.Q4_K_M.gguf\",\"https://huggingface.co/Lewdiculous/Kunoichi-DPO-v2-7B-GGUF-Imatrix/resolve/main/Kunoichi-DPO-v2-7B-Q4_K_M-imatrix.gguf\",\"https://huggingface.co/mradermacher/L3-8B-Stheno-v3.2-i1-GGUF/resolve/main/L3-8B-Stheno-v3.2.i1-Q4_K_M.gguf\",\"https://huggingface.co/Lewdiculous/Llama-3-Lumimaid-8B-v0.1-OAS-GGUF-IQ-Imatrix/resolve/main/v2-Llama-3-Lumimaid-8B-v0.1-OAS-Q4_K_M-imat.gguf\",\"https://huggingface.co/bartowski/NeuralDaredevil-8B-abliterated-GGUF/resolve/main/NeuralDaredevil-8B-abliterated-Q4_K_M.gguf\",\"https://huggingface.co/bartowski/L3-8B-Lunaris-v1-GGUF/resolve/main/L3-8B-Lunaris-v1-Q4_K_M.gguf\",\"https://huggingface.co/mradermacher/L3-Umbral-Mind-RP-v2.0-8B-GGUF/resolve/main/L3-Umbral-Mind-RP-v2.0-8B.Q4_K_M.gguf\",\"https://huggingface.co/bartowski/TheDrummer_Cydonia-24B-v2-GGUF/resolve/main/TheDrummer_Cydonia-24B-v2-Q4_K_S.gguf\",\"https://huggingface.co/bartowski/PocketDoc_Dans-PersonalityEngine-V1.2.0-24b-GGUF/resolve/main/PocketDoc_Dans-PersonalityEngine-V1.2.0-24b-IQ4_XS.gguf\"]{allow-input: true}\n",
|
52 |
+
"Layers = 99 #@param [99]{allow-input: true}\n",
|
53 |
+
"ContextSize = 4096 #@param [4096,8192] {allow-input: true}\n",
|
54 |
+
"FlashAttention = True #@param {type:\"boolean\"}\n",
|
55 |
+
"Multiplayer = False #@param {type:\"boolean\"}\n",
|
56 |
+
"FACommand = \"\"\n",
|
57 |
+
"MPCommand = \"\"\n",
|
58 |
+
"#@markdown <hr>\n",
|
59 |
+
"LoadVisionMMProjector = False #@param {type:\"boolean\"}\n",
|
60 |
+
"Mmproj = \"https://huggingface.co/koboldcpp/mmproj/resolve/main/llama-13b-mmproj-v1.5.Q4_1.gguf\" #@param [\"https://huggingface.co/koboldcpp/mmproj/resolve/main/llama-13b-mmproj-v1.5.Q4_1.gguf\",\"https://huggingface.co/koboldcpp/mmproj/resolve/main/mistral-7b-mmproj-v1.5-Q4_1.gguf\",\"https://huggingface.co/koboldcpp/mmproj/resolve/main/llama-7b-mmproj-v1.5-Q4_0.gguf\",\"https://huggingface.co/koboldcpp/mmproj/resolve/main/LLaMA3-8B_mmproj-Q4_1.gguf\"]{allow-input: true}\n",
|
61 |
+
"VCommand = \"\"\n",
|
62 |
+
"#@markdown <hr>\n",
|
63 |
+
"LoadImgModel = False #@param {type:\"boolean\"}\n",
|
64 |
+
"ImgModel = \"https://huggingface.co/koboldcpp/imgmodel/resolve/main/imgmodel_ftuned_q4_0.gguf\" #@param [\"https://huggingface.co/koboldcpp/imgmodel/resolve/main/imgmodel_ftuned_q4_0.gguf\"]{allow-input: true}\n",
|
65 |
+
"SCommand = \"\"\n",
|
66 |
+
"#@markdown <hr>\n",
|
67 |
+
"LoadSpeechModel = False #@param {type:\"boolean\"}\n",
|
68 |
+
"SpeechModel = \"https://huggingface.co/koboldcpp/whisper/resolve/main/whisper-base.en-q5_1.bin\" #@param [\"https://huggingface.co/koboldcpp/whisper/resolve/main/whisper-base.en-q5_1.bin\"]{allow-input: true}\n",
|
69 |
+
"WCommand = \"\"\n",
|
70 |
+
"#@markdown <hr>\n",
|
71 |
+
"LoadTTSModel = False #@param {type:\"boolean\"}\n",
|
72 |
+
"TTSModel = \"https://huggingface.co/koboldcpp/tts/resolve/main/OuteTTS-0.2-500M-Q4_0.gguf\" #@param [\"https://huggingface.co/koboldcpp/tts/resolve/main/OuteTTS-0.2-500M-Q4_0.gguf\"]{allow-input: true}\n",
|
73 |
+
"WavTokModel = \"https://huggingface.co/koboldcpp/tts/resolve/main/WavTokenizer-Large-75-Q4_0.gguf\" #@param [\"https://huggingface.co/koboldcpp/tts/resolve/main/WavTokenizer-Large-75-Q4_0.gguf\"]{allow-input: true}\n",
|
74 |
+
"TTSCommand = \"\"\n",
|
75 |
+
"#@markdown <hr>\n",
|
76 |
+
"AllowSaveToGoogleDrive = False #@param {type:\"boolean\"}\n",
|
77 |
+
"SavGdriveCommand = \"\"\n",
|
78 |
+
"\n",
|
79 |
+
"import os\n",
|
80 |
+
"if not os.path.isfile(\"/opt/bin/nvidia-smi\"):\n",
|
81 |
+
" raise RuntimeError(\"⚠️Colab did not give you a GPU due to usage limits, this can take a few hours before they let you back in. Check out https://lite.koboldai.net for a free alternative (that does not provide an API link but can load KoboldAI saves and chat cards) or subscribe to Colab Pro for immediate access.⚠️\")\n",
|
82 |
+
"\n",
|
83 |
+
"if AllowSaveToGoogleDrive:\n",
|
84 |
+
" print(\"Attempting to request access to save to your google drive...\")\n",
|
85 |
+
" try:\n",
|
86 |
+
" from google.colab import drive\n",
|
87 |
+
" import os, json\n",
|
88 |
+
" drive.mount('/content/drive', force_remount=True)\n",
|
89 |
+
" if not os.path.exists(\"/content/drive/MyDrive\"):\n",
|
90 |
+
" raise RuntimeError(\"Google Drive mount failed. Please grant permissions and try again.\")\n",
|
91 |
+
" kcppdir = '/content/drive/MyDrive/koboldcpp_data'\n",
|
92 |
+
" os.makedirs(kcppdir, exist_ok=True)\n",
|
93 |
+
" savedatapath = os.path.join(kcppdir, \"koboldcpp_save_db.jsondb\")\n",
|
94 |
+
" if not os.path.exists(savedatapath):\n",
|
95 |
+
" settings_data = {}\n",
|
96 |
+
" with open(savedatapath, \"w\") as json_file:\n",
|
97 |
+
" json.dump(settings_data, json_file, indent=4)\n",
|
98 |
+
" print(f\"Created new koboldcpp_save_db.jsondb at {savedatapath}\")\n",
|
99 |
+
" else:\n",
|
100 |
+
" print(f\"Loading saved data at {savedatapath}\")\n",
|
101 |
+
" SavGdriveCommand = f\" --savedatafile {savedatapath}\"\n",
|
102 |
+
" except Exception as e:\n",
|
103 |
+
" print(f\"⚠️ Error: {e}\")\n",
|
104 |
+
" print(\"Please ensure you grant Google Drive permissions and try again.\")\n",
|
105 |
+
"\n",
|
106 |
+
"%cd /content\n",
|
107 |
+
"if Mmproj and LoadVisionMMProjector:\n",
|
108 |
+
" VCommand = \"--mmproj vmodel.gguf\"\n",
|
109 |
+
"else:\n",
|
110 |
+
" SCommand = \"\"\n",
|
111 |
+
"if ImgModel and LoadImgModel:\n",
|
112 |
+
" SCommand = \"--sdmodel imodel.gguf --sdthreads 4 --sdquant --sdclamped\"\n",
|
113 |
+
"else:\n",
|
114 |
+
" SCommand = \"\"\n",
|
115 |
+
"if SpeechModel and LoadSpeechModel:\n",
|
116 |
+
" WCommand = \"--whispermodel wmodel.bin\"\n",
|
117 |
+
"else:\n",
|
118 |
+
" WCommand = \"\"\n",
|
119 |
+
"if TTSModel and WavTokModel and LoadTTSModel:\n",
|
120 |
+
" TTSCommand = \"--ttsmodel ttsmodel.bin --ttswavtokenizer ttswavtok.bin --ttsgpu\"\n",
|
121 |
+
"else:\n",
|
122 |
+
" TTSCommand = \"\"\n",
|
123 |
+
"if FlashAttention:\n",
|
124 |
+
" FACommand = \"--flashattention\"\n",
|
125 |
+
"else:\n",
|
126 |
+
" FACommand = \"\"\n",
|
127 |
+
"if Multiplayer:\n",
|
128 |
+
" MPCommand = \"--multiplayer\"\n",
|
129 |
+
"else:\n",
|
130 |
+
" MPCommand = \"\"\n",
|
131 |
+
"\n",
|
132 |
+
"!echo Downloading KoboldCpp, please wait...\n",
|
133 |
+
"!wget -O dlfile.tmp https://kcpplinux.concedo.workers.dev && mv dlfile.tmp koboldcpp_linux\n",
|
134 |
+
"!test -f koboldcpp_linux && echo Download Successful || echo Download Failed\n",
|
135 |
+
"!chmod +x ./koboldcpp_linux\n",
|
136 |
+
"!apt update\n",
|
137 |
+
"!apt install aria2 -y\n",
|
138 |
+
"# simple fix for a common URL mistake\n",
|
139 |
+
"if \"https://huggingface.co/\" in Model and \"/blob/main/\" in Model:\n",
|
140 |
+
" Model = Model.replace(\"/blob/main/\", \"/resolve/main/\")\n",
|
141 |
+
"!aria2c -x 10 -o model.gguf --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $Model\n",
|
142 |
+
"if VCommand:\n",
|
143 |
+
" !aria2c -x 10 -o vmodel.gguf --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $Mmproj\n",
|
144 |
+
"if SCommand:\n",
|
145 |
+
" !aria2c -x 10 -o imodel.gguf --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $ImgModel\n",
|
146 |
+
"if WCommand:\n",
|
147 |
+
" !aria2c -x 10 -o wmodel.bin --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $SpeechModel\n",
|
148 |
+
"if TTSCommand:\n",
|
149 |
+
" !aria2c -x 10 -o ttsmodel.bin --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $TTSModel\n",
|
150 |
+
" !aria2c -x 10 -o ttswavtok.bin --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $WavTokModel\n",
|
151 |
+
"!./koboldcpp_linux model.gguf --usecublas 0 mmq --chatcompletionsadapter AutoGuess --multiuser --gpulayers $Layers --contextsize $ContextSize --websearch --quiet --remotetunnel $FACommand $MPCommand $VCommand $SCommand $WCommand $TTSCommand $SavGdriveCommand\n"
|
152 |
+
]
|
153 |
+
}
|
154 |
+
],
|
155 |
+
"metadata": {
|
156 |
+
"accelerator": "GPU",
|
157 |
+
"colab": {
|
158 |
+
"cell_execution_strategy": "setup",
|
159 |
+
"gpuType": "T4",
|
160 |
+
"include_colab_link": true,
|
161 |
+
"private_outputs": true,
|
162 |
+
"provenance": []
|
163 |
+
},
|
164 |
+
"kernelspec": {
|
165 |
+
"display_name": "Python 3",
|
166 |
+
"name": "python3"
|
167 |
+
},
|
168 |
+
"language_info": {
|
169 |
+
"name": "python"
|
170 |
+
}
|
171 |
+
},
|
172 |
+
"nbformat": 4,
|
173 |
+
"nbformat_minor": 0
|
174 |
+
}
|
common/arg.cpp
ADDED
The diff for this file is too large to render.
See raw diff
|
|
common/arg.h
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "common.h"
|
4 |
+
|
5 |
+
#include <set>
|
6 |
+
#include <string>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
//
|
10 |
+
// CLI argument parsing
|
11 |
+
//
|
12 |
+
|
13 |
+
struct common_arg {
|
14 |
+
std::set<enum llama_example> examples = {LLAMA_EXAMPLE_COMMON};
|
15 |
+
std::set<enum llama_example> excludes = {};
|
16 |
+
std::vector<const char *> args;
|
17 |
+
const char * value_hint = nullptr; // help text or example for arg value
|
18 |
+
const char * value_hint_2 = nullptr; // for second arg value
|
19 |
+
const char * env = nullptr;
|
20 |
+
std::string help;
|
21 |
+
bool is_sparam = false; // is current arg a sampling param?
|
22 |
+
void (*handler_void) (common_params & params) = nullptr;
|
23 |
+
void (*handler_string) (common_params & params, const std::string &) = nullptr;
|
24 |
+
void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
|
25 |
+
void (*handler_int) (common_params & params, int) = nullptr;
|
26 |
+
|
27 |
+
common_arg(
|
28 |
+
const std::initializer_list<const char *> & args,
|
29 |
+
const char * value_hint,
|
30 |
+
const std::string & help,
|
31 |
+
void (*handler)(common_params & params, const std::string &)
|
32 |
+
) : args(args), value_hint(value_hint), help(help), handler_string(handler) {}
|
33 |
+
|
34 |
+
common_arg(
|
35 |
+
const std::initializer_list<const char *> & args,
|
36 |
+
const char * value_hint,
|
37 |
+
const std::string & help,
|
38 |
+
void (*handler)(common_params & params, int)
|
39 |
+
) : args(args), value_hint(value_hint), help(help), handler_int(handler) {}
|
40 |
+
|
41 |
+
common_arg(
|
42 |
+
const std::initializer_list<const char *> & args,
|
43 |
+
const std::string & help,
|
44 |
+
void (*handler)(common_params & params)
|
45 |
+
) : args(args), help(help), handler_void(handler) {}
|
46 |
+
|
47 |
+
// support 2 values for arg
|
48 |
+
common_arg(
|
49 |
+
const std::initializer_list<const char *> & args,
|
50 |
+
const char * value_hint,
|
51 |
+
const char * value_hint_2,
|
52 |
+
const std::string & help,
|
53 |
+
void (*handler)(common_params & params, const std::string &, const std::string &)
|
54 |
+
) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
|
55 |
+
|
56 |
+
common_arg & set_examples(std::initializer_list<enum llama_example> examples);
|
57 |
+
common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
|
58 |
+
common_arg & set_env(const char * env);
|
59 |
+
common_arg & set_sparam();
|
60 |
+
bool in_example(enum llama_example ex);
|
61 |
+
bool is_exclude(enum llama_example ex);
|
62 |
+
bool get_value_from_env(std::string & output);
|
63 |
+
bool has_value_from_env();
|
64 |
+
std::string to_string();
|
65 |
+
};
|
66 |
+
|
67 |
+
struct common_params_context {
|
68 |
+
enum llama_example ex = LLAMA_EXAMPLE_COMMON;
|
69 |
+
common_params & params;
|
70 |
+
std::vector<common_arg> options;
|
71 |
+
void(*print_usage)(int, char **) = nullptr;
|
72 |
+
common_params_context(common_params & params) : params(params) {}
|
73 |
+
};
|
74 |
+
|
75 |
+
// parse input arguments from CLI
|
76 |
+
// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
|
77 |
+
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
78 |
+
|
79 |
+
// function to be used by test-arg-parser
|
80 |
+
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
common/base64.hpp
ADDED
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
This is free and unencumbered software released into the public domain.
|
3 |
+
|
4 |
+
Anyone is free to copy, modify, publish, use, compile, sell, or
|
5 |
+
distribute this software, either in source code form or as a compiled
|
6 |
+
binary, for any purpose, commercial or non-commercial, and by any
|
7 |
+
means.
|
8 |
+
|
9 |
+
In jurisdictions that recognize copyright laws, the author or authors
|
10 |
+
of this software dedicate any and all copyright interest in the
|
11 |
+
software to the public domain. We make this dedication for the benefit
|
12 |
+
of the public at large and to the detriment of our heirs and
|
13 |
+
successors. We intend this dedication to be an overt act of
|
14 |
+
relinquishment in perpetuity of all present and future rights to this
|
15 |
+
software under copyright law.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
18 |
+
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
19 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
20 |
+
IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
21 |
+
OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
22 |
+
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
23 |
+
OTHER DEALINGS IN THE SOFTWARE.
|
24 |
+
|
25 |
+
For more information, please refer to <http://unlicense.org>
|
26 |
+
*/
|
27 |
+
|
28 |
+
#ifndef PUBLIC_DOMAIN_BASE64_HPP_
|
29 |
+
#define PUBLIC_DOMAIN_BASE64_HPP_
|
30 |
+
|
31 |
+
#include <cstdint>
|
32 |
+
#include <iterator>
|
33 |
+
#include <stdexcept>
|
34 |
+
#include <string>
|
35 |
+
|
36 |
+
class base64_error : public std::runtime_error
|
37 |
+
{
|
38 |
+
public:
|
39 |
+
using std::runtime_error::runtime_error;
|
40 |
+
};
|
41 |
+
|
42 |
+
class base64
|
43 |
+
{
|
44 |
+
public:
|
45 |
+
enum class alphabet
|
46 |
+
{
|
47 |
+
/** the alphabet is detected automatically */
|
48 |
+
auto_,
|
49 |
+
/** the standard base64 alphabet is used */
|
50 |
+
standard,
|
51 |
+
/** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/
|
52 |
+
url_filename_safe
|
53 |
+
};
|
54 |
+
|
55 |
+
enum class decoding_behavior
|
56 |
+
{
|
57 |
+
/** if the input is not padded, the remaining bits are ignored */
|
58 |
+
moderate,
|
59 |
+
/** if a padding character is encounter decoding is finished */
|
60 |
+
loose
|
61 |
+
};
|
62 |
+
|
63 |
+
/**
|
64 |
+
Encodes all the elements from `in_begin` to `in_end` to `out`.
|
65 |
+
|
66 |
+
@warning The source and destination cannot overlap. The destination must be able to hold at least
|
67 |
+
`required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator.
|
68 |
+
|
69 |
+
@tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than
|
70 |
+
8 bits
|
71 |
+
@tparam Output_iterator the destination; the elements written to it are from the type `char`
|
72 |
+
@param in_begin the beginning of the source
|
73 |
+
@param in_end the ending of the source
|
74 |
+
@param out the destination iterator
|
75 |
+
@param alphabet which alphabet should be used
|
76 |
+
@returns the iterator to the next element past the last element copied
|
77 |
+
@throws see `Input_iterator` and `Output_iterator`
|
78 |
+
*/
|
79 |
+
template<typename Input_iterator, typename Output_iterator>
|
80 |
+
static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
|
81 |
+
alphabet alphabet = alphabet::standard)
|
82 |
+
{
|
83 |
+
constexpr auto pad = '=';
|
84 |
+
const char* alpha = alphabet == alphabet::url_filename_safe
|
85 |
+
? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
86 |
+
: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
87 |
+
|
88 |
+
while (in_begin != in_end) {
|
89 |
+
std::uint8_t i0 = 0, i1 = 0, i2 = 0;
|
90 |
+
|
91 |
+
// first character
|
92 |
+
i0 = static_cast<std::uint8_t>(*in_begin);
|
93 |
+
++in_begin;
|
94 |
+
|
95 |
+
*out = alpha[i0 >> 2 & 0x3f];
|
96 |
+
++out;
|
97 |
+
|
98 |
+
// part of first character and second
|
99 |
+
if (in_begin != in_end) {
|
100 |
+
i1 = static_cast<std::uint8_t>(*in_begin);
|
101 |
+
++in_begin;
|
102 |
+
|
103 |
+
*out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)];
|
104 |
+
++out;
|
105 |
+
} else {
|
106 |
+
*out = alpha[(i0 & 0x3) << 4];
|
107 |
+
++out;
|
108 |
+
|
109 |
+
// last padding
|
110 |
+
*out = pad;
|
111 |
+
++out;
|
112 |
+
|
113 |
+
// last padding
|
114 |
+
*out = pad;
|
115 |
+
++out;
|
116 |
+
|
117 |
+
break;
|
118 |
+
}
|
119 |
+
|
120 |
+
// part of second character and third
|
121 |
+
if (in_begin != in_end) {
|
122 |
+
i2 = static_cast<std::uint8_t>(*in_begin);
|
123 |
+
++in_begin;
|
124 |
+
|
125 |
+
*out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)];
|
126 |
+
++out;
|
127 |
+
} else {
|
128 |
+
*out = alpha[(i1 & 0xf) << 2];
|
129 |
+
++out;
|
130 |
+
|
131 |
+
// last padding
|
132 |
+
*out = pad;
|
133 |
+
++out;
|
134 |
+
|
135 |
+
break;
|
136 |
+
}
|
137 |
+
|
138 |
+
// rest of third
|
139 |
+
*out = alpha[i2 & 0x3f];
|
140 |
+
++out;
|
141 |
+
}
|
142 |
+
|
143 |
+
return out;
|
144 |
+
}
|
145 |
+
/**
|
146 |
+
Encodes a string.
|
147 |
+
|
148 |
+
@param str the string that should be encoded
|
149 |
+
@param alphabet which alphabet should be used
|
150 |
+
@returns the encoded base64 string
|
151 |
+
@throws see base64::encode()
|
152 |
+
*/
|
153 |
+
static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard)
|
154 |
+
{
|
155 |
+
std::string result;
|
156 |
+
|
157 |
+
result.reserve(required_encode_size(str.length()) + 1);
|
158 |
+
|
159 |
+
encode(str.begin(), str.end(), std::back_inserter(result), alphabet);
|
160 |
+
|
161 |
+
return result;
|
162 |
+
}
|
163 |
+
/**
|
164 |
+
Encodes a char array.
|
165 |
+
|
166 |
+
@param buffer the char array
|
167 |
+
@param size the size of the array
|
168 |
+
@param alphabet which alphabet should be used
|
169 |
+
@returns the encoded string
|
170 |
+
*/
|
171 |
+
static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard)
|
172 |
+
{
|
173 |
+
std::string result;
|
174 |
+
|
175 |
+
result.reserve(required_encode_size(size) + 1);
|
176 |
+
|
177 |
+
encode(buffer, buffer + size, std::back_inserter(result), alphabet);
|
178 |
+
|
179 |
+
return result;
|
180 |
+
}
|
181 |
+
/**
|
182 |
+
Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`,
|
183 |
+
in other words: inplace decoding is possible.
|
184 |
+
|
185 |
+
@warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`,
|
186 |
+
otherwise the behavior depends on the output iterator.
|
187 |
+
|
188 |
+
@tparam Input_iterator the source; the returned elements are cast to `char`
|
189 |
+
@tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t`
|
190 |
+
@param in_begin the beginning of the source
|
191 |
+
@param in_end the ending of the source
|
192 |
+
@param out the destination iterator
|
193 |
+
@param alphabet which alphabet should be used
|
194 |
+
@param behavior the behavior when an error was detected
|
195 |
+
@returns the iterator to the next element past the last element copied
|
196 |
+
@throws base64_error depending on the set behavior
|
197 |
+
@throws see `Input_iterator` and `Output_iterator`
|
198 |
+
*/
|
199 |
+
template<typename Input_iterator, typename Output_iterator>
|
200 |
+
static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
|
201 |
+
alphabet alphabet = alphabet::auto_,
|
202 |
+
decoding_behavior behavior = decoding_behavior::moderate)
|
203 |
+
{
|
204 |
+
//constexpr auto pad = '=';
|
205 |
+
std::uint8_t last = 0;
|
206 |
+
auto bits = 0;
|
207 |
+
|
208 |
+
while (in_begin != in_end) {
|
209 |
+
auto c = *in_begin;
|
210 |
+
++in_begin;
|
211 |
+
|
212 |
+
if (c == '=') {
|
213 |
+
break;
|
214 |
+
}
|
215 |
+
|
216 |
+
auto part = _base64_value(alphabet, c);
|
217 |
+
|
218 |
+
// enough bits for one byte
|
219 |
+
if (bits + 6 >= 8) {
|
220 |
+
*out = (last << (8 - bits)) | (part >> (bits - 2));
|
221 |
+
++out;
|
222 |
+
|
223 |
+
bits -= 2;
|
224 |
+
} else {
|
225 |
+
bits += 6;
|
226 |
+
}
|
227 |
+
|
228 |
+
last = part;
|
229 |
+
}
|
230 |
+
|
231 |
+
// check padding
|
232 |
+
if (behavior != decoding_behavior::loose) {
|
233 |
+
while (in_begin != in_end) {
|
234 |
+
auto c = *in_begin;
|
235 |
+
++in_begin;
|
236 |
+
|
237 |
+
if (c != '=') {
|
238 |
+
throw base64_error("invalid base64 character.");
|
239 |
+
}
|
240 |
+
}
|
241 |
+
}
|
242 |
+
|
243 |
+
return out;
|
244 |
+
}
|
245 |
+
/**
|
246 |
+
Decodes a string.
|
247 |
+
|
248 |
+
@param str the base64 encoded string
|
249 |
+
@param alphabet which alphabet should be used
|
250 |
+
@param behavior the behavior when an error was detected
|
251 |
+
@returns the decoded string
|
252 |
+
@throws see base64::decode()
|
253 |
+
*/
|
254 |
+
static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_,
|
255 |
+
decoding_behavior behavior = decoding_behavior::moderate)
|
256 |
+
{
|
257 |
+
std::string result;
|
258 |
+
|
259 |
+
result.reserve(max_decode_size(str.length()));
|
260 |
+
|
261 |
+
decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior);
|
262 |
+
|
263 |
+
return result;
|
264 |
+
}
|
265 |
+
/**
|
266 |
+
Decodes a string.
|
267 |
+
|
268 |
+
@param buffer the base64 encoded buffer
|
269 |
+
@param size the size of the buffer
|
270 |
+
@param alphabet which alphabet should be used
|
271 |
+
@param behavior the behavior when an error was detected
|
272 |
+
@returns the decoded string
|
273 |
+
@throws see base64::decode()
|
274 |
+
*/
|
275 |
+
static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_,
|
276 |
+
decoding_behavior behavior = decoding_behavior::moderate)
|
277 |
+
{
|
278 |
+
std::string result;
|
279 |
+
|
280 |
+
result.reserve(max_decode_size(size));
|
281 |
+
|
282 |
+
decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior);
|
283 |
+
|
284 |
+
return result;
|
285 |
+
}
|
286 |
+
/**
|
287 |
+
Decodes a string inplace.
|
288 |
+
|
289 |
+
@param[in,out] str the base64 encoded string
|
290 |
+
@param alphabet which alphabet should be used
|
291 |
+
@param behavior the behavior when an error was detected
|
292 |
+
@throws base64::decode_inplace()
|
293 |
+
*/
|
294 |
+
static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_,
|
295 |
+
decoding_behavior behavior = decoding_behavior::moderate)
|
296 |
+
{
|
297 |
+
str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin());
|
298 |
+
}
|
299 |
+
/**
|
300 |
+
Decodes a char array inplace.
|
301 |
+
|
302 |
+
@param[in,out] str the string array
|
303 |
+
@param size the length of the array
|
304 |
+
@param alphabet which alphabet should be used
|
305 |
+
@param behavior the behavior when an error was detected
|
306 |
+
@returns the pointer to the next element past the last element decoded
|
307 |
+
@throws base64::decode_inplace()
|
308 |
+
*/
|
309 |
+
static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_,
|
310 |
+
decoding_behavior behavior = decoding_behavior::moderate)
|
311 |
+
{
|
312 |
+
return decode(str, str + size, str, alphabet, behavior);
|
313 |
+
}
|
314 |
+
/**
|
315 |
+
Returns the required decoding size for a given size. The value is calculated with the following formula:
|
316 |
+
|
317 |
+
$$
|
318 |
+
\lceil \frac{size}{4} \rceil \cdot 3
|
319 |
+
$$
|
320 |
+
|
321 |
+
@param size the size of the encoded input
|
322 |
+
@returns the size of the resulting decoded buffer; this the absolute maximum
|
323 |
+
*/
|
324 |
+
static std::size_t max_decode_size(std::size_t size) noexcept
|
325 |
+
{
|
326 |
+
return (size / 4 + (size % 4 ? 1 : 0)) * 3;
|
327 |
+
}
|
328 |
+
/**
|
329 |
+
Returns the required encoding size for a given size. The value is calculated with the following formula:
|
330 |
+
|
331 |
+
$$
|
332 |
+
\lceil \frac{size}{3} \rceil \cdot 4
|
333 |
+
$$
|
334 |
+
|
335 |
+
@param size the size of the decoded input
|
336 |
+
@returns the size of the resulting encoded buffer
|
337 |
+
*/
|
338 |
+
static std::size_t required_encode_size(std::size_t size) noexcept
|
339 |
+
{
|
340 |
+
return (size / 3 + (size % 3 ? 1 : 0)) * 4;
|
341 |
+
}
|
342 |
+
|
343 |
+
private:
|
344 |
+
static std::uint8_t _base64_value(alphabet& alphabet, char c)
|
345 |
+
{
|
346 |
+
if (c >= 'A' && c <= 'Z') {
|
347 |
+
return c - 'A';
|
348 |
+
} else if (c >= 'a' && c <= 'z') {
|
349 |
+
return c - 'a' + 26;
|
350 |
+
} else if (c >= '0' && c <= '9') {
|
351 |
+
return c - '0' + 52;
|
352 |
+
}
|
353 |
+
|
354 |
+
// comes down to alphabet
|
355 |
+
if (alphabet == alphabet::standard) {
|
356 |
+
if (c == '+') {
|
357 |
+
return 62;
|
358 |
+
} else if (c == '/') {
|
359 |
+
return 63;
|
360 |
+
}
|
361 |
+
} else if (alphabet == alphabet::url_filename_safe) {
|
362 |
+
if (c == '-') {
|
363 |
+
return 62;
|
364 |
+
} else if (c == '_') {
|
365 |
+
return 63;
|
366 |
+
}
|
367 |
+
} // auto detect
|
368 |
+
else {
|
369 |
+
if (c == '+') {
|
370 |
+
alphabet = alphabet::standard;
|
371 |
+
|
372 |
+
return 62;
|
373 |
+
} else if (c == '/') {
|
374 |
+
alphabet = alphabet::standard;
|
375 |
+
|
376 |
+
return 63;
|
377 |
+
} else if (c == '-') {
|
378 |
+
alphabet = alphabet::url_filename_safe;
|
379 |
+
|
380 |
+
return 62;
|
381 |
+
} else if (c == '_') {
|
382 |
+
alphabet = alphabet::url_filename_safe;
|
383 |
+
|
384 |
+
return 63;
|
385 |
+
}
|
386 |
+
}
|
387 |
+
|
388 |
+
throw base64_error("invalid base64 character.");
|
389 |
+
}
|
390 |
+
};
|
391 |
+
|
392 |
+
#endif // !PUBLIC_DOMAIN_BASE64_HPP_
|
common/build-info.cpp.in
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
int LLAMA_BUILD_NUMBER = @BUILD_NUMBER@;
|
2 |
+
char const *LLAMA_COMMIT = "@BUILD_COMMIT@";
|
3 |
+
char const *LLAMA_COMPILER = "@BUILD_COMPILER@";
|
4 |
+
char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@";
|
common/chat.cpp
ADDED
@@ -0,0 +1,1779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "chat.h"
|
2 |
+
#include "json-schema-to-grammar.h"
|
3 |
+
#include "log.h"
|
4 |
+
#include "minja/chat-template.hpp"
|
5 |
+
#include "minja/minja.hpp"
|
6 |
+
|
7 |
+
#include <optional>
|
8 |
+
|
9 |
+
typedef minja::chat_template common_chat_template;
|
10 |
+
|
11 |
+
struct common_chat_templates {
|
12 |
+
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
13 |
+
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
14 |
+
std::unique_ptr<common_chat_template> template_tool_use;
|
15 |
+
};
|
16 |
+
|
17 |
+
struct templates_params {
|
18 |
+
json messages;
|
19 |
+
json tools;
|
20 |
+
common_chat_tool_choice tool_choice;
|
21 |
+
json json_schema;
|
22 |
+
bool parallel_tool_calls;
|
23 |
+
bool stream;
|
24 |
+
std::string grammar;
|
25 |
+
bool add_generation_prompt = true;
|
26 |
+
bool extract_reasoning = true;
|
27 |
+
};
|
28 |
+
|
29 |
+
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
30 |
+
if (tool_choice == "auto") {
|
31 |
+
return COMMON_CHAT_TOOL_CHOICE_AUTO;
|
32 |
+
}
|
33 |
+
if (tool_choice == "none") {
|
34 |
+
return COMMON_CHAT_TOOL_CHOICE_NONE;
|
35 |
+
}
|
36 |
+
if (tool_choice == "required") {
|
37 |
+
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
38 |
+
}
|
39 |
+
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
40 |
+
}
|
41 |
+
|
42 |
+
template <>
|
43 |
+
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
44 |
+
std::vector<common_chat_msg> msgs;
|
45 |
+
|
46 |
+
try {
|
47 |
+
|
48 |
+
if (!messages.is_array()) {
|
49 |
+
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
|
50 |
+
}
|
51 |
+
|
52 |
+
for (const auto & message : messages) {
|
53 |
+
if (!message.is_object()) {
|
54 |
+
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
|
55 |
+
}
|
56 |
+
|
57 |
+
common_chat_msg msg;
|
58 |
+
if (!message.contains("role")) {
|
59 |
+
throw std::runtime_error("Missing 'role' in message: " + message.dump());
|
60 |
+
}
|
61 |
+
msg.role = message.at("role");
|
62 |
+
|
63 |
+
auto has_content = message.contains("content");
|
64 |
+
auto has_tool_calls = message.contains("tool_calls");
|
65 |
+
if (has_content) {
|
66 |
+
const auto & content = message.at("content");
|
67 |
+
if (content.is_string()) {
|
68 |
+
msg.content = content;
|
69 |
+
} else if (content.is_array()) {
|
70 |
+
for (const auto & part : content) {
|
71 |
+
if (!part.contains("type")) {
|
72 |
+
throw std::runtime_error("Missing content part type: " + part.dump());
|
73 |
+
}
|
74 |
+
const auto & type = part.at("type");
|
75 |
+
if (type != "text") {
|
76 |
+
throw std::runtime_error("Unsupported content part type: " + type.dump());
|
77 |
+
}
|
78 |
+
common_chat_msg_content_part msg_part;
|
79 |
+
msg_part.type = type;
|
80 |
+
msg_part.text = part.at("text");
|
81 |
+
msg.content_parts.push_back(msg_part);
|
82 |
+
}
|
83 |
+
} else if (!content.is_null()) {
|
84 |
+
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
85 |
+
}
|
86 |
+
}
|
87 |
+
if (has_tool_calls) {
|
88 |
+
for (const auto & tool_call : message.at("tool_calls")) {
|
89 |
+
common_chat_tool_call tc;
|
90 |
+
if (!tool_call.contains("type")) {
|
91 |
+
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
|
92 |
+
}
|
93 |
+
const auto & type = tool_call.at("type");
|
94 |
+
if (type != "function") {
|
95 |
+
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
|
96 |
+
}
|
97 |
+
if (!tool_call.contains("function")) {
|
98 |
+
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
|
99 |
+
}
|
100 |
+
const auto & fc = tool_call.at("function");
|
101 |
+
if (!fc.contains("name")) {
|
102 |
+
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
|
103 |
+
}
|
104 |
+
tc.name = fc.at("name");
|
105 |
+
tc.arguments = fc.at("arguments");
|
106 |
+
if (tool_call.contains("id")) {
|
107 |
+
tc.id = tool_call.at("id");
|
108 |
+
}
|
109 |
+
msg.tool_calls.push_back(tc);
|
110 |
+
}
|
111 |
+
}
|
112 |
+
if (!has_content && !has_tool_calls) {
|
113 |
+
throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
|
114 |
+
}
|
115 |
+
if (message.contains("reasoning_content")) {
|
116 |
+
msg.reasoning_content = message.at("reasoning_content");
|
117 |
+
}
|
118 |
+
if (message.contains("name")) {
|
119 |
+
msg.tool_name = message.at("name");
|
120 |
+
}
|
121 |
+
if (message.contains("tool_call_id")) {
|
122 |
+
msg.tool_call_id = message.at("tool_call_id");
|
123 |
+
}
|
124 |
+
|
125 |
+
msgs.push_back(msg);
|
126 |
+
}
|
127 |
+
} catch (const std::exception & e) {
|
128 |
+
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
|
129 |
+
}
|
130 |
+
|
131 |
+
return msgs;
|
132 |
+
}
|
133 |
+
|
134 |
+
template <>
|
135 |
+
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
136 |
+
json messages = json::array();
|
137 |
+
for (const auto & msg : msgs) {
|
138 |
+
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
139 |
+
throw std::runtime_error("Cannot specify both content and content_parts");
|
140 |
+
}
|
141 |
+
json jmsg {
|
142 |
+
{"role", msg.role},
|
143 |
+
};
|
144 |
+
if (!msg.content.empty()) {
|
145 |
+
jmsg["content"] = msg.content;
|
146 |
+
} else if (!msg.content_parts.empty()) {
|
147 |
+
if (concat_typed_text) {
|
148 |
+
std::string text;
|
149 |
+
for (const auto & part : msg.content_parts) {
|
150 |
+
if (part.type != "text") {
|
151 |
+
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
152 |
+
continue;
|
153 |
+
}
|
154 |
+
if (!text.empty()) {
|
155 |
+
text += '\n';
|
156 |
+
}
|
157 |
+
text += part.text;
|
158 |
+
}
|
159 |
+
jmsg["content"] = text;
|
160 |
+
} else {
|
161 |
+
auto & parts = jmsg["content"] = json::array();
|
162 |
+
for (const auto & part : msg.content_parts) {
|
163 |
+
parts.push_back({
|
164 |
+
{"type", part.type},
|
165 |
+
{"text", part.text},
|
166 |
+
});
|
167 |
+
}
|
168 |
+
}
|
169 |
+
} else {
|
170 |
+
jmsg["content"] = json(); // null
|
171 |
+
}
|
172 |
+
if (!msg.reasoning_content.empty()) {
|
173 |
+
jmsg["reasoning_content"] = msg.reasoning_content;
|
174 |
+
}
|
175 |
+
if (!msg.tool_name.empty()) {
|
176 |
+
jmsg["name"] = msg.tool_name;
|
177 |
+
}
|
178 |
+
if (!msg.tool_call_id.empty()) {
|
179 |
+
jmsg["tool_call_id"] = msg.tool_call_id;
|
180 |
+
}
|
181 |
+
if (!msg.tool_calls.empty()) {
|
182 |
+
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
183 |
+
for (const auto & tool_call : msg.tool_calls) {
|
184 |
+
json tc {
|
185 |
+
{"type", "function"},
|
186 |
+
{"function", {
|
187 |
+
{"name", tool_call.name},
|
188 |
+
{"arguments", tool_call.arguments},
|
189 |
+
}},
|
190 |
+
};
|
191 |
+
if (!tool_call.id.empty()) {
|
192 |
+
tc["id"] = tool_call.id;
|
193 |
+
}
|
194 |
+
tool_calls.push_back(tc);
|
195 |
+
}
|
196 |
+
}
|
197 |
+
messages.push_back(jmsg);
|
198 |
+
}
|
199 |
+
return messages;
|
200 |
+
}
|
201 |
+
|
202 |
+
template <>
|
203 |
+
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
204 |
+
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
205 |
+
}
|
206 |
+
|
207 |
+
template <>
|
208 |
+
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
209 |
+
std::vector<common_chat_tool> result;
|
210 |
+
|
211 |
+
try {
|
212 |
+
if (!tools.is_null()) {
|
213 |
+
if (!tools.is_array()) {
|
214 |
+
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
|
215 |
+
}
|
216 |
+
for (const auto & tool : tools) {
|
217 |
+
if (!tool.contains("type")) {
|
218 |
+
throw std::runtime_error("Missing tool type: " + tool.dump());
|
219 |
+
}
|
220 |
+
const auto & type = tool.at("type");
|
221 |
+
if (!type.is_string() || type != "function") {
|
222 |
+
throw std::runtime_error("Unsupported tool type: " + tool.dump());
|
223 |
+
}
|
224 |
+
if (!tool.contains("function")) {
|
225 |
+
throw std::runtime_error("Missing tool function: " + tool.dump());
|
226 |
+
}
|
227 |
+
|
228 |
+
const auto & function = tool.at("function");
|
229 |
+
result.push_back({
|
230 |
+
/* .name = */ function.at("name"),
|
231 |
+
/* .description = */ function.at("description"),
|
232 |
+
/* .parameters = */ function.at("parameters").dump(),
|
233 |
+
});
|
234 |
+
}
|
235 |
+
}
|
236 |
+
} catch (const std::exception & e) {
|
237 |
+
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
|
238 |
+
}
|
239 |
+
|
240 |
+
return result;
|
241 |
+
}
|
242 |
+
|
243 |
+
template <>
|
244 |
+
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
245 |
+
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
246 |
+
}
|
247 |
+
|
248 |
+
template <>
|
249 |
+
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
250 |
+
if (tools.empty()) {
|
251 |
+
return json();
|
252 |
+
}
|
253 |
+
|
254 |
+
auto result = json::array();
|
255 |
+
for (const auto & tool : tools) {
|
256 |
+
result.push_back({
|
257 |
+
{"type", "function"},
|
258 |
+
{"function", {
|
259 |
+
{"name", tool.name},
|
260 |
+
{"description", tool.description},
|
261 |
+
{"parameters", json::parse(tool.parameters)},
|
262 |
+
}},
|
263 |
+
});
|
264 |
+
}
|
265 |
+
return result;
|
266 |
+
}
|
267 |
+
|
268 |
+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
269 |
+
if (use_jinja) {
|
270 |
+
try {
|
271 |
+
common_chat_msg msg;
|
272 |
+
msg.role = "user";
|
273 |
+
msg.content = "test";
|
274 |
+
|
275 |
+
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
|
276 |
+
|
277 |
+
common_chat_templates_inputs inputs;
|
278 |
+
inputs.messages = {msg};
|
279 |
+
|
280 |
+
common_chat_templates_apply(tmpls.get(), inputs);
|
281 |
+
return true;
|
282 |
+
} catch (const std::exception & e) {
|
283 |
+
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
284 |
+
return false;
|
285 |
+
}
|
286 |
+
}
|
287 |
+
llama_chat_message chat[] = {{"user", "test"}};
|
288 |
+
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
289 |
+
return res >= 0;
|
290 |
+
}
|
291 |
+
|
292 |
+
std::string common_chat_format_single(
|
293 |
+
const struct common_chat_templates * tmpls,
|
294 |
+
const std::vector<common_chat_msg> & past_msg,
|
295 |
+
const common_chat_msg & new_msg,
|
296 |
+
bool add_ass,
|
297 |
+
bool use_jinja) {
|
298 |
+
|
299 |
+
common_chat_templates_inputs inputs;
|
300 |
+
inputs.use_jinja = use_jinja;
|
301 |
+
|
302 |
+
std::string fmt_past_msg;
|
303 |
+
if (!past_msg.empty()) {
|
304 |
+
inputs.messages = past_msg;
|
305 |
+
inputs.add_generation_prompt = false;
|
306 |
+
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
307 |
+
}
|
308 |
+
std::ostringstream ss;
|
309 |
+
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
310 |
+
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
311 |
+
ss << "\n";
|
312 |
+
};
|
313 |
+
// format chat with new_msg
|
314 |
+
inputs.messages.push_back(new_msg);
|
315 |
+
inputs.add_generation_prompt = add_ass;
|
316 |
+
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
317 |
+
// get the diff part
|
318 |
+
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
319 |
+
return ss.str();
|
320 |
+
}
|
321 |
+
|
322 |
+
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
323 |
+
common_chat_templates_inputs inputs;
|
324 |
+
inputs.use_jinja = use_jinja;
|
325 |
+
auto add_simple_msg = [&](auto role, auto content) {
|
326 |
+
common_chat_msg msg;
|
327 |
+
msg.role = role;
|
328 |
+
msg.content = content;
|
329 |
+
inputs.messages.push_back(msg);
|
330 |
+
};
|
331 |
+
add_simple_msg("system", "You are a helpful assistant");
|
332 |
+
add_simple_msg("user", "Hello");
|
333 |
+
add_simple_msg("assistant", "Hi there");
|
334 |
+
add_simple_msg("user", "How are you?");
|
335 |
+
return common_chat_templates_apply(tmpls, inputs).prompt;
|
336 |
+
}
|
337 |
+
|
338 |
+
#define CHATML_TEMPLATE_SRC \
|
339 |
+
"{%- for message in messages -%}\n" \
|
340 |
+
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
341 |
+
"{%- endfor -%}\n" \
|
342 |
+
"{%- if add_generation_prompt -%}\n" \
|
343 |
+
" {{- '<|im_start|>assistant\n' -}}\n" \
|
344 |
+
"{%- endif -%}"
|
345 |
+
|
346 |
+
void common_chat_templates_free(struct common_chat_templates * tmpls) {
|
347 |
+
delete tmpls;
|
348 |
+
}
|
349 |
+
|
350 |
+
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
|
351 |
+
return tmpls->has_explicit_template;
|
352 |
+
}
|
353 |
+
|
354 |
+
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
355 |
+
if (variant != nullptr) {
|
356 |
+
if (strcmp(variant, "tool_use") == 0) {
|
357 |
+
if (tmpls->template_tool_use) {
|
358 |
+
return tmpls->template_tool_use->source().c_str();
|
359 |
+
}
|
360 |
+
return nullptr;
|
361 |
+
} else {
|
362 |
+
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
363 |
+
}
|
364 |
+
}
|
365 |
+
return tmpls->template_default->source().c_str();
|
366 |
+
}
|
367 |
+
|
368 |
+
common_chat_templates_ptr common_chat_templates_init(
|
369 |
+
const struct llama_model * model,
|
370 |
+
const std::string & chat_template_override,
|
371 |
+
const std::string & bos_token_override,
|
372 |
+
const std::string & eos_token_override)
|
373 |
+
{
|
374 |
+
std::string default_template_src;
|
375 |
+
std::string template_tool_use_src;
|
376 |
+
|
377 |
+
bool has_explicit_template = !chat_template_override.empty();
|
378 |
+
if (chat_template_override.empty()) {
|
379 |
+
GGML_ASSERT(model != nullptr);
|
380 |
+
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
|
381 |
+
if (str) {
|
382 |
+
default_template_src = str;
|
383 |
+
has_explicit_template = true;
|
384 |
+
}
|
385 |
+
str = llama_model_chat_template(model, /* name */ "tool_use");
|
386 |
+
if (str) {
|
387 |
+
template_tool_use_src = str;
|
388 |
+
has_explicit_template = true;
|
389 |
+
}
|
390 |
+
} else {
|
391 |
+
default_template_src = chat_template_override;
|
392 |
+
}
|
393 |
+
if (default_template_src.empty() || default_template_src == "chatml") {
|
394 |
+
if (!template_tool_use_src.empty()) {
|
395 |
+
default_template_src = template_tool_use_src;
|
396 |
+
} else {
|
397 |
+
default_template_src = CHATML_TEMPLATE_SRC;
|
398 |
+
}
|
399 |
+
}
|
400 |
+
std::string token_bos = bos_token_override;
|
401 |
+
std::string token_eos = eos_token_override;
|
402 |
+
if (model) {
|
403 |
+
const auto * vocab = llama_model_get_vocab(model);
|
404 |
+
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
405 |
+
if (token == LLAMA_TOKEN_NULL) {
|
406 |
+
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
407 |
+
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
408 |
+
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
|
409 |
+
}
|
410 |
+
return std::string();
|
411 |
+
}
|
412 |
+
return common_token_to_piece(vocab, token, true);
|
413 |
+
};
|
414 |
+
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
415 |
+
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
416 |
+
}
|
417 |
+
common_chat_templates_ptr tmpls(new common_chat_templates());
|
418 |
+
tmpls->has_explicit_template = has_explicit_template;
|
419 |
+
try {
|
420 |
+
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
421 |
+
} catch (const std::exception & e) {
|
422 |
+
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
423 |
+
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
424 |
+
}
|
425 |
+
if (!template_tool_use_src.empty()) {
|
426 |
+
try {
|
427 |
+
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
428 |
+
} catch (const std::exception & e) {
|
429 |
+
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
430 |
+
}
|
431 |
+
}
|
432 |
+
return tmpls;
|
433 |
+
}
|
434 |
+
|
435 |
+
std::string common_chat_format_name(common_chat_format format) {
|
436 |
+
switch (format) {
|
437 |
+
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
|
438 |
+
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
|
439 |
+
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
|
440 |
+
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
|
441 |
+
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
|
442 |
+
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
|
443 |
+
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
|
444 |
+
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
|
445 |
+
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
|
446 |
+
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
447 |
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
448 |
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
|
449 |
+
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
450 |
+
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
|
451 |
+
default:
|
452 |
+
throw std::runtime_error("Unknown chat format");
|
453 |
+
}
|
454 |
+
}
|
455 |
+
|
456 |
+
static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
|
457 |
+
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
458 |
+
struct json_error_locator : public nlohmann::json_sax<json> {
|
459 |
+
std::size_t position;
|
460 |
+
bool found_error;
|
461 |
+
|
462 |
+
json_error_locator() : position(0), found_error(false) {}
|
463 |
+
|
464 |
+
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
|
465 |
+
this->position = position - 1;
|
466 |
+
this->found_error = true;
|
467 |
+
return false;
|
468 |
+
}
|
469 |
+
bool null() override { return true; } // NOLINT
|
470 |
+
bool boolean(bool) override { return true; } // NOLINT
|
471 |
+
bool number_integer(number_integer_t) override { return true; } // NOLINT
|
472 |
+
bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
|
473 |
+
bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
|
474 |
+
bool string(string_t &) override { return true; } // NOLINT
|
475 |
+
bool binary(binary_t &) override { return true; } // NOLINT
|
476 |
+
bool start_object(std::size_t) override { return true; } // NOLINT
|
477 |
+
bool key(string_t &) override { return true; } // NOLINT
|
478 |
+
bool end_object() override { return true; }
|
479 |
+
bool start_array(std::size_t) override { return true; } // NOLINT
|
480 |
+
bool end_array() override { return true; }
|
481 |
+
};
|
482 |
+
json_error_locator err_loc;
|
483 |
+
json::sax_parse(it, end, &err_loc);
|
484 |
+
|
485 |
+
std::string::const_iterator temptative_end;
|
486 |
+
if (err_loc.found_error) {
|
487 |
+
temptative_end = it + err_loc.position;
|
488 |
+
} else {
|
489 |
+
temptative_end = end;
|
490 |
+
}
|
491 |
+
std::string json_sub {it, temptative_end};
|
492 |
+
try {
|
493 |
+
out = json::parse(json_sub);
|
494 |
+
it = temptative_end;
|
495 |
+
return true;
|
496 |
+
} catch (const std::exception &) {
|
497 |
+
return false;
|
498 |
+
}
|
499 |
+
}
|
500 |
+
|
501 |
+
static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
|
502 |
+
auto expected_it = expected.begin();
|
503 |
+
auto tmp_it = it;
|
504 |
+
while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
|
505 |
+
++tmp_it;
|
506 |
+
++expected_it;
|
507 |
+
}
|
508 |
+
if (expected_it == expected.end()) {
|
509 |
+
it = tmp_it;
|
510 |
+
return true;
|
511 |
+
}
|
512 |
+
return false;
|
513 |
+
}
|
514 |
+
|
515 |
+
static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
|
516 |
+
std::smatch match;
|
517 |
+
if (std::regex_match(it, end, match, expected)) {
|
518 |
+
it = match.suffix().first;
|
519 |
+
return match;
|
520 |
+
}
|
521 |
+
return std::nullopt;
|
522 |
+
}
|
523 |
+
|
524 |
+
static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
|
525 |
+
while (it != end && std::isspace(*it)) {
|
526 |
+
++it;
|
527 |
+
}
|
528 |
+
}
|
529 |
+
|
530 |
+
/**
|
531 |
+
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
|
532 |
+
* Aggregates the prefix, suffix and in-between text into the content.
|
533 |
+
*/
|
534 |
+
static common_chat_msg parse_json_tool_calls(
|
535 |
+
const std::string& input,
|
536 |
+
const std::optional<std::regex> & trigger_opt,
|
537 |
+
const std::regex & function_regex,
|
538 |
+
const std::regex & close_regex,
|
539 |
+
bool allow_raw_python = false) {
|
540 |
+
std::smatch match;
|
541 |
+
|
542 |
+
common_chat_msg result;
|
543 |
+
result.role = "assistant";
|
544 |
+
|
545 |
+
|
546 |
+
auto end = input.end();
|
547 |
+
auto it = input.begin();
|
548 |
+
|
549 |
+
if (trigger_opt) {
|
550 |
+
if (!std::regex_search(it, end, match, *trigger_opt)) {
|
551 |
+
result.content = input;
|
552 |
+
return result;
|
553 |
+
}
|
554 |
+
result.content = match.prefix().str();
|
555 |
+
it = match.suffix().first;
|
556 |
+
}
|
557 |
+
|
558 |
+
while (it != end) {
|
559 |
+
std::sregex_iterator rend;
|
560 |
+
std::sregex_iterator rit(it, end, function_regex);
|
561 |
+
if (rit == rend) {
|
562 |
+
result.content += std::string(it, end);
|
563 |
+
break;
|
564 |
+
}
|
565 |
+
auto name = rit->str(1);
|
566 |
+
result.content += std::string(it, rit->prefix().second);
|
567 |
+
it = rit->suffix().first;
|
568 |
+
|
569 |
+
json arguments;
|
570 |
+
if (parse_json(it, end, arguments)) {
|
571 |
+
if (!std::regex_search(it, end, match, close_regex)) {
|
572 |
+
throw std::runtime_error("Malformed input, missing closing pattern: " + input);
|
573 |
+
}
|
574 |
+
it = match.suffix().first;
|
575 |
+
result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
|
576 |
+
} else {
|
577 |
+
if (allow_raw_python && name == "python") {
|
578 |
+
result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
|
579 |
+
break;
|
580 |
+
}
|
581 |
+
throw std::runtime_error("Failed to parse json tool call arguments: " + input);
|
582 |
+
}
|
583 |
+
}
|
584 |
+
|
585 |
+
if (!result.tool_calls.empty()) {
|
586 |
+
if (!string_strip(result.content).empty()) {
|
587 |
+
LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
|
588 |
+
}
|
589 |
+
result.content = "";
|
590 |
+
}
|
591 |
+
return result;
|
592 |
+
}
|
593 |
+
|
594 |
+
static common_chat_tool_call process_tool_call(const json & tool_call) {
|
595 |
+
const auto & arguments = tool_call.at("arguments");
|
596 |
+
return {
|
597 |
+
/* .name = */ tool_call.at("name"),
|
598 |
+
/* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
599 |
+
/* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
|
600 |
+
};
|
601 |
+
}
|
602 |
+
static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
|
603 |
+
auto content_end = input.find(prefix);
|
604 |
+
size_t tc_start = std::string::npos;
|
605 |
+
|
606 |
+
common_chat_msg result;
|
607 |
+
result.role = "assistant";
|
608 |
+
if (content_end == std::string::npos) {
|
609 |
+
result.content = input;
|
610 |
+
} else {
|
611 |
+
tc_start = content_end + prefix.size() - rstrip_prefix;
|
612 |
+
result.content = input.substr(0, content_end);
|
613 |
+
auto tool_calls = json::parse(input.substr(tc_start));
|
614 |
+
for (const auto & tool_call : tool_calls) {
|
615 |
+
result.tool_calls.emplace_back(process_tool_call(tool_call));
|
616 |
+
}
|
617 |
+
}
|
618 |
+
return result;
|
619 |
+
}
|
620 |
+
|
621 |
+
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
622 |
+
for (const auto & tool : tools) {
|
623 |
+
if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
|
624 |
+
LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
|
625 |
+
continue;
|
626 |
+
}
|
627 |
+
fn(tool);
|
628 |
+
}
|
629 |
+
}
|
630 |
+
|
631 |
+
static std::string apply(
|
632 |
+
const common_chat_template & tmpl,
|
633 |
+
const nlohmann::ordered_json & messages,
|
634 |
+
const nlohmann::ordered_json & tools,
|
635 |
+
bool add_generation_prompt,
|
636 |
+
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
|
637 |
+
{
|
638 |
+
minja::chat_template_inputs tmpl_inputs;
|
639 |
+
tmpl_inputs.messages = messages;
|
640 |
+
tmpl_inputs.tools = tools;
|
641 |
+
tmpl_inputs.add_generation_prompt = add_generation_prompt;
|
642 |
+
tmpl_inputs.extra_context = extra_context;
|
643 |
+
// TODO: add flag to control date/time, if only for testing purposes.
|
644 |
+
// tmpl_inputs.now = std::chrono::system_clock::now();
|
645 |
+
|
646 |
+
minja::chat_template_options tmpl_opts;
|
647 |
+
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
648 |
+
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
649 |
+
// may be needed inside the template / between messages too.
|
650 |
+
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
651 |
+
if (string_starts_with(result, tmpl.bos_token())) {
|
652 |
+
result = result.substr(tmpl.bos_token().size());
|
653 |
+
}
|
654 |
+
if (string_ends_with(result, tmpl.eos_token())) {
|
655 |
+
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
656 |
+
}
|
657 |
+
return result;
|
658 |
+
}
|
659 |
+
|
660 |
+
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
661 |
+
common_chat_params data;
|
662 |
+
|
663 |
+
auto tool_call_schemas = json::array();
|
664 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
665 |
+
const auto & function = tool.at("function");
|
666 |
+
auto tool_schema = json {
|
667 |
+
{"type", "object"},
|
668 |
+
{"properties", {
|
669 |
+
{"name", {
|
670 |
+
{"type", "string"},
|
671 |
+
{"const", function.at("name")},
|
672 |
+
}},
|
673 |
+
{"arguments", function.at("parameters")},
|
674 |
+
}},
|
675 |
+
{"required", json::array({"name", "arguments"})},
|
676 |
+
};
|
677 |
+
if (function.contains("description")) {
|
678 |
+
tool_schema["description"] = function.at("description");
|
679 |
+
}
|
680 |
+
if (inputs.parallel_tool_calls) {
|
681 |
+
tool_schema.at("properties")["id"] = {
|
682 |
+
{"type", "string"},
|
683 |
+
{"minLength", 4},
|
684 |
+
};
|
685 |
+
tool_schema.at("required").push_back("id");
|
686 |
+
}
|
687 |
+
tool_call_schemas.emplace_back(tool_schema);
|
688 |
+
});
|
689 |
+
const auto tool_call =
|
690 |
+
inputs.parallel_tool_calls
|
691 |
+
? json {
|
692 |
+
{"type", "object"},
|
693 |
+
{"properties", {
|
694 |
+
{"tool_calls", {
|
695 |
+
{"type", "array"},
|
696 |
+
{"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
697 |
+
{"anyOf", tool_call_schemas},
|
698 |
+
}},
|
699 |
+
{"minItems", 1},
|
700 |
+
}},
|
701 |
+
}},
|
702 |
+
{"required", json::array({"tool_calls"})},
|
703 |
+
}
|
704 |
+
: json {
|
705 |
+
{"type", "object"},
|
706 |
+
{"properties", {
|
707 |
+
{"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
|
708 |
+
{"anyOf", tool_call_schemas},
|
709 |
+
}},
|
710 |
+
}},
|
711 |
+
{"required", json::array({"tool_call"})},
|
712 |
+
};
|
713 |
+
const auto schema =
|
714 |
+
inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
|
715 |
+
? json {
|
716 |
+
{"anyOf", json::array({
|
717 |
+
tool_call,
|
718 |
+
{
|
719 |
+
{"type", "object"},
|
720 |
+
{"properties", {
|
721 |
+
{"response", inputs.json_schema.is_null()
|
722 |
+
? json {{"type", "string"}}
|
723 |
+
: inputs.json_schema
|
724 |
+
},
|
725 |
+
}},
|
726 |
+
{"required", json::array({"response"})},
|
727 |
+
},
|
728 |
+
})}
|
729 |
+
}
|
730 |
+
: tool_call;
|
731 |
+
|
732 |
+
data.grammar_lazy = false;
|
733 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
734 |
+
builder.add_schema("root", schema);
|
735 |
+
});
|
736 |
+
|
737 |
+
auto tweaked_messages = common_chat_template::add_system(
|
738 |
+
inputs.messages,
|
739 |
+
"Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
|
740 |
+
|
741 |
+
data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
742 |
+
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
743 |
+
return data;
|
744 |
+
}
|
745 |
+
static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
746 |
+
json data = json::parse(input);
|
747 |
+
common_chat_msg result;
|
748 |
+
result.role = "assistant";
|
749 |
+
if (data.contains("tool_calls")) {
|
750 |
+
for (const auto & tool_call : data.at("tool_calls")) {
|
751 |
+
result.tool_calls.push_back({
|
752 |
+
tool_call.at("name"),
|
753 |
+
tool_call.at("arguments").dump(),
|
754 |
+
tool_call.contains("id") ? tool_call.at("id") : "",
|
755 |
+
});
|
756 |
+
}
|
757 |
+
} else if (data.contains("tool_call")) {
|
758 |
+
result.tool_calls.push_back({
|
759 |
+
data.at("tool_call").at("name"),
|
760 |
+
data.at("tool_call").at("arguments").dump(),
|
761 |
+
/* id= */ "",
|
762 |
+
});
|
763 |
+
} else if (data.contains("response")) {
|
764 |
+
const auto & response = data.at("response");
|
765 |
+
result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
|
766 |
+
}
|
767 |
+
return result;
|
768 |
+
}
|
769 |
+
|
770 |
+
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
771 |
+
common_chat_params data;
|
772 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
773 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
774 |
+
auto schemas = json::array();
|
775 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
776 |
+
const auto & function = tool.at("function");
|
777 |
+
schemas.push_back({
|
778 |
+
{"type", "object"},
|
779 |
+
{"properties", {
|
780 |
+
// Important note: the model is probably trained to take a JSON stringified arguments value.
|
781 |
+
// It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
|
782 |
+
{"name", {
|
783 |
+
{"type", "string"},
|
784 |
+
{"const", function.at("name")},
|
785 |
+
}},
|
786 |
+
{"arguments", function.at("parameters")},
|
787 |
+
{"id", {
|
788 |
+
{"type", "string"},
|
789 |
+
// Nemo's template expects a 9-character alphanumeric ID.
|
790 |
+
{"pattern", "^[a-zA-Z0-9]{9}$"},
|
791 |
+
}},
|
792 |
+
}},
|
793 |
+
{"required", json::array({"name", "arguments", "id"})},
|
794 |
+
});
|
795 |
+
});
|
796 |
+
auto schema = json {
|
797 |
+
{"type", "array"},
|
798 |
+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
799 |
+
{"minItems", 1},
|
800 |
+
};
|
801 |
+
if (!inputs.parallel_tool_calls) {
|
802 |
+
schema["maxItems"] = 1;
|
803 |
+
}
|
804 |
+
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
805 |
+
});
|
806 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
|
807 |
+
data.preserved_tokens = {
|
808 |
+
"[TOOL_CALLS]",
|
809 |
+
};
|
810 |
+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
811 |
+
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
|
812 |
+
return data;
|
813 |
+
}
|
814 |
+
static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
|
815 |
+
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
816 |
+
}
|
817 |
+
|
818 |
+
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
819 |
+
common_chat_params data;
|
820 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
821 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
822 |
+
auto schemas = json::array();
|
823 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
824 |
+
const auto & function = tool.at("function");
|
825 |
+
schemas.push_back({
|
826 |
+
{"type", "object"},
|
827 |
+
{"properties", {
|
828 |
+
{"tool_call_id", {
|
829 |
+
{"type", "string"},
|
830 |
+
// Command-R's template expects an integer string.
|
831 |
+
{"pattern", "^[0-9]{1,10}$"},
|
832 |
+
}},
|
833 |
+
{"tool_name", {
|
834 |
+
{"type", "string"},
|
835 |
+
{"const", function.at("name")},
|
836 |
+
}},
|
837 |
+
{"parameters", function.at("parameters")},
|
838 |
+
}},
|
839 |
+
{"required", json::array({"tool_call_id", "tool_name", "parameters"})},
|
840 |
+
});
|
841 |
+
});
|
842 |
+
auto schema = json {
|
843 |
+
{"type", "array"},
|
844 |
+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
845 |
+
{"minItems", 1},
|
846 |
+
};
|
847 |
+
if (!inputs.parallel_tool_calls) {
|
848 |
+
schema["maxItems"] = 1;
|
849 |
+
}
|
850 |
+
builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
|
851 |
+
});
|
852 |
+
data.grammar_triggers.push_back({
|
853 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
854 |
+
"<|START_ACTION|>",
|
855 |
+
});
|
856 |
+
data.preserved_tokens = {
|
857 |
+
"<|START_ACTION|>",
|
858 |
+
"<|END_ACTION|>",
|
859 |
+
"<|START_RESPONSE|>",
|
860 |
+
"<|END_RESPONSE|>",
|
861 |
+
"<|START_THINKING|>",
|
862 |
+
"<|END_THINKING|>",
|
863 |
+
};
|
864 |
+
auto adjusted_messages = json::array();
|
865 |
+
for (const auto & msg : inputs.messages) {
|
866 |
+
auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
|
867 |
+
auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
|
868 |
+
if (has_reasoning_content && has_tool_calls) {
|
869 |
+
auto adjusted_message = msg;
|
870 |
+
adjusted_message["tool_plan"] = msg.at("reasoning_content");
|
871 |
+
adjusted_message.erase("reasoning_content");
|
872 |
+
adjusted_messages.push_back(adjusted_message);
|
873 |
+
} else {
|
874 |
+
adjusted_messages.push_back(msg);
|
875 |
+
}
|
876 |
+
}
|
877 |
+
data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
|
878 |
+
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
|
879 |
+
return data;
|
880 |
+
}
|
881 |
+
static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
|
882 |
+
static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
|
883 |
+
static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
|
884 |
+
static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
|
885 |
+
|
886 |
+
std::smatch match;
|
887 |
+
|
888 |
+
common_chat_msg result;
|
889 |
+
result.role = "assistant";
|
890 |
+
|
891 |
+
std::string rest = input;
|
892 |
+
|
893 |
+
if (std::regex_match(rest, match, thought_regex)) {
|
894 |
+
if (extract_reasoning) {
|
895 |
+
result.reasoning_content = match[2].str();
|
896 |
+
} else if (!match[2].str().empty()) {
|
897 |
+
// Let the unparsed thinking tags through in content only if their insides aren't empty.
|
898 |
+
result.content = match[1].str();
|
899 |
+
}
|
900 |
+
rest = match[3].str();
|
901 |
+
}
|
902 |
+
if (std::regex_match(rest, match, action_regex)) {
|
903 |
+
auto actions_str = match[1].str();
|
904 |
+
auto actions = json::parse(actions_str);
|
905 |
+
for (const auto & action : actions) {
|
906 |
+
result.tool_calls.push_back({
|
907 |
+
/* .name = */ action.at("tool_name"),
|
908 |
+
/* .arguments = */ action.at("parameters").dump(),
|
909 |
+
/* .id = */ action.at("tool_call_id"),
|
910 |
+
});
|
911 |
+
}
|
912 |
+
} else if (std::regex_match(rest, match, response_regex)) {
|
913 |
+
auto response = match[1].str();
|
914 |
+
result.content += response;
|
915 |
+
} else {
|
916 |
+
result.content += rest;
|
917 |
+
}
|
918 |
+
return result;
|
919 |
+
}
|
920 |
+
|
921 |
+
static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
|
922 |
+
if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
|
923 |
+
throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
|
924 |
+
}
|
925 |
+
const auto & parameters_properties = parameters.at("properties");
|
926 |
+
const auto & parameters_required = parameters.at("required");
|
927 |
+
for (const auto & prop : expected_properties) {
|
928 |
+
if (!parameters_properties.contains(prop)) {
|
929 |
+
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
|
930 |
+
}
|
931 |
+
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
932 |
+
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
|
933 |
+
}
|
934 |
+
}
|
935 |
+
if (parameters_properties.size() != expected_properties.size()) {
|
936 |
+
throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
|
937 |
+
}
|
938 |
+
}
|
939 |
+
|
940 |
+
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
941 |
+
auto builtin_tools = json::array();
|
942 |
+
common_chat_params data;
|
943 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
944 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
945 |
+
std::vector<std::string> tool_rules;
|
946 |
+
|
947 |
+
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
948 |
+
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
949 |
+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
950 |
+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
951 |
+
expect_tool_parameters(name, parameters, {"query"});
|
952 |
+
} else if (name == "python" || name == "code_interpreter") {
|
953 |
+
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
|
954 |
+
expect_tool_parameters(name, parameters, {"code"});
|
955 |
+
} else {
|
956 |
+
return false;
|
957 |
+
}
|
958 |
+
|
959 |
+
std::vector<std::string> kvs;
|
960 |
+
for (const auto & [key, value] : parameters.at("properties").items()) {
|
961 |
+
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
962 |
+
}
|
963 |
+
|
964 |
+
tool_rules.push_back(
|
965 |
+
builder.add_rule(
|
966 |
+
name + "-call",
|
967 |
+
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
|
968 |
+
builtin_tools.push_back(name);
|
969 |
+
|
970 |
+
return true;
|
971 |
+
};
|
972 |
+
|
973 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
974 |
+
const auto & function = tool.at("function");
|
975 |
+
std::string name = function.at("name");
|
976 |
+
auto parameters = function.at("parameters");
|
977 |
+
builder.resolve_refs(parameters);
|
978 |
+
|
979 |
+
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
|
980 |
+
if (allow_python_tag_builtin_tools) {
|
981 |
+
handle_builtin_tool(name, parameters);
|
982 |
+
}
|
983 |
+
tool_rules.push_back(
|
984 |
+
builder.add_rule(
|
985 |
+
name + "-call",
|
986 |
+
"\"{\" space "
|
987 |
+
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
|
988 |
+
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
|
989 |
+
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
|
990 |
+
"\"}\" space"));
|
991 |
+
});
|
992 |
+
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
|
993 |
+
data.grammar_triggers.push_back({
|
994 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
995 |
+
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
|
996 |
+
});
|
997 |
+
if (!builtin_tools.empty()) {
|
998 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
999 |
+
data.preserved_tokens.push_back("<|python_tag|>");
|
1000 |
+
}
|
1001 |
+
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
1002 |
+
builder.add_rule("root", string_join(tool_rules, " | "));
|
1003 |
+
});
|
1004 |
+
data.additional_stops.push_back("<|eom_id|>");
|
1005 |
+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
1006 |
+
{"tools_in_user_message", false},
|
1007 |
+
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
1008 |
+
});
|
1009 |
+
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
1010 |
+
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
1011 |
+
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
1012 |
+
return data;
|
1013 |
+
}
|
1014 |
+
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
1015 |
+
// TODO: tighten & simplify the parser, don't accept leading text context.
|
1016 |
+
static const std::regex function_regex(
|
1017 |
+
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
|
1018 |
+
static const std::regex close_regex("\\}\\s*");
|
1019 |
+
static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
|
1020 |
+
|
1021 |
+
if (with_builtin_tools) {
|
1022 |
+
std::smatch match;
|
1023 |
+
if (std::regex_match(input, match, builtin_call_regex)) {
|
1024 |
+
try {
|
1025 |
+
auto name = match[1].str();
|
1026 |
+
auto arg_name = match[2].str();
|
1027 |
+
auto arg_value_str = match[3].str();
|
1028 |
+
auto arg_value = json::parse(arg_value_str);
|
1029 |
+
|
1030 |
+
common_chat_msg msg;
|
1031 |
+
msg.role = "assistant";
|
1032 |
+
msg.tool_calls.push_back({
|
1033 |
+
/* .name = */ name,
|
1034 |
+
/* .arguments = */ (json {
|
1035 |
+
{arg_name, arg_value},
|
1036 |
+
}).dump(),
|
1037 |
+
/* .id = */ "",
|
1038 |
+
});
|
1039 |
+
return msg;
|
1040 |
+
} catch (const std::exception & e) {
|
1041 |
+
LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
|
1042 |
+
}
|
1043 |
+
}
|
1044 |
+
}
|
1045 |
+
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
1046 |
+
}
|
1047 |
+
|
1048 |
+
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1049 |
+
common_chat_params data;
|
1050 |
+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
1051 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
|
1052 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
1053 |
+
std::vector<std::string> tool_rules;
|
1054 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
1055 |
+
const auto & function = tool.at("function");
|
1056 |
+
std::string name = function.at("name");
|
1057 |
+
auto parameters = function.at("parameters");
|
1058 |
+
builder.resolve_refs(parameters);
|
1059 |
+
tool_rules.push_back(builder.add_rule(name + "-call",
|
1060 |
+
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
1061 |
+
"```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
|
1062 |
+
"\"```<|tool▁call▁end|>\""));
|
1063 |
+
});
|
1064 |
+
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
|
1065 |
+
// so we accept common variants (then it's all constrained)
|
1066 |
+
builder.add_rule("root",
|
1067 |
+
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
|
1068 |
+
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
|
1069 |
+
"\"<|tool▁calls▁end|>\""
|
1070 |
+
" space");
|
1071 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
|
1072 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
|
1073 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
|
1074 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
|
1075 |
+
data.preserved_tokens = {
|
1076 |
+
"<think>",
|
1077 |
+
"</think>",
|
1078 |
+
"<|tool▁calls▁begin|>",
|
1079 |
+
"<|tool▁call▁begin|>",
|
1080 |
+
"<|tool▁sep|>",
|
1081 |
+
"<|tool▁call▁end|>",
|
1082 |
+
"<|tool▁calls▁end|",
|
1083 |
+
};
|
1084 |
+
});
|
1085 |
+
}
|
1086 |
+
auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
1087 |
+
|
1088 |
+
// Hacks to fix the official (broken) prompt.
|
1089 |
+
// It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
|
1090 |
+
// until the official template is fixed.
|
1091 |
+
if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
|
1092 |
+
// Don't leave the chat dangling after tool results
|
1093 |
+
if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
|
1094 |
+
prompt += "<|end▁of▁sentence|>";
|
1095 |
+
if (inputs.add_generation_prompt) {
|
1096 |
+
prompt += "<|Assistant|>";
|
1097 |
+
}
|
1098 |
+
}
|
1099 |
+
// Fix up tool call delta example added by Minja
|
1100 |
+
prompt = std::regex_replace(
|
1101 |
+
prompt,
|
1102 |
+
std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
|
1103 |
+
"$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
|
1104 |
+
}
|
1105 |
+
data.prompt = prompt;
|
1106 |
+
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
|
1107 |
+
return data;
|
1108 |
+
}
|
1109 |
+
static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
|
1110 |
+
std::smatch match;
|
1111 |
+
static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
|
1112 |
+
if (std::regex_match(input, match, reasoning_content_regex)) {
|
1113 |
+
auto rest = match[3].str();
|
1114 |
+
auto msg = rest_parser(rest);
|
1115 |
+
auto reasoning_content = string_strip(match[2].str());
|
1116 |
+
if (extract_reasoning) {
|
1117 |
+
msg.reasoning_content = reasoning_content;
|
1118 |
+
} else if (!reasoning_content.empty()) {
|
1119 |
+
std::ostringstream content;
|
1120 |
+
content << "<think>" << reasoning_content << "</think>" << msg.content;
|
1121 |
+
msg.content = content.str();
|
1122 |
+
}
|
1123 |
+
return msg;
|
1124 |
+
}
|
1125 |
+
return rest_parser(input);
|
1126 |
+
}
|
1127 |
+
static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
|
1128 |
+
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
1129 |
+
static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
|
1130 |
+
static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
|
1131 |
+
static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
|
1132 |
+
|
1133 |
+
common_chat_msg msg;
|
1134 |
+
msg.role = "assistant";
|
1135 |
+
std::smatch match;
|
1136 |
+
if (std::regex_search(input, match, tool_calls_regex)) {
|
1137 |
+
auto tool_calls = match[1].str();
|
1138 |
+
auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
|
1139 |
+
msg.tool_calls = std::move(msg2.tool_calls);
|
1140 |
+
} else {
|
1141 |
+
msg.content = input;
|
1142 |
+
}
|
1143 |
+
return msg;
|
1144 |
+
});
|
1145 |
+
}
|
1146 |
+
|
1147 |
+
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1148 |
+
LOG_DBG("%s\n", __func__);
|
1149 |
+
common_chat_params data;
|
1150 |
+
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
1151 |
+
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
1152 |
+
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
1153 |
+
});
|
1154 |
+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
1155 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
1156 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
1157 |
+
auto schemas = json::array();
|
1158 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
1159 |
+
const auto & function = tool.at("function");
|
1160 |
+
schemas.push_back({
|
1161 |
+
{"type", "object"},
|
1162 |
+
{"properties", {
|
1163 |
+
{"name", {
|
1164 |
+
{"type", "string"},
|
1165 |
+
{"const", function.at("name")},
|
1166 |
+
}},
|
1167 |
+
{"arguments", function.at("parameters")},
|
1168 |
+
}},
|
1169 |
+
{"required", json::array({"name", "arguments", "id"})},
|
1170 |
+
});
|
1171 |
+
});
|
1172 |
+
auto schema = json {
|
1173 |
+
{"type", "array"},
|
1174 |
+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
|
1175 |
+
{"minItems", 1},
|
1176 |
+
};
|
1177 |
+
if (!inputs.parallel_tool_calls) {
|
1178 |
+
schema["maxItems"] = 1;
|
1179 |
+
}
|
1180 |
+
builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
|
1181 |
+
});
|
1182 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
|
1183 |
+
data.preserved_tokens = {
|
1184 |
+
" functools[",
|
1185 |
+
};
|
1186 |
+
data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
|
1187 |
+
} else {
|
1188 |
+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1189 |
+
}
|
1190 |
+
return data;
|
1191 |
+
}
|
1192 |
+
static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
|
1193 |
+
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
1194 |
+
}
|
1195 |
+
|
1196 |
+
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1197 |
+
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
1198 |
+
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
1199 |
+
common_chat_params data;
|
1200 |
+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
1201 |
+
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
1202 |
+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
1203 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
1204 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
1205 |
+
std::vector<std::string> first_tool_rules;
|
1206 |
+
std::vector<std::string> subsequent_tool_rules;
|
1207 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
1208 |
+
const auto & function = tool.at("function");
|
1209 |
+
std::string name = function.at("name");
|
1210 |
+
auto parameters = function.at("parameters");
|
1211 |
+
builder.resolve_refs(parameters);
|
1212 |
+
auto args_rule = builder.add_schema(name + "-args", parameters);
|
1213 |
+
first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
|
1214 |
+
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
1215 |
+
data.grammar_triggers.push_back({
|
1216 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
1217 |
+
regex_escape(name + "\n"),
|
1218 |
+
});
|
1219 |
+
data.grammar_triggers.push_back({
|
1220 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
1221 |
+
regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
|
1222 |
+
});
|
1223 |
+
data.grammar_triggers.push_back({
|
1224 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
1225 |
+
regex_escape(">>>" + name + "\n"),
|
1226 |
+
});
|
1227 |
+
data.grammar_triggers.push_back({
|
1228 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
1229 |
+
">>>assistant<|end_header_id|>\n" + name,
|
1230 |
+
});
|
1231 |
+
});
|
1232 |
+
data.preserved_tokens = {
|
1233 |
+
"<|end_header_id|>",
|
1234 |
+
};
|
1235 |
+
auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
|
1236 |
+
if (inputs.parallel_tool_calls) {
|
1237 |
+
auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
|
1238 |
+
builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
|
1239 |
+
} else {
|
1240 |
+
builder.add_rule("root", first_rule);
|
1241 |
+
}
|
1242 |
+
|
1243 |
+
});
|
1244 |
+
}
|
1245 |
+
return data;
|
1246 |
+
}
|
1247 |
+
|
1248 |
+
static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
|
1249 |
+
static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
|
1250 |
+
static const std::regex close_regex(R"($|(?=>>>))");
|
1251 |
+
|
1252 |
+
std::string content;
|
1253 |
+
auto it = input.begin();
|
1254 |
+
const auto end = input.end();
|
1255 |
+
|
1256 |
+
if (parse_literal(it, end, "all\n")) {
|
1257 |
+
std::smatch match;
|
1258 |
+
if (std::regex_search(it, end, match, function_regex)) {
|
1259 |
+
auto fun_it = match.prefix().second;
|
1260 |
+
content = std::string(it, fun_it);
|
1261 |
+
it = fun_it;
|
1262 |
+
} else {
|
1263 |
+
common_chat_msg res;
|
1264 |
+
res.role = "assistant";
|
1265 |
+
res.content = std::string(it, end);
|
1266 |
+
return res;
|
1267 |
+
}
|
1268 |
+
}
|
1269 |
+
// TODO: tighten & simplify.
|
1270 |
+
try {
|
1271 |
+
auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
|
1272 |
+
res.content = content + res.content;
|
1273 |
+
return res;
|
1274 |
+
} catch (const std::exception & e) {
|
1275 |
+
LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what());
|
1276 |
+
common_chat_msg res;
|
1277 |
+
res.role = "assistant";
|
1278 |
+
res.content = input;
|
1279 |
+
return res;
|
1280 |
+
}
|
1281 |
+
}
|
1282 |
+
|
1283 |
+
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1284 |
+
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
1285 |
+
common_chat_params data;
|
1286 |
+
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
1287 |
+
std::string python_code_argument_name;
|
1288 |
+
auto has_raw_python = false;
|
1289 |
+
|
1290 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
1291 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
1292 |
+
std::vector<std::string> tool_rules;
|
1293 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
1294 |
+
const auto & function = tool.at("function");
|
1295 |
+
const auto & parameters = function.at("parameters");
|
1296 |
+
std::string name = function.at("name");
|
1297 |
+
if (name == "python" || name == "ipython") {
|
1298 |
+
if (!parameters.contains("type")) {
|
1299 |
+
throw std::runtime_error("Missing type in python tool");
|
1300 |
+
}
|
1301 |
+
has_raw_python = true;
|
1302 |
+
const auto & type = parameters.at("type");
|
1303 |
+
if (type == "object") {
|
1304 |
+
auto properties = parameters.at("properties");
|
1305 |
+
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
1306 |
+
if (it.value().at("type") == "string") {
|
1307 |
+
if (!python_code_argument_name.empty()) {
|
1308 |
+
throw std::runtime_error("Multiple string arguments found in python tool");
|
1309 |
+
}
|
1310 |
+
python_code_argument_name = it.key();
|
1311 |
+
}
|
1312 |
+
}
|
1313 |
+
if (python_code_argument_name.empty()) {
|
1314 |
+
throw std::runtime_error("No string argument found in python tool");
|
1315 |
+
}
|
1316 |
+
} else if (type != "string") {
|
1317 |
+
throw std::runtime_error("Invalid type in python tool: " + type.dump());
|
1318 |
+
}
|
1319 |
+
}
|
1320 |
+
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
|
1321 |
+
});
|
1322 |
+
if (has_raw_python) {
|
1323 |
+
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
|
1324 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
|
1325 |
+
data.preserved_tokens.push_back("<|python_tag|>");
|
1326 |
+
}
|
1327 |
+
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
|
1328 |
+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
1329 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
1330 |
+
});
|
1331 |
+
|
1332 |
+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
1333 |
+
// TODO: if (has_raw_python)
|
1334 |
+
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
1335 |
+
return data;
|
1336 |
+
}
|
1337 |
+
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
1338 |
+
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
|
1339 |
+
static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
|
1340 |
+
std::smatch match;
|
1341 |
+
if (std::regex_search(input, match, python_tag_regex)) {
|
1342 |
+
auto code = match[1].str();
|
1343 |
+
common_chat_msg msg;
|
1344 |
+
msg.role = "assistant";
|
1345 |
+
msg.content = match.prefix().str();
|
1346 |
+
msg.tool_calls.push_back({
|
1347 |
+
/* .name = */ "python",
|
1348 |
+
/* .arguments = */ (json {{"code", code}}).dump(),
|
1349 |
+
/* .id = */ "",
|
1350 |
+
});
|
1351 |
+
return msg;
|
1352 |
+
}
|
1353 |
+
static const std::regex function_regex(R"(<function=(\w+)>)");
|
1354 |
+
static const std::regex close_regex(R"(</function>)");
|
1355 |
+
// TODO: tighten & simplify.
|
1356 |
+
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
1357 |
+
}
|
1358 |
+
|
1359 |
+
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1360 |
+
common_chat_params data;
|
1361 |
+
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
1362 |
+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
1363 |
+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
1364 |
+
std::vector<std::string> tool_rules;
|
1365 |
+
std::vector<std::string> tool_call_alts;
|
1366 |
+
foreach_function(inputs.tools, [&](const json & tool) {
|
1367 |
+
const auto & function = tool.at("function");
|
1368 |
+
std::string name = function.at("name");
|
1369 |
+
auto parameters = function.at("parameters");
|
1370 |
+
builder.resolve_refs(parameters);
|
1371 |
+
tool_rules.push_back(builder.add_schema(name + "-call", {
|
1372 |
+
{"type", "object"},
|
1373 |
+
{"properties", json {
|
1374 |
+
{"name", json {{"const", name}}},
|
1375 |
+
{"arguments", parameters},
|
1376 |
+
}},
|
1377 |
+
{"required", json::array({"name", "arguments"})},
|
1378 |
+
}));
|
1379 |
+
tool_call_alts.push_back(builder.add_rule(
|
1380 |
+
name + "-function-tag",
|
1381 |
+
"\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
|
1382 |
+
builder.add_schema(name + "-args", parameters) + " "
|
1383 |
+
"\"</function>\" space"));
|
1384 |
+
|
1385 |
+
data.grammar_triggers.push_back({
|
1386 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
1387 |
+
"<function=" + name + ">",
|
1388 |
+
});
|
1389 |
+
auto escaped_name = regex_escape(name);
|
1390 |
+
data.grammar_triggers.push_back({
|
1391 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
1392 |
+
"<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
|
1393 |
+
});
|
1394 |
+
});
|
1395 |
+
auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
|
1396 |
+
std::vector<std::string> alt_tags {
|
1397 |
+
any_tool_call,
|
1398 |
+
"\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
|
1399 |
+
// The rest is just to accommodate common "good bad" outputs.
|
1400 |
+
"\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
|
1401 |
+
"\"<response>\" space " + any_tool_call + " \"</response>\"",
|
1402 |
+
"\"<tools>\" space " + any_tool_call + " \"</tools>\"",
|
1403 |
+
"\"<json>\" space " + any_tool_call + " \"</json>\"",
|
1404 |
+
"\"<xml>\" space " + any_tool_call + " \"</xml>\"",
|
1405 |
+
"\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
|
1406 |
+
};
|
1407 |
+
auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
|
1408 |
+
tool_call_alts.push_back(wrappable_tool_call);
|
1409 |
+
tool_call_alts.push_back(
|
1410 |
+
"( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
|
1411 |
+
auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
|
1412 |
+
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
1413 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
|
1414 |
+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
|
1415 |
+
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
|
1416 |
+
data.grammar_triggers.push_back({
|
1417 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
1418 |
+
"(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
|
1419 |
+
});
|
1420 |
+
data.preserved_tokens = {
|
1421 |
+
"<think>",
|
1422 |
+
"</think>",
|
1423 |
+
"<tool_call>",
|
1424 |
+
"</tool_call>",
|
1425 |
+
"<function",
|
1426 |
+
"<tools>",
|
1427 |
+
"</tools>",
|
1428 |
+
"<response>",
|
1429 |
+
"</response>",
|
1430 |
+
"<function_call>",
|
1431 |
+
"</function_call>",
|
1432 |
+
"<json>",
|
1433 |
+
"</json>",
|
1434 |
+
"<JSON>",
|
1435 |
+
"</JSON>",
|
1436 |
+
"```",
|
1437 |
+
"```json",
|
1438 |
+
"```xml",
|
1439 |
+
};
|
1440 |
+
});
|
1441 |
+
|
1442 |
+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
1443 |
+
data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
|
1444 |
+
return data;
|
1445 |
+
}
|
1446 |
+
static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
|
1447 |
+
return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
|
1448 |
+
static const std::regex open_regex(
|
1449 |
+
"(?:"
|
1450 |
+
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
|
1451 |
+
"(<tool_call>" // match 2 (open_tag)
|
1452 |
+
"|<function_call>"
|
1453 |
+
"|<tool>"
|
1454 |
+
"|<tools>"
|
1455 |
+
"|<response>"
|
1456 |
+
"|<json>"
|
1457 |
+
"|<xml>"
|
1458 |
+
"|<JSON>"
|
1459 |
+
")?"
|
1460 |
+
"(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
|
1461 |
+
")"
|
1462 |
+
"|"
|
1463 |
+
"(?:<function=([^>]+)>" // match 4 (function name)
|
1464 |
+
"|<function name=\"([^\"]+)\">)" // match 5 (function name again)
|
1465 |
+
"([\\s\\S]*)" // match 6 (function arguments + rest)})"
|
1466 |
+
);
|
1467 |
+
|
1468 |
+
try {
|
1469 |
+
common_chat_msg msg;
|
1470 |
+
msg.role = "assistant";
|
1471 |
+
|
1472 |
+
std::string::const_iterator it = input.begin();
|
1473 |
+
const std::string::const_iterator end = input.end();
|
1474 |
+
std::smatch match;
|
1475 |
+
|
1476 |
+
while (it != end) {
|
1477 |
+
if (std::regex_search(it, end, match, open_regex)) {
|
1478 |
+
// Add content before the match
|
1479 |
+
msg.content += std::string(it, match[0].first);
|
1480 |
+
|
1481 |
+
auto block_start = match[1].str();
|
1482 |
+
std::string block_end = block_start.empty() ? "" : "```";
|
1483 |
+
|
1484 |
+
auto open_tag = match[2].str();
|
1485 |
+
std::string close_tag;
|
1486 |
+
|
1487 |
+
if (match[3].matched) {
|
1488 |
+
close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
|
1489 |
+
auto json_it = match[3].first;
|
1490 |
+
json tool_call;
|
1491 |
+
if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
|
1492 |
+
|
1493 |
+
msg.tool_calls.emplace_back(process_tool_call(tool_call));
|
1494 |
+
it = json_it; // Move iterator past parsed JSON
|
1495 |
+
|
1496 |
+
// Handle close tags
|
1497 |
+
consume_spaces(it, end);
|
1498 |
+
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
1499 |
+
throw std::runtime_error("Failed to parse closing tag");
|
1500 |
+
}
|
1501 |
+
consume_spaces(it, end);
|
1502 |
+
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
1503 |
+
throw std::runtime_error("Failed to parse block end");
|
1504 |
+
}
|
1505 |
+
consume_spaces(it, end);
|
1506 |
+
} else {
|
1507 |
+
// Not a valid tool call, treat as content
|
1508 |
+
msg.content += std::string(match[0].first, match[0].second);
|
1509 |
+
it = match[0].second;
|
1510 |
+
}
|
1511 |
+
} else {
|
1512 |
+
auto function_name = match[4].str();
|
1513 |
+
if (function_name.empty()) {
|
1514 |
+
function_name = match[5].str();
|
1515 |
+
}
|
1516 |
+
GGML_ASSERT(!function_name.empty());
|
1517 |
+
|
1518 |
+
close_tag = "</function>";
|
1519 |
+
// Start parsing from after the opening tags
|
1520 |
+
auto json_it = match[6].first;
|
1521 |
+
json arguments;
|
1522 |
+
if (parse_json(json_it, end, arguments)) {
|
1523 |
+
msg.tool_calls.emplace_back(process_tool_call({
|
1524 |
+
{"name", function_name},
|
1525 |
+
{"arguments", arguments},
|
1526 |
+
}));
|
1527 |
+
it = json_it; // Move iterator past parsed JSON
|
1528 |
+
|
1529 |
+
// Handle close tags
|
1530 |
+
consume_spaces(it, end);
|
1531 |
+
if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
|
1532 |
+
throw std::runtime_error("Failed to parse closing tag");
|
1533 |
+
}
|
1534 |
+
consume_spaces(it, end);
|
1535 |
+
if (!block_end.empty() && !parse_literal(it, end, block_end)) {
|
1536 |
+
throw std::runtime_error("Failed to parse block end");
|
1537 |
+
}
|
1538 |
+
consume_spaces(it, end);
|
1539 |
+
} else {
|
1540 |
+
// Not a valid tool call, treat as content
|
1541 |
+
msg.content += std::string(match[0].first, match[0].second);
|
1542 |
+
it = match[0].second;
|
1543 |
+
}
|
1544 |
+
}
|
1545 |
+
} else {
|
1546 |
+
// Add remaining content
|
1547 |
+
msg.content += std::string(it, end);
|
1548 |
+
break;
|
1549 |
+
}
|
1550 |
+
}
|
1551 |
+
return msg;
|
1552 |
+
} catch (const std::exception & e) {
|
1553 |
+
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
1554 |
+
common_chat_msg msg;
|
1555 |
+
msg.role = "assistant";
|
1556 |
+
msg.content = input;
|
1557 |
+
return msg;
|
1558 |
+
}
|
1559 |
+
});
|
1560 |
+
}
|
1561 |
+
|
1562 |
+
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
1563 |
+
common_chat_params data;
|
1564 |
+
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
1565 |
+
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
1566 |
+
data.grammar_lazy = false;
|
1567 |
+
if (!inputs.json_schema.is_null()) {
|
1568 |
+
if (!inputs.grammar.empty()) {
|
1569 |
+
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
1570 |
+
}
|
1571 |
+
data.grammar = json_schema_to_grammar(inputs.json_schema);
|
1572 |
+
} else {
|
1573 |
+
data.grammar = inputs.grammar;
|
1574 |
+
}
|
1575 |
+
return data;
|
1576 |
+
}
|
1577 |
+
|
1578 |
+
static common_chat_params common_chat_templates_apply_jinja(
|
1579 |
+
const struct common_chat_templates * tmpls,
|
1580 |
+
const struct common_chat_templates_inputs & inputs)
|
1581 |
+
{
|
1582 |
+
templates_params params;
|
1583 |
+
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
1584 |
+
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
1585 |
+
? *tmpls->template_tool_use
|
1586 |
+
: *tmpls->template_default;
|
1587 |
+
const auto & src = tmpl.source();
|
1588 |
+
const auto & caps = tmpl.original_caps();
|
1589 |
+
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
1590 |
+
params.add_generation_prompt = inputs.add_generation_prompt;
|
1591 |
+
params.extract_reasoning = inputs.extract_reasoning;
|
1592 |
+
params.tool_choice = inputs.tool_choice;
|
1593 |
+
params.grammar = inputs.grammar;
|
1594 |
+
if (!inputs.json_schema.empty()) {
|
1595 |
+
params.json_schema = json::parse(inputs.json_schema);
|
1596 |
+
}
|
1597 |
+
|
1598 |
+
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
1599 |
+
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
1600 |
+
params.parallel_tool_calls = false;
|
1601 |
+
} else {
|
1602 |
+
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
1603 |
+
}
|
1604 |
+
|
1605 |
+
if (params.tools.is_array()) {
|
1606 |
+
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
1607 |
+
throw std::runtime_error("Cannot specify grammar with tools");
|
1608 |
+
}
|
1609 |
+
if (caps.supports_tool_calls && !caps.supports_tools) {
|
1610 |
+
LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
|
1611 |
+
}
|
1612 |
+
}
|
1613 |
+
|
1614 |
+
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
1615 |
+
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
|
1616 |
+
return common_chat_params_init_deepseek_r1(tmpl, params);
|
1617 |
+
}
|
1618 |
+
|
1619 |
+
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
1620 |
+
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
|
1621 |
+
return common_chat_params_init_command_r7b(tmpl, params);
|
1622 |
+
}
|
1623 |
+
|
1624 |
+
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
1625 |
+
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
1626 |
+
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
1627 |
+
}
|
1628 |
+
|
1629 |
+
// Use generic handler when mixing tools + JSON schema.
|
1630 |
+
// TODO: support that mix in handlers below.
|
1631 |
+
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
1632 |
+
return common_chat_params_init_generic(tmpl, params);
|
1633 |
+
}
|
1634 |
+
|
1635 |
+
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
1636 |
+
if (src.find(">>>all") != std::string::npos) {
|
1637 |
+
return common_chat_params_init_functionary_v3_2(tmpl, params);
|
1638 |
+
}
|
1639 |
+
|
1640 |
+
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
1641 |
+
if (src.find(" functools[") != std::string::npos) {
|
1642 |
+
return common_chat_params_init_firefunction_v2(tmpl, params);
|
1643 |
+
}
|
1644 |
+
|
1645 |
+
// Plain handler (no tools)
|
1646 |
+
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
1647 |
+
return common_chat_params_init_without_tools(tmpl, params);
|
1648 |
+
}
|
1649 |
+
|
1650 |
+
// Functionary v3.1 (w/ tools)
|
1651 |
+
if (src.find("<|start_header_id|>") != std::string::npos
|
1652 |
+
&& src.find("<function=") != std::string::npos) {
|
1653 |
+
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
1654 |
+
}
|
1655 |
+
|
1656 |
+
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
1657 |
+
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
1658 |
+
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
1659 |
+
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
1660 |
+
}
|
1661 |
+
|
1662 |
+
// Mistral Nemo (w/ tools)
|
1663 |
+
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
1664 |
+
return common_chat_params_init_mistral_nemo(tmpl, params);
|
1665 |
+
}
|
1666 |
+
|
1667 |
+
// Generic fallback
|
1668 |
+
return common_chat_params_init_generic(tmpl, params);
|
1669 |
+
}
|
1670 |
+
|
1671 |
+
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
|
1672 |
+
static common_chat_params common_chat_templates_apply_legacy(
|
1673 |
+
const struct common_chat_templates * tmpls,
|
1674 |
+
const struct common_chat_templates_inputs & inputs)
|
1675 |
+
{
|
1676 |
+
int alloc_size = 0;
|
1677 |
+
std::vector<llama_chat_message> chat;
|
1678 |
+
std::vector<std::string> contents;
|
1679 |
+
for (const auto & msg : inputs.messages) {
|
1680 |
+
auto content = msg.content;
|
1681 |
+
for (const auto & part : msg.content_parts) {
|
1682 |
+
if (part.type != "text") {
|
1683 |
+
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
1684 |
+
continue;
|
1685 |
+
}
|
1686 |
+
if (!content.empty()) {
|
1687 |
+
content += "\n";;
|
1688 |
+
}
|
1689 |
+
content += part.text;
|
1690 |
+
}
|
1691 |
+
contents.emplace_back(std::move(content));
|
1692 |
+
}
|
1693 |
+
for (size_t i = 0; i < contents.size(); ++i) {
|
1694 |
+
const auto & msg = inputs.messages[i];
|
1695 |
+
const auto & content = contents[i];
|
1696 |
+
chat.push_back({msg.role.c_str(), content.c_str()});
|
1697 |
+
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
1698 |
+
}
|
1699 |
+
|
1700 |
+
std::vector<char> buf(alloc_size);
|
1701 |
+
|
1702 |
+
// run the first time to get the total output length
|
1703 |
+
const auto & src = tmpls->template_default->source();
|
1704 |
+
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
1705 |
+
|
1706 |
+
// error: chat template is not supported
|
1707 |
+
if (res < 0) {
|
1708 |
+
// if the custom "tmpl" is not supported, we throw an error
|
1709 |
+
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
1710 |
+
throw std::runtime_error("this custom template is not supported");
|
1711 |
+
}
|
1712 |
+
|
1713 |
+
// if it turns out that our buffer is too small, we resize it
|
1714 |
+
if ((size_t) res > buf.size()) {
|
1715 |
+
buf.resize(res);
|
1716 |
+
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
1717 |
+
}
|
1718 |
+
|
1719 |
+
common_chat_params params;
|
1720 |
+
params.prompt = std::string(buf.data(), res);
|
1721 |
+
if (!inputs.json_schema.empty()) {
|
1722 |
+
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
|
1723 |
+
} else {
|
1724 |
+
params.grammar = inputs.grammar;
|
1725 |
+
}
|
1726 |
+
return params;
|
1727 |
+
}
|
1728 |
+
|
1729 |
+
common_chat_params common_chat_templates_apply(
|
1730 |
+
const struct common_chat_templates * tmpls,
|
1731 |
+
const struct common_chat_templates_inputs & inputs)
|
1732 |
+
{
|
1733 |
+
GGML_ASSERT(tmpls != nullptr);
|
1734 |
+
return inputs.use_jinja
|
1735 |
+
? common_chat_templates_apply_jinja(tmpls, inputs)
|
1736 |
+
: common_chat_templates_apply_legacy(tmpls, inputs);
|
1737 |
+
}
|
1738 |
+
|
1739 |
+
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
1740 |
+
common_chat_msg msg;
|
1741 |
+
msg.role = "assistant";
|
1742 |
+
msg.content = input;
|
1743 |
+
return msg;
|
1744 |
+
}
|
1745 |
+
|
1746 |
+
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
1747 |
+
switch (format) {
|
1748 |
+
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
|
1749 |
+
return common_chat_parse_content_only(input);
|
1750 |
+
case COMMON_CHAT_FORMAT_GENERIC:
|
1751 |
+
return common_chat_parse_generic(input);
|
1752 |
+
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
|
1753 |
+
return common_chat_parse_mistral_nemo(input);
|
1754 |
+
case COMMON_CHAT_FORMAT_LLAMA_3_X:
|
1755 |
+
return common_chat_parse_llama_3_1(input);
|
1756 |
+
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
|
1757 |
+
return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
|
1758 |
+
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
|
1759 |
+
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
|
1760 |
+
case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
|
1761 |
+
return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
|
1762 |
+
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
|
1763 |
+
return common_chat_parse_functionary_v3_2(input);
|
1764 |
+
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
|
1765 |
+
return common_chat_parse_functionary_v3_1_llama_3_1(input);
|
1766 |
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
|
1767 |
+
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
|
1768 |
+
case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
|
1769 |
+
return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
|
1770 |
+
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
|
1771 |
+
return common_chat_parse_firefunction_v2(input);
|
1772 |
+
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
1773 |
+
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
|
1774 |
+
case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
|
1775 |
+
return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
|
1776 |
+
default:
|
1777 |
+
throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
|
1778 |
+
}
|
1779 |
+
}
|
common/chat.h
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "common.h"
|
6 |
+
#include <string>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
struct common_chat_templates;
|
10 |
+
|
11 |
+
struct common_chat_tool_call {
|
12 |
+
std::string name;
|
13 |
+
std::string arguments;
|
14 |
+
std::string id;
|
15 |
+
};
|
16 |
+
|
17 |
+
struct common_chat_msg_content_part {
|
18 |
+
std::string type;
|
19 |
+
std::string text;
|
20 |
+
};
|
21 |
+
|
22 |
+
struct common_chat_msg {
|
23 |
+
std::string role;
|
24 |
+
std::string content;
|
25 |
+
std::vector<common_chat_msg_content_part> content_parts = {};
|
26 |
+
std::vector<common_chat_tool_call> tool_calls = {};
|
27 |
+
std::string reasoning_content;
|
28 |
+
std::string tool_name;
|
29 |
+
std::string tool_call_id;
|
30 |
+
};
|
31 |
+
|
32 |
+
struct common_chat_tool {
|
33 |
+
std::string name;
|
34 |
+
std::string description;
|
35 |
+
std::string parameters;
|
36 |
+
};
|
37 |
+
|
38 |
+
enum common_chat_tool_choice {
|
39 |
+
COMMON_CHAT_TOOL_CHOICE_AUTO,
|
40 |
+
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
|
41 |
+
COMMON_CHAT_TOOL_CHOICE_NONE,
|
42 |
+
};
|
43 |
+
|
44 |
+
enum common_chat_format {
|
45 |
+
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
46 |
+
COMMON_CHAT_FORMAT_GENERIC,
|
47 |
+
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
48 |
+
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
49 |
+
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
50 |
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
51 |
+
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
52 |
+
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
53 |
+
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
54 |
+
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
55 |
+
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
56 |
+
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
|
57 |
+
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
58 |
+
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
59 |
+
|
60 |
+
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
61 |
+
};
|
62 |
+
|
63 |
+
struct common_chat_templates_inputs {
|
64 |
+
std::vector<common_chat_msg> messages;
|
65 |
+
std::string grammar;
|
66 |
+
std::string json_schema;
|
67 |
+
bool add_generation_prompt = true;
|
68 |
+
bool use_jinja = true;
|
69 |
+
// Parameters below only supported when use_jinja is true
|
70 |
+
std::vector<common_chat_tool> tools;
|
71 |
+
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
72 |
+
bool parallel_tool_calls = false;
|
73 |
+
bool extract_reasoning = true;
|
74 |
+
};
|
75 |
+
|
76 |
+
struct common_chat_params {
|
77 |
+
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
78 |
+
std::string prompt;
|
79 |
+
std::string grammar;
|
80 |
+
bool grammar_lazy = false;
|
81 |
+
std::vector<common_grammar_trigger> grammar_triggers;
|
82 |
+
std::vector<std::string> preserved_tokens;
|
83 |
+
std::vector<std::string> additional_stops;
|
84 |
+
};
|
85 |
+
|
86 |
+
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
87 |
+
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
88 |
+
|
89 |
+
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
90 |
+
|
91 |
+
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
92 |
+
|
93 |
+
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
94 |
+
|
95 |
+
common_chat_templates_ptr common_chat_templates_init(
|
96 |
+
const struct llama_model * model,
|
97 |
+
const std::string & chat_template_override,
|
98 |
+
const std::string & bos_token_override = "",
|
99 |
+
const std::string & eos_token_override = "");
|
100 |
+
|
101 |
+
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
102 |
+
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
103 |
+
|
104 |
+
|
105 |
+
struct common_chat_params common_chat_templates_apply(
|
106 |
+
const struct common_chat_templates * tmpls,
|
107 |
+
const struct common_chat_templates_inputs & inputs);
|
108 |
+
|
109 |
+
// Format single message, while taking into account the position of that message in chat history
|
110 |
+
std::string common_chat_format_single(
|
111 |
+
const struct common_chat_templates * tmpls,
|
112 |
+
const std::vector<common_chat_msg> & past_msg,
|
113 |
+
const common_chat_msg & new_msg,
|
114 |
+
bool add_ass,
|
115 |
+
bool use_jinja);
|
116 |
+
|
117 |
+
// Returns an example of formatted chat
|
118 |
+
std::string common_chat_format_example(
|
119 |
+
const struct common_chat_templates * tmpls,
|
120 |
+
bool use_jinja);
|
121 |
+
|
122 |
+
std::string common_chat_format_name(common_chat_format format);
|
123 |
+
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
124 |
+
|
125 |
+
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
126 |
+
|
127 |
+
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
128 |
+
// T can be std::string containing JSON or nlohmann::ordered_json
|
129 |
+
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
130 |
+
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
131 |
+
|
132 |
+
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
133 |
+
// T can be std::string containing JSON or nlohmann::ordered_json
|
134 |
+
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
135 |
+
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
common/common.cpp
ADDED
@@ -0,0 +1,2058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#if defined(_MSC_VER)
|
2 |
+
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
|
3 |
+
#endif
|
4 |
+
|
5 |
+
#include "ggml.h"
|
6 |
+
#include "gguf.h"
|
7 |
+
|
8 |
+
#include "common.h"
|
9 |
+
#include "log.h"
|
10 |
+
#include "build-info.h"
|
11 |
+
#include "log.cpp"
|
12 |
+
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
13 |
+
#define JSON_ASSERT GGML_ASSERT
|
14 |
+
#include "json.hpp"
|
15 |
+
#include "json-schema-to-grammar.cpp"
|
16 |
+
#include "llama.h"
|
17 |
+
#include "chat.cpp"
|
18 |
+
|
19 |
+
#include <algorithm>
|
20 |
+
#include <cinttypes>
|
21 |
+
#include <climits>
|
22 |
+
#include <cmath>
|
23 |
+
#include <codecvt>
|
24 |
+
#include <cstdarg>
|
25 |
+
#include <cstring>
|
26 |
+
#include <ctime>
|
27 |
+
#include <filesystem>
|
28 |
+
#include <fstream>
|
29 |
+
#include <iostream>
|
30 |
+
#include <iterator>
|
31 |
+
#include <regex>
|
32 |
+
#include <sstream>
|
33 |
+
#include <string>
|
34 |
+
#include <thread>
|
35 |
+
#include <unordered_map>
|
36 |
+
#include <unordered_set>
|
37 |
+
#include <vector>
|
38 |
+
|
39 |
+
#if defined(__APPLE__) && defined(__MACH__)
|
40 |
+
#include <sys/types.h>
|
41 |
+
#include <sys/sysctl.h>
|
42 |
+
#endif
|
43 |
+
|
44 |
+
#if defined(_WIN32)
|
45 |
+
#define WIN32_LEAN_AND_MEAN
|
46 |
+
#ifndef NOMINMAX
|
47 |
+
# define NOMINMAX
|
48 |
+
#endif
|
49 |
+
#include <locale>
|
50 |
+
#include <windows.h>
|
51 |
+
#include <fcntl.h>
|
52 |
+
#include <io.h>
|
53 |
+
#else
|
54 |
+
#include <sys/ioctl.h>
|
55 |
+
#include <sys/stat.h>
|
56 |
+
#include <unistd.h>
|
57 |
+
#endif
|
58 |
+
#if defined(LLAMA_USE_CURL)
|
59 |
+
#include <curl/curl.h>
|
60 |
+
#include <curl/easy.h>
|
61 |
+
#include <future>
|
62 |
+
#endif
|
63 |
+
|
64 |
+
#if defined(_MSC_VER)
|
65 |
+
#pragma warning(disable: 4244 4267) // possible loss of data
|
66 |
+
#endif
|
67 |
+
|
68 |
+
#if defined(LLAMA_USE_CURL)
|
69 |
+
#ifdef __linux__
|
70 |
+
#include <linux/limits.h>
|
71 |
+
#elif defined(_WIN32)
|
72 |
+
# if !defined(PATH_MAX)
|
73 |
+
# define PATH_MAX MAX_PATH
|
74 |
+
# endif
|
75 |
+
#else
|
76 |
+
#include <sys/syslimits.h>
|
77 |
+
#endif
|
78 |
+
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
|
79 |
+
|
80 |
+
//
|
81 |
+
// CURL utils
|
82 |
+
//
|
83 |
+
|
84 |
+
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
|
85 |
+
|
86 |
+
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
|
87 |
+
struct curl_slist_ptr {
|
88 |
+
struct curl_slist * ptr = nullptr;
|
89 |
+
~curl_slist_ptr() {
|
90 |
+
if (ptr) {
|
91 |
+
curl_slist_free_all(ptr);
|
92 |
+
}
|
93 |
+
}
|
94 |
+
};
|
95 |
+
#endif // LLAMA_USE_CURL
|
96 |
+
|
97 |
+
using json = nlohmann::ordered_json;
|
98 |
+
|
99 |
+
//
|
100 |
+
// CPU utils
|
101 |
+
//
|
102 |
+
|
103 |
+
int32_t cpu_get_num_physical_cores() {
|
104 |
+
#ifdef __linux__
|
105 |
+
// enumerate the set of thread siblings, num entries is num cores
|
106 |
+
std::unordered_set<std::string> siblings;
|
107 |
+
for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
|
108 |
+
std::ifstream thread_siblings("/sys/devices/system/cpu/cpu"
|
109 |
+
+ std::to_string(cpu) + "/topology/thread_siblings");
|
110 |
+
if (!thread_siblings.is_open()) {
|
111 |
+
break; // no more cpus
|
112 |
+
}
|
113 |
+
std::string line;
|
114 |
+
if (std::getline(thread_siblings, line)) {
|
115 |
+
siblings.insert(line);
|
116 |
+
}
|
117 |
+
}
|
118 |
+
if (!siblings.empty()) {
|
119 |
+
return static_cast<int32_t>(siblings.size());
|
120 |
+
}
|
121 |
+
#elif defined(__APPLE__) && defined(__MACH__)
|
122 |
+
int32_t num_physical_cores;
|
123 |
+
size_t len = sizeof(num_physical_cores);
|
124 |
+
int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
|
125 |
+
if (result == 0) {
|
126 |
+
return num_physical_cores;
|
127 |
+
}
|
128 |
+
result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
|
129 |
+
if (result == 0) {
|
130 |
+
return num_physical_cores;
|
131 |
+
}
|
132 |
+
#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
|
133 |
+
// TODO: windows + arm64 + mingw64
|
134 |
+
unsigned int n_threads_win = std::thread::hardware_concurrency();
|
135 |
+
unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;
|
136 |
+
|
137 |
+
DWORD buffer_size = 0;
|
138 |
+
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
|
139 |
+
if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
|
140 |
+
return default_threads;
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
std::vector<char> buffer(buffer_size);
|
145 |
+
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) {
|
146 |
+
return default_threads;
|
147 |
+
}
|
148 |
+
|
149 |
+
int32_t num_physical_cores = 0;
|
150 |
+
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data());
|
151 |
+
while (buffer_size > 0) {
|
152 |
+
if (info->Relationship == RelationProcessorCore) {
|
153 |
+
num_physical_cores += info->Processor.GroupCount;
|
154 |
+
}
|
155 |
+
buffer_size -= info->Size;
|
156 |
+
info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size);
|
157 |
+
}
|
158 |
+
|
159 |
+
return num_physical_cores > 0 ? num_physical_cores : default_threads;
|
160 |
+
#endif
|
161 |
+
unsigned int n_threads = std::thread::hardware_concurrency();
|
162 |
+
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
163 |
+
}
|
164 |
+
|
165 |
+
#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
|
166 |
+
#include <pthread.h>
|
167 |
+
|
168 |
+
static void cpuid(unsigned leaf, unsigned subleaf,
|
169 |
+
unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
|
170 |
+
__asm__("movq\t%%rbx,%%rsi\n\t"
|
171 |
+
"cpuid\n\t"
|
172 |
+
"xchgq\t%%rbx,%%rsi"
|
173 |
+
: "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
|
174 |
+
: "0"(leaf), "2"(subleaf));
|
175 |
+
}
|
176 |
+
|
177 |
+
static int pin_cpu(int cpu) {
|
178 |
+
cpu_set_t mask;
|
179 |
+
CPU_ZERO(&mask);
|
180 |
+
CPU_SET(cpu, &mask);
|
181 |
+
return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
|
182 |
+
}
|
183 |
+
|
184 |
+
static bool is_hybrid_cpu(void) {
|
185 |
+
unsigned eax, ebx, ecx, edx;
|
186 |
+
cpuid(7, 0, &eax, &ebx, &ecx, &edx);
|
187 |
+
return !!(edx & (1u << 15));
|
188 |
+
}
|
189 |
+
|
190 |
+
static bool is_running_on_efficiency_core(void) {
|
191 |
+
unsigned eax, ebx, ecx, edx;
|
192 |
+
cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
|
193 |
+
int intel_atom = 0x20;
|
194 |
+
int core_type = (eax & 0xff000000u) >> 24;
|
195 |
+
return core_type == intel_atom;
|
196 |
+
}
|
197 |
+
|
198 |
+
static int cpu_count_math_cpus(int n_cpu) {
|
199 |
+
int result = 0;
|
200 |
+
for (int cpu = 0; cpu < n_cpu; ++cpu) {
|
201 |
+
if (pin_cpu(cpu)) {
|
202 |
+
return -1;
|
203 |
+
}
|
204 |
+
if (is_running_on_efficiency_core()) {
|
205 |
+
continue; // efficiency cores harm lockstep threading
|
206 |
+
}
|
207 |
+
++cpu; // hyperthreading isn't useful for linear algebra
|
208 |
+
++result;
|
209 |
+
}
|
210 |
+
return result;
|
211 |
+
}
|
212 |
+
|
213 |
+
#endif // __x86_64__ && __linux__
|
214 |
+
|
215 |
+
/**
|
216 |
+
* Returns number of CPUs on system that are useful for math.
|
217 |
+
*/
|
218 |
+
int32_t cpu_get_num_math() {
|
219 |
+
#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
|
220 |
+
int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
|
221 |
+
if (n_cpu < 1) {
|
222 |
+
return cpu_get_num_physical_cores();
|
223 |
+
}
|
224 |
+
if (is_hybrid_cpu()) {
|
225 |
+
cpu_set_t affinity;
|
226 |
+
if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
|
227 |
+
int result = cpu_count_math_cpus(n_cpu);
|
228 |
+
pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
|
229 |
+
if (result > 0) {
|
230 |
+
return result;
|
231 |
+
}
|
232 |
+
}
|
233 |
+
}
|
234 |
+
#endif
|
235 |
+
return cpu_get_num_physical_cores();
|
236 |
+
}
|
237 |
+
|
238 |
+
// Helper for setting process priority
|
239 |
+
|
240 |
+
#if defined(_WIN32)
|
241 |
+
|
242 |
+
bool set_process_priority(enum ggml_sched_priority prio) {
|
243 |
+
if (prio == GGML_SCHED_PRIO_NORMAL) {
|
244 |
+
return true;
|
245 |
+
}
|
246 |
+
|
247 |
+
DWORD p = NORMAL_PRIORITY_CLASS;
|
248 |
+
switch (prio) {
|
249 |
+
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
|
250 |
+
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
|
251 |
+
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
|
252 |
+
case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS; break;
|
253 |
+
}
|
254 |
+
|
255 |
+
if (!SetPriorityClass(GetCurrentProcess(), p)) {
|
256 |
+
LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
|
257 |
+
return false;
|
258 |
+
}
|
259 |
+
|
260 |
+
return true;
|
261 |
+
}
|
262 |
+
|
263 |
+
#else // MacOS and POSIX
|
264 |
+
#include <sys/types.h>
|
265 |
+
#include <sys/resource.h>
|
266 |
+
|
267 |
+
bool set_process_priority(enum ggml_sched_priority prio) {
|
268 |
+
if (prio == GGML_SCHED_PRIO_NORMAL) {
|
269 |
+
return true;
|
270 |
+
}
|
271 |
+
|
272 |
+
int p = 0;
|
273 |
+
switch (prio) {
|
274 |
+
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
|
275 |
+
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
|
276 |
+
case GGML_SCHED_PRIO_HIGH: p = -10; break;
|
277 |
+
case GGML_SCHED_PRIO_REALTIME: p = -20; break;
|
278 |
+
}
|
279 |
+
|
280 |
+
if (!setpriority(PRIO_PROCESS, 0, p)) {
|
281 |
+
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
282 |
+
return false;
|
283 |
+
}
|
284 |
+
return true;
|
285 |
+
}
|
286 |
+
|
287 |
+
#endif
|
288 |
+
|
289 |
+
//
|
290 |
+
// CLI argument parsing
|
291 |
+
//
|
292 |
+
|
293 |
+
|
294 |
+
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
|
295 |
+
int32_t n_set = 0;
|
296 |
+
|
297 |
+
if (cpuparams.n_threads < 0) {
|
298 |
+
// Assuming everything about cpuparams is invalid
|
299 |
+
if (role_model != nullptr) {
|
300 |
+
cpuparams = *role_model;
|
301 |
+
} else {
|
302 |
+
cpuparams.n_threads = cpu_get_num_math();
|
303 |
+
}
|
304 |
+
}
|
305 |
+
|
306 |
+
for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
|
307 |
+
if (cpuparams.cpumask[i]) {
|
308 |
+
n_set++;
|
309 |
+
}
|
310 |
+
}
|
311 |
+
|
312 |
+
if (n_set && n_set < cpuparams.n_threads) {
|
313 |
+
// Not enough set bits, may experience performance issues.
|
314 |
+
LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
|
315 |
+
}
|
316 |
+
}
|
317 |
+
|
318 |
+
bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
|
319 |
+
size_t dash_loc = range.find('-');
|
320 |
+
if (dash_loc == std::string::npos) {
|
321 |
+
LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
|
322 |
+
return false;
|
323 |
+
}
|
324 |
+
|
325 |
+
size_t start_i;
|
326 |
+
size_t end_i;
|
327 |
+
|
328 |
+
if (dash_loc == 0) {
|
329 |
+
start_i = 0;
|
330 |
+
} else {
|
331 |
+
start_i = std::stoull(range.substr(0, dash_loc));
|
332 |
+
if (start_i >= GGML_MAX_N_THREADS) {
|
333 |
+
LOG_ERR("Start index out of bounds!\n");
|
334 |
+
return false;
|
335 |
+
}
|
336 |
+
}
|
337 |
+
|
338 |
+
if (dash_loc == range.length() - 1) {
|
339 |
+
end_i = GGML_MAX_N_THREADS - 1;
|
340 |
+
} else {
|
341 |
+
end_i = std::stoull(range.substr(dash_loc + 1));
|
342 |
+
if (end_i >= GGML_MAX_N_THREADS) {
|
343 |
+
LOG_ERR("End index out of bounds!\n");
|
344 |
+
return false;
|
345 |
+
}
|
346 |
+
}
|
347 |
+
|
348 |
+
for (size_t i = start_i; i <= end_i; i++) {
|
349 |
+
boolmask[i] = true;
|
350 |
+
}
|
351 |
+
|
352 |
+
return true;
|
353 |
+
}
|
354 |
+
|
355 |
+
bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) {
|
356 |
+
// Discard potential 0x prefix
|
357 |
+
size_t start_i = 0;
|
358 |
+
if (mask.length() >= 2 && mask.substr(0, 2) == "0x") {
|
359 |
+
start_i = 2;
|
360 |
+
}
|
361 |
+
|
362 |
+
size_t num_digits = mask.length() - start_i;
|
363 |
+
if (num_digits > 128) num_digits = 128;
|
364 |
+
|
365 |
+
size_t end_i = num_digits + start_i;
|
366 |
+
|
367 |
+
for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) {
|
368 |
+
char c = mask.at(i);
|
369 |
+
int8_t id = c;
|
370 |
+
|
371 |
+
if ((c >= '0' && c <= '9')) {
|
372 |
+
id -= '0';
|
373 |
+
} else if (c >= 'a' && c <= 'f') {
|
374 |
+
id -= 'a' - 10;
|
375 |
+
} else if (c >= 'A' && c <= 'F') {
|
376 |
+
id -= 'A' - 10;
|
377 |
+
} else {
|
378 |
+
LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
|
379 |
+
return false;
|
380 |
+
}
|
381 |
+
|
382 |
+
boolmask[ n ] = boolmask[ n ] || ((id & 8) != 0);
|
383 |
+
boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0);
|
384 |
+
boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0);
|
385 |
+
boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0);
|
386 |
+
}
|
387 |
+
|
388 |
+
return true;
|
389 |
+
}
|
390 |
+
|
391 |
+
void common_init() {
|
392 |
+
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
|
393 |
+
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
|
394 |
+
common_log_add(common_log_main(), level, "%s", text);
|
395 |
+
}
|
396 |
+
}, NULL);
|
397 |
+
|
398 |
+
#ifdef NDEBUG
|
399 |
+
const char * build_type = "";
|
400 |
+
#else
|
401 |
+
const char * build_type = " (debug)";
|
402 |
+
#endif
|
403 |
+
|
404 |
+
LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
|
405 |
+
}
|
406 |
+
|
407 |
+
std::string common_params_get_system_info(const common_params & params) {
|
408 |
+
std::ostringstream os;
|
409 |
+
|
410 |
+
os << "system_info: n_threads = " << params.cpuparams.n_threads;
|
411 |
+
if (params.cpuparams_batch.n_threads != -1) {
|
412 |
+
os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")";
|
413 |
+
}
|
414 |
+
#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
|
415 |
+
// TODO: windows + arm64 + mingw64
|
416 |
+
DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
|
417 |
+
os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
|
418 |
+
#else
|
419 |
+
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
|
420 |
+
#endif
|
421 |
+
|
422 |
+
return os.str();
|
423 |
+
}
|
424 |
+
|
425 |
+
//
|
426 |
+
// String utils
|
427 |
+
//
|
428 |
+
|
429 |
+
std::string string_format(const char * fmt, ...) {
|
430 |
+
va_list ap;
|
431 |
+
va_list ap2;
|
432 |
+
va_start(ap, fmt);
|
433 |
+
va_copy(ap2, ap);
|
434 |
+
int size = vsnprintf(NULL, 0, fmt, ap);
|
435 |
+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
436 |
+
std::vector<char> buf(size + 1);
|
437 |
+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
438 |
+
GGML_ASSERT(size2 == size);
|
439 |
+
va_end(ap2);
|
440 |
+
va_end(ap);
|
441 |
+
return std::string(buf.data(), size);
|
442 |
+
}
|
443 |
+
|
444 |
+
std::string string_strip(const std::string & str) {
|
445 |
+
size_t start = 0;
|
446 |
+
size_t end = str.size();
|
447 |
+
while (start < end && std::isspace(str[start])) {
|
448 |
+
start++;
|
449 |
+
}
|
450 |
+
while (end > start && std::isspace(str[end - 1])) {
|
451 |
+
end--;
|
452 |
+
}
|
453 |
+
return str.substr(start, end - start);
|
454 |
+
}
|
455 |
+
|
456 |
+
std::string string_get_sortable_timestamp() {
|
457 |
+
using clock = std::chrono::system_clock;
|
458 |
+
|
459 |
+
const clock::time_point current_time = clock::now();
|
460 |
+
const time_t as_time_t = clock::to_time_t(current_time);
|
461 |
+
char timestamp_no_ns[100];
|
462 |
+
std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
|
463 |
+
|
464 |
+
const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
|
465 |
+
current_time.time_since_epoch() % 1000000000).count();
|
466 |
+
char timestamp_ns[11];
|
467 |
+
snprintf(timestamp_ns, 11, "%09" PRId64, ns);
|
468 |
+
|
469 |
+
return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
|
470 |
+
}
|
471 |
+
|
472 |
+
void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
473 |
+
if (search.empty()) {
|
474 |
+
return;
|
475 |
+
}
|
476 |
+
std::string builder;
|
477 |
+
builder.reserve(s.length());
|
478 |
+
size_t pos = 0;
|
479 |
+
size_t last_pos = 0;
|
480 |
+
while ((pos = s.find(search, last_pos)) != std::string::npos) {
|
481 |
+
builder.append(s, last_pos, pos - last_pos);
|
482 |
+
builder.append(replace);
|
483 |
+
last_pos = pos + search.length();
|
484 |
+
}
|
485 |
+
builder.append(s, last_pos, std::string::npos);
|
486 |
+
s = std::move(builder);
|
487 |
+
}
|
488 |
+
|
489 |
+
std::string regex_escape(const std::string & s) {
|
490 |
+
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
491 |
+
return std::regex_replace(s, special_chars, "\\$0");
|
492 |
+
}
|
493 |
+
|
494 |
+
std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
|
495 |
+
std::ostringstream result;
|
496 |
+
for (size_t i = 0; i < values.size(); ++i) {
|
497 |
+
if (i > 0) {
|
498 |
+
result << separator;
|
499 |
+
}
|
500 |
+
result << values[i];
|
501 |
+
}
|
502 |
+
return result.str();
|
503 |
+
}
|
504 |
+
|
505 |
+
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
|
506 |
+
std::vector<std::string> parts;
|
507 |
+
size_t start = 0;
|
508 |
+
size_t end = str.find(delimiter);
|
509 |
+
|
510 |
+
while (end != std::string::npos) {
|
511 |
+
parts.push_back(str.substr(start, end - start));
|
512 |
+
start = end + delimiter.length();
|
513 |
+
end = str.find(delimiter, start);
|
514 |
+
}
|
515 |
+
|
516 |
+
parts.push_back(str.substr(start));
|
517 |
+
|
518 |
+
return parts;
|
519 |
+
}
|
520 |
+
|
521 |
+
std::string string_repeat(const std::string & str, size_t n) {
|
522 |
+
if (n == 0) {
|
523 |
+
return "";
|
524 |
+
}
|
525 |
+
|
526 |
+
std::string result;
|
527 |
+
result.reserve(str.length() * n);
|
528 |
+
|
529 |
+
for (size_t i = 0; i < n; ++i) {
|
530 |
+
result += str;
|
531 |
+
}
|
532 |
+
|
533 |
+
return result;
|
534 |
+
}
|
535 |
+
|
536 |
+
std::string string_from(bool value) {
|
537 |
+
return value ? "true" : "false";
|
538 |
+
}
|
539 |
+
|
540 |
+
std::string string_from(const std::vector<int> & values) {
|
541 |
+
std::stringstream buf;
|
542 |
+
|
543 |
+
buf << "[ ";
|
544 |
+
bool first = true;
|
545 |
+
for (auto e : values) {
|
546 |
+
if (first) {
|
547 |
+
first = false;
|
548 |
+
} else {
|
549 |
+
buf << ", ";
|
550 |
+
}
|
551 |
+
buf << std::to_string(e);
|
552 |
+
}
|
553 |
+
buf << " ]";
|
554 |
+
|
555 |
+
return buf.str();
|
556 |
+
}
|
557 |
+
|
558 |
+
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
|
559 |
+
std::stringstream buf;
|
560 |
+
|
561 |
+
buf << "[ ";
|
562 |
+
|
563 |
+
bool first = true;
|
564 |
+
for (const auto & token : tokens) {
|
565 |
+
if (!first) {
|
566 |
+
buf << ", ";
|
567 |
+
} else {
|
568 |
+
first = false;
|
569 |
+
}
|
570 |
+
|
571 |
+
auto detokenized = common_token_to_piece(ctx, token);
|
572 |
+
|
573 |
+
detokenized.erase(
|
574 |
+
std::remove_if(
|
575 |
+
detokenized.begin(),
|
576 |
+
detokenized.end(),
|
577 |
+
[](const unsigned char c) { return !std::isprint(c); }),
|
578 |
+
detokenized.end());
|
579 |
+
|
580 |
+
buf << "'" << detokenized << "'"
|
581 |
+
<< ":" << std::to_string(token);
|
582 |
+
}
|
583 |
+
|
584 |
+
buf << " ]";
|
585 |
+
|
586 |
+
return buf.str();
|
587 |
+
}
|
588 |
+
|
589 |
+
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
|
590 |
+
std::stringstream buf;
|
591 |
+
|
592 |
+
buf << "[ ";
|
593 |
+
|
594 |
+
bool first = true;
|
595 |
+
for (int i = 0; i < batch.n_tokens; ++i) {
|
596 |
+
if (!first) {
|
597 |
+
buf << ", ";
|
598 |
+
} else {
|
599 |
+
first = false;
|
600 |
+
}
|
601 |
+
|
602 |
+
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
|
603 |
+
|
604 |
+
detokenized.erase(
|
605 |
+
std::remove_if(
|
606 |
+
detokenized.begin(),
|
607 |
+
detokenized.end(),
|
608 |
+
[](const unsigned char c) { return !std::isprint(c); }),
|
609 |
+
detokenized.end());
|
610 |
+
|
611 |
+
buf << "\n" << std::to_string(i)
|
612 |
+
<< ", token '" << detokenized << "'"
|
613 |
+
<< ", pos " << std::to_string(batch.pos[i])
|
614 |
+
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
|
615 |
+
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
|
616 |
+
<< ", logits " << std::to_string(batch.logits[i]);
|
617 |
+
}
|
618 |
+
|
619 |
+
buf << " ]";
|
620 |
+
|
621 |
+
return buf.str();
|
622 |
+
}
|
623 |
+
|
624 |
+
void string_process_escapes(std::string & input) {
|
625 |
+
std::size_t input_len = input.length();
|
626 |
+
std::size_t output_idx = 0;
|
627 |
+
|
628 |
+
for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
|
629 |
+
if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
|
630 |
+
switch (input[++input_idx]) {
|
631 |
+
case 'n': input[output_idx++] = '\n'; break;
|
632 |
+
case 'r': input[output_idx++] = '\r'; break;
|
633 |
+
case 't': input[output_idx++] = '\t'; break;
|
634 |
+
case '\'': input[output_idx++] = '\''; break;
|
635 |
+
case '\"': input[output_idx++] = '\"'; break;
|
636 |
+
case '\\': input[output_idx++] = '\\'; break;
|
637 |
+
case 'x':
|
638 |
+
// Handle \x12, etc
|
639 |
+
if (input_idx + 2 < input_len) {
|
640 |
+
const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 };
|
641 |
+
char *err_p = nullptr;
|
642 |
+
const long val = std::strtol(x, &err_p, 16);
|
643 |
+
if (err_p == x + 2) {
|
644 |
+
input_idx += 2;
|
645 |
+
input[output_idx++] = char(val);
|
646 |
+
break;
|
647 |
+
}
|
648 |
+
}
|
649 |
+
// fall through
|
650 |
+
default: input[output_idx++] = '\\';
|
651 |
+
input[output_idx++] = input[input_idx]; break;
|
652 |
+
}
|
653 |
+
} else {
|
654 |
+
input[output_idx++] = input[input_idx];
|
655 |
+
}
|
656 |
+
}
|
657 |
+
|
658 |
+
input.resize(output_idx);
|
659 |
+
}
|
660 |
+
|
661 |
+
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
|
662 |
+
const char * sep = strchr(data, '=');
|
663 |
+
if (sep == nullptr || sep - data >= 128) {
|
664 |
+
LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
|
665 |
+
return false;
|
666 |
+
}
|
667 |
+
llama_model_kv_override kvo;
|
668 |
+
std::strncpy(kvo.key, data, sep - data);
|
669 |
+
kvo.key[sep - data] = 0;
|
670 |
+
sep++;
|
671 |
+
if (strncmp(sep, "int:", 4) == 0) {
|
672 |
+
sep += 4;
|
673 |
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
|
674 |
+
kvo.val_i64 = std::atol(sep);
|
675 |
+
} else if (strncmp(sep, "float:", 6) == 0) {
|
676 |
+
sep += 6;
|
677 |
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
|
678 |
+
kvo.val_f64 = std::atof(sep);
|
679 |
+
} else if (strncmp(sep, "bool:", 5) == 0) {
|
680 |
+
sep += 5;
|
681 |
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
|
682 |
+
if (std::strcmp(sep, "true") == 0) {
|
683 |
+
kvo.val_bool = true;
|
684 |
+
} else if (std::strcmp(sep, "false") == 0) {
|
685 |
+
kvo.val_bool = false;
|
686 |
+
} else {
|
687 |
+
LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
688 |
+
return false;
|
689 |
+
}
|
690 |
+
} else if (strncmp(sep, "str:", 4) == 0) {
|
691 |
+
sep += 4;
|
692 |
+
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
693 |
+
if (strlen(sep) > 127) {
|
694 |
+
LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
695 |
+
return false;
|
696 |
+
}
|
697 |
+
strncpy(kvo.val_str, sep, 127);
|
698 |
+
kvo.val_str[127] = '\0';
|
699 |
+
} else {
|
700 |
+
LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
|
701 |
+
return false;
|
702 |
+
}
|
703 |
+
overrides.emplace_back(std::move(kvo));
|
704 |
+
return true;
|
705 |
+
}
|
706 |
+
|
707 |
+
//
|
708 |
+
// Filesystem utils
|
709 |
+
//
|
710 |
+
|
711 |
+
// Validate if a filename is safe to use
|
712 |
+
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
|
713 |
+
bool fs_validate_filename(const std::string & filename) {
|
714 |
+
if (!filename.length()) {
|
715 |
+
// Empty filename invalid
|
716 |
+
return false;
|
717 |
+
}
|
718 |
+
if (filename.length() > 255) {
|
719 |
+
// Limit at common largest possible filename on Linux filesystems
|
720 |
+
// to avoid unnecessary further validation
|
721 |
+
// (On systems with smaller limits it will be caught by the OS)
|
722 |
+
return false;
|
723 |
+
}
|
724 |
+
|
725 |
+
std::u32string filename_utf32;
|
726 |
+
try {
|
727 |
+
#if defined(__clang__)
|
728 |
+
// disable C++17 deprecation warning for std::codecvt_utf8
|
729 |
+
# pragma clang diagnostic push
|
730 |
+
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
731 |
+
#endif
|
732 |
+
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
|
733 |
+
|
734 |
+
#if defined(__clang__)
|
735 |
+
# pragma clang diagnostic pop
|
736 |
+
#endif
|
737 |
+
|
738 |
+
filename_utf32 = converter.from_bytes(filename);
|
739 |
+
|
740 |
+
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
|
741 |
+
// or invalid encodings were encountered. Reject such attempts
|
742 |
+
std::string filename_reencoded = converter.to_bytes(filename_utf32);
|
743 |
+
if (filename_reencoded != filename) {
|
744 |
+
return false;
|
745 |
+
}
|
746 |
+
} catch (const std::exception &) {
|
747 |
+
return false;
|
748 |
+
}
|
749 |
+
|
750 |
+
// Check for forbidden codepoints:
|
751 |
+
// - Control characters
|
752 |
+
// - Unicode equivalents of illegal characters
|
753 |
+
// - UTF-16 surrogate pairs
|
754 |
+
// - UTF-8 replacement character
|
755 |
+
// - Byte order mark (BOM)
|
756 |
+
// - Illegal characters: / \ : * ? " < > |
|
757 |
+
for (char32_t c : filename_utf32) {
|
758 |
+
if (c <= 0x1F // Control characters (C0)
|
759 |
+
|| c == 0x7F // Control characters (DEL)
|
760 |
+
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
|
761 |
+
|| c == 0xFF0E // Fullwidth Full Stop (period equivalent)
|
762 |
+
|| c == 0x2215 // Division Slash (forward slash equivalent)
|
763 |
+
|| c == 0x2216 // Set Minus (backslash equivalent)
|
764 |
+
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|
765 |
+
|| c == 0xFFFD // Replacement Character (UTF-8)
|
766 |
+
|| c == 0xFEFF // Byte Order Mark (BOM)
|
767 |
+
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|
768 |
+
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
|
769 |
+
return false;
|
770 |
+
}
|
771 |
+
}
|
772 |
+
|
773 |
+
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
|
774 |
+
// Unicode and other whitespace is not affected, only 0x20 space
|
775 |
+
if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
|
776 |
+
return false;
|
777 |
+
}
|
778 |
+
|
779 |
+
// Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
|
780 |
+
if (filename.find("..") != std::string::npos) {
|
781 |
+
return false;
|
782 |
+
}
|
783 |
+
|
784 |
+
// Reject "."
|
785 |
+
if (filename == ".") {
|
786 |
+
return false;
|
787 |
+
}
|
788 |
+
|
789 |
+
return true;
|
790 |
+
}
|
791 |
+
|
792 |
+
// returns true if successful, false otherwise
|
793 |
+
bool fs_create_directory_with_parents(const std::string & path) {
|
794 |
+
#ifdef _WIN32
|
795 |
+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
796 |
+
std::wstring wpath = converter.from_bytes(path);
|
797 |
+
|
798 |
+
// if the path already exists, check whether it's a directory
|
799 |
+
const DWORD attributes = GetFileAttributesW(wpath.c_str());
|
800 |
+
if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
801 |
+
return true;
|
802 |
+
}
|
803 |
+
|
804 |
+
size_t pos_slash = 0;
|
805 |
+
|
806 |
+
// process path from front to back, procedurally creating directories
|
807 |
+
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
|
808 |
+
const std::wstring subpath = wpath.substr(0, pos_slash);
|
809 |
+
const wchar_t * test = subpath.c_str();
|
810 |
+
|
811 |
+
const bool success = CreateDirectoryW(test, NULL);
|
812 |
+
if (!success) {
|
813 |
+
const DWORD error = GetLastError();
|
814 |
+
|
815 |
+
// if the path already exists, ensure that it's a directory
|
816 |
+
if (error == ERROR_ALREADY_EXISTS) {
|
817 |
+
const DWORD attributes = GetFileAttributesW(subpath.c_str());
|
818 |
+
if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
|
819 |
+
return false;
|
820 |
+
}
|
821 |
+
} else {
|
822 |
+
return false;
|
823 |
+
}
|
824 |
+
}
|
825 |
+
|
826 |
+
pos_slash += 1;
|
827 |
+
}
|
828 |
+
|
829 |
+
return true;
|
830 |
+
#else
|
831 |
+
// if the path already exists, check whether it's a directory
|
832 |
+
struct stat info;
|
833 |
+
if (stat(path.c_str(), &info) == 0) {
|
834 |
+
return S_ISDIR(info.st_mode);
|
835 |
+
}
|
836 |
+
|
837 |
+
size_t pos_slash = 1; // skip leading slashes for directory creation
|
838 |
+
|
839 |
+
// process path from front to back, procedurally creating directories
|
840 |
+
while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
|
841 |
+
const std::string subpath = path.substr(0, pos_slash);
|
842 |
+
struct stat info;
|
843 |
+
|
844 |
+
// if the path already exists, ensure that it's a directory
|
845 |
+
if (stat(subpath.c_str(), &info) == 0) {
|
846 |
+
if (!S_ISDIR(info.st_mode)) {
|
847 |
+
return false;
|
848 |
+
}
|
849 |
+
} else {
|
850 |
+
// create parent directories
|
851 |
+
const int ret = mkdir(subpath.c_str(), 0755);
|
852 |
+
if (ret != 0) {
|
853 |
+
return false;
|
854 |
+
}
|
855 |
+
}
|
856 |
+
|
857 |
+
pos_slash += 1;
|
858 |
+
}
|
859 |
+
|
860 |
+
return true;
|
861 |
+
#endif // _WIN32
|
862 |
+
}
|
863 |
+
|
864 |
+
std::string fs_get_cache_directory() {
|
865 |
+
std::string cache_directory = "";
|
866 |
+
auto ensure_trailing_slash = [](std::string p) {
|
867 |
+
// Make sure to add trailing slash
|
868 |
+
if (p.back() != DIRECTORY_SEPARATOR) {
|
869 |
+
p += DIRECTORY_SEPARATOR;
|
870 |
+
}
|
871 |
+
return p;
|
872 |
+
};
|
873 |
+
if (getenv("LLAMA_CACHE")) {
|
874 |
+
cache_directory = std::getenv("LLAMA_CACHE");
|
875 |
+
} else {
|
876 |
+
#ifdef __linux__
|
877 |
+
if (std::getenv("XDG_CACHE_HOME")) {
|
878 |
+
cache_directory = std::getenv("XDG_CACHE_HOME");
|
879 |
+
} else {
|
880 |
+
cache_directory = std::getenv("HOME") + std::string("/.cache/");
|
881 |
+
}
|
882 |
+
#elif defined(__APPLE__)
|
883 |
+
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
|
884 |
+
#elif defined(_WIN32)
|
885 |
+
cache_directory = std::getenv("LOCALAPPDATA");
|
886 |
+
#endif // __linux__
|
887 |
+
cache_directory = ensure_trailing_slash(cache_directory);
|
888 |
+
cache_directory += "llama.cpp";
|
889 |
+
}
|
890 |
+
return ensure_trailing_slash(cache_directory);
|
891 |
+
}
|
892 |
+
|
893 |
+
std::string fs_get_cache_file(const std::string & filename) {
|
894 |
+
GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos);
|
895 |
+
std::string cache_directory = fs_get_cache_directory();
|
896 |
+
const bool success = fs_create_directory_with_parents(cache_directory);
|
897 |
+
if (!success) {
|
898 |
+
throw std::runtime_error("failed to create cache directory: " + cache_directory);
|
899 |
+
}
|
900 |
+
return cache_directory + filename;
|
901 |
+
}
|
902 |
+
|
903 |
+
|
904 |
+
//
|
905 |
+
// Model utils
|
906 |
+
//
|
907 |
+
struct common_init_result common_init_from_params(common_params & params) {
|
908 |
+
common_init_result iparams;
|
909 |
+
auto mparams = common_model_params_to_llama(params);
|
910 |
+
|
911 |
+
llama_model * model = nullptr;
|
912 |
+
|
913 |
+
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
|
914 |
+
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
|
915 |
+
} else if (!params.model_url.empty()) {
|
916 |
+
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
|
917 |
+
} else {
|
918 |
+
model = llama_model_load_from_file(params.model.c_str(), mparams);
|
919 |
+
}
|
920 |
+
|
921 |
+
if (model == NULL) {
|
922 |
+
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
|
923 |
+
return iparams;
|
924 |
+
}
|
925 |
+
|
926 |
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
927 |
+
|
928 |
+
if (params.reranking) {
|
929 |
+
bool ok = true;
|
930 |
+
|
931 |
+
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
932 |
+
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
933 |
+
ok = false;
|
934 |
+
}
|
935 |
+
|
936 |
+
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
937 |
+
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
|
938 |
+
ok = false;
|
939 |
+
}
|
940 |
+
|
941 |
+
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
|
942 |
+
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
943 |
+
ok = false;
|
944 |
+
}
|
945 |
+
|
946 |
+
if (!ok) {
|
947 |
+
llama_model_free(model);
|
948 |
+
|
949 |
+
return iparams;
|
950 |
+
}
|
951 |
+
}
|
952 |
+
|
953 |
+
auto cparams = common_context_params_to_llama(params);
|
954 |
+
|
955 |
+
llama_context * lctx = llama_init_from_model(model, cparams);
|
956 |
+
if (lctx == NULL) {
|
957 |
+
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
|
958 |
+
llama_model_free(model);
|
959 |
+
return iparams;
|
960 |
+
}
|
961 |
+
|
962 |
+
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
|
963 |
+
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
|
964 |
+
params.ctx_shift = false;
|
965 |
+
}
|
966 |
+
|
967 |
+
if (!params.control_vectors.empty()) {
|
968 |
+
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
|
969 |
+
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model);
|
970 |
+
|
971 |
+
const auto cvec = common_control_vector_load(params.control_vectors);
|
972 |
+
if (cvec.n_embd == -1) {
|
973 |
+
llama_free(lctx);
|
974 |
+
llama_model_free(model);
|
975 |
+
|
976 |
+
return iparams;
|
977 |
+
}
|
978 |
+
|
979 |
+
int err = llama_apply_adapter_cvec(
|
980 |
+
lctx,
|
981 |
+
cvec.data.data(),
|
982 |
+
cvec.data.size(),
|
983 |
+
cvec.n_embd,
|
984 |
+
params.control_vector_layer_start,
|
985 |
+
params.control_vector_layer_end);
|
986 |
+
if (err) {
|
987 |
+
llama_free(lctx);
|
988 |
+
llama_model_free(model);
|
989 |
+
|
990 |
+
return iparams;
|
991 |
+
}
|
992 |
+
}
|
993 |
+
|
994 |
+
// load and optionally apply lora adapters
|
995 |
+
for (auto & la : params.lora_adapters) {
|
996 |
+
llama_adapter_lora_ptr lora;
|
997 |
+
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
998 |
+
if (lora == nullptr) {
|
999 |
+
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
1000 |
+
llama_free(lctx);
|
1001 |
+
llama_model_free(model);
|
1002 |
+
return iparams;
|
1003 |
+
}
|
1004 |
+
|
1005 |
+
la.ptr = lora.get();
|
1006 |
+
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
1007 |
+
}
|
1008 |
+
|
1009 |
+
if (!params.lora_init_without_apply) {
|
1010 |
+
common_set_adapter_lora(lctx, params.lora_adapters);
|
1011 |
+
}
|
1012 |
+
|
1013 |
+
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
1014 |
+
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
1015 |
+
params.sampling.ignore_eos = false;
|
1016 |
+
}
|
1017 |
+
|
1018 |
+
if (params.sampling.ignore_eos) {
|
1019 |
+
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
1020 |
+
if (llama_vocab_is_eog(vocab, i)) {
|
1021 |
+
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
1022 |
+
params.sampling.logit_bias.push_back({i, -INFINITY});
|
1023 |
+
}
|
1024 |
+
}
|
1025 |
+
}
|
1026 |
+
|
1027 |
+
if (params.sampling.penalty_last_n == -1) {
|
1028 |
+
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
1029 |
+
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
1030 |
+
}
|
1031 |
+
|
1032 |
+
if (params.sampling.dry_penalty_last_n == -1) {
|
1033 |
+
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
1034 |
+
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
1035 |
+
}
|
1036 |
+
|
1037 |
+
if (params.warmup) {
|
1038 |
+
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
1039 |
+
|
1040 |
+
std::vector<llama_token> tmp;
|
1041 |
+
llama_token bos = llama_vocab_bos(vocab);
|
1042 |
+
llama_token eos = llama_vocab_eos(vocab);
|
1043 |
+
|
1044 |
+
// some models (e.g. T5) don't have a BOS token
|
1045 |
+
if (bos != LLAMA_TOKEN_NULL) {
|
1046 |
+
tmp.push_back(bos);
|
1047 |
+
}
|
1048 |
+
if (eos != LLAMA_TOKEN_NULL) {
|
1049 |
+
tmp.push_back(eos);
|
1050 |
+
}
|
1051 |
+
if (tmp.empty()) {
|
1052 |
+
tmp.push_back(0);
|
1053 |
+
}
|
1054 |
+
|
1055 |
+
if (llama_model_has_encoder(model)) {
|
1056 |
+
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
1057 |
+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
1058 |
+
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
1059 |
+
decoder_start_token_id = bos;
|
1060 |
+
}
|
1061 |
+
tmp.clear();
|
1062 |
+
tmp.push_back(decoder_start_token_id);
|
1063 |
+
}
|
1064 |
+
if (llama_model_has_decoder(model)) {
|
1065 |
+
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
1066 |
+
}
|
1067 |
+
llama_kv_cache_clear(lctx);
|
1068 |
+
llama_synchronize(lctx);
|
1069 |
+
llama_perf_context_reset(lctx);
|
1070 |
+
}
|
1071 |
+
|
1072 |
+
iparams.model.reset(model);
|
1073 |
+
iparams.context.reset(lctx);
|
1074 |
+
|
1075 |
+
return iparams;
|
1076 |
+
}
|
1077 |
+
|
1078 |
+
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
|
1079 |
+
llama_clear_adapter_lora(ctx);
|
1080 |
+
for (auto & la : lora) {
|
1081 |
+
if (la.scale != 0.0f) {
|
1082 |
+
llama_set_adapter_lora(ctx, la.ptr, la.scale);
|
1083 |
+
}
|
1084 |
+
}
|
1085 |
+
}
|
1086 |
+
|
1087 |
+
struct llama_model_params common_model_params_to_llama(common_params & params) {
|
1088 |
+
auto mparams = llama_model_default_params();
|
1089 |
+
|
1090 |
+
if (!params.devices.empty()) {
|
1091 |
+
mparams.devices = params.devices.data();
|
1092 |
+
}
|
1093 |
+
if (params.n_gpu_layers != -1) {
|
1094 |
+
mparams.n_gpu_layers = params.n_gpu_layers;
|
1095 |
+
}
|
1096 |
+
mparams.main_gpu = params.main_gpu;
|
1097 |
+
mparams.split_mode = params.split_mode;
|
1098 |
+
mparams.tensor_split = params.tensor_split;
|
1099 |
+
mparams.use_mmap = params.use_mmap;
|
1100 |
+
mparams.use_mlock = params.use_mlock;
|
1101 |
+
mparams.check_tensors = params.check_tensors;
|
1102 |
+
if (params.kv_overrides.empty()) {
|
1103 |
+
mparams.kv_overrides = NULL;
|
1104 |
+
} else {
|
1105 |
+
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
|
1106 |
+
mparams.kv_overrides = params.kv_overrides.data();
|
1107 |
+
}
|
1108 |
+
|
1109 |
+
return mparams;
|
1110 |
+
}
|
1111 |
+
|
1112 |
+
struct llama_context_params common_context_params_to_llama(const common_params & params) {
|
1113 |
+
auto cparams = llama_context_default_params();
|
1114 |
+
|
1115 |
+
cparams.n_ctx = params.n_ctx;
|
1116 |
+
cparams.n_seq_max = params.n_parallel;
|
1117 |
+
cparams.n_batch = params.n_batch;
|
1118 |
+
cparams.n_ubatch = params.n_ubatch;
|
1119 |
+
cparams.n_threads = params.cpuparams.n_threads;
|
1120 |
+
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
|
1121 |
+
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
|
1122 |
+
cparams.logits_all = params.logits_all;
|
1123 |
+
cparams.embeddings = params.embedding;
|
1124 |
+
cparams.rope_scaling_type = params.rope_scaling_type;
|
1125 |
+
cparams.rope_freq_base = params.rope_freq_base;
|
1126 |
+
cparams.rope_freq_scale = params.rope_freq_scale;
|
1127 |
+
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
1128 |
+
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
1129 |
+
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
1130 |
+
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
1131 |
+
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
|
1132 |
+
cparams.pooling_type = params.pooling_type;
|
1133 |
+
cparams.attention_type = params.attention_type;
|
1134 |
+
cparams.defrag_thold = params.defrag_thold;
|
1135 |
+
cparams.cb_eval = params.cb_eval;
|
1136 |
+
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
1137 |
+
cparams.offload_kqv = !params.no_kv_offload;
|
1138 |
+
cparams.flash_attn = params.flash_attn;
|
1139 |
+
cparams.no_perf = params.no_perf;
|
1140 |
+
|
1141 |
+
if (params.reranking) {
|
1142 |
+
cparams.embeddings = true;
|
1143 |
+
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
|
1144 |
+
}
|
1145 |
+
|
1146 |
+
cparams.type_k = params.cache_type_k;
|
1147 |
+
cparams.type_v = params.cache_type_v;
|
1148 |
+
|
1149 |
+
return cparams;
|
1150 |
+
}
|
1151 |
+
|
1152 |
+
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
|
1153 |
+
struct ggml_threadpool_params tpp;
|
1154 |
+
|
1155 |
+
ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
|
1156 |
+
|
1157 |
+
if (params.mask_valid) {
|
1158 |
+
std::memcpy(&tpp.cpumask, ¶ms.cpumask, GGML_MAX_N_THREADS);
|
1159 |
+
}
|
1160 |
+
|
1161 |
+
tpp.prio = params.priority;
|
1162 |
+
tpp.poll = params.poll;
|
1163 |
+
tpp.strict_cpu = params.strict_cpu;
|
1164 |
+
|
1165 |
+
return tpp;
|
1166 |
+
}
|
1167 |
+
|
1168 |
+
#ifdef LLAMA_USE_CURL
|
1169 |
+
|
1170 |
+
#define CURL_MAX_RETRY 3
|
1171 |
+
#define CURL_RETRY_DELAY_SECONDS 2
|
1172 |
+
|
1173 |
+
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
|
1174 |
+
int remaining_attempts = max_attempts;
|
1175 |
+
|
1176 |
+
while (remaining_attempts > 0) {
|
1177 |
+
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
|
1178 |
+
|
1179 |
+
CURLcode res = curl_easy_perform(curl);
|
1180 |
+
if (res == CURLE_OK) {
|
1181 |
+
return true;
|
1182 |
+
}
|
1183 |
+
|
1184 |
+
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
|
1185 |
+
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
|
1186 |
+
|
1187 |
+
remaining_attempts--;
|
1188 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
|
1189 |
+
}
|
1190 |
+
|
1191 |
+
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
|
1192 |
+
|
1193 |
+
return false;
|
1194 |
+
}
|
1195 |
+
|
1196 |
+
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
|
1197 |
+
// Initialize libcurl
|
1198 |
+
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
1199 |
+
curl_slist_ptr http_headers;
|
1200 |
+
if (!curl) {
|
1201 |
+
LOG_ERR("%s: error initializing libcurl\n", __func__);
|
1202 |
+
return false;
|
1203 |
+
}
|
1204 |
+
|
1205 |
+
bool force_download = false;
|
1206 |
+
|
1207 |
+
// Set the URL, allow to follow http redirection
|
1208 |
+
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
1209 |
+
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
|
1210 |
+
|
1211 |
+
// Check if hf-token or bearer-token was specified
|
1212 |
+
if (!hf_token.empty()) {
|
1213 |
+
std::string auth_header = "Authorization: Bearer " + hf_token;
|
1214 |
+
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
1215 |
+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
1216 |
+
}
|
1217 |
+
|
1218 |
+
#if defined(_WIN32)
|
1219 |
+
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
|
1220 |
+
// operating system. Currently implemented under MS-Windows.
|
1221 |
+
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
1222 |
+
#endif
|
1223 |
+
|
1224 |
+
// Check if the file already exists locally
|
1225 |
+
auto file_exists = std::filesystem::exists(path);
|
1226 |
+
|
1227 |
+
// If the file exists, check its JSON metadata companion file.
|
1228 |
+
std::string metadata_path = path + ".json";
|
1229 |
+
nlohmann::json metadata;
|
1230 |
+
std::string etag;
|
1231 |
+
std::string last_modified;
|
1232 |
+
|
1233 |
+
if (file_exists) {
|
1234 |
+
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
|
1235 |
+
std::ifstream metadata_in(metadata_path);
|
1236 |
+
if (metadata_in.good()) {
|
1237 |
+
try {
|
1238 |
+
metadata_in >> metadata;
|
1239 |
+
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
|
1240 |
+
if (metadata.contains("url") && metadata.at("url").is_string()) {
|
1241 |
+
auto previous_url = metadata.at("url").get<std::string>();
|
1242 |
+
if (previous_url != url) {
|
1243 |
+
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
|
1244 |
+
return false;
|
1245 |
+
}
|
1246 |
+
}
|
1247 |
+
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
|
1248 |
+
etag = metadata.at("etag");
|
1249 |
+
}
|
1250 |
+
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
|
1251 |
+
last_modified = metadata.at("lastModified");
|
1252 |
+
}
|
1253 |
+
} catch (const nlohmann::json::exception & e) {
|
1254 |
+
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
|
1255 |
+
return false;
|
1256 |
+
}
|
1257 |
+
}
|
1258 |
+
} else {
|
1259 |
+
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
1260 |
+
}
|
1261 |
+
|
1262 |
+
// Send a HEAD request to retrieve the etag and last-modified headers
|
1263 |
+
struct common_load_model_from_url_headers {
|
1264 |
+
std::string etag;
|
1265 |
+
std::string last_modified;
|
1266 |
+
};
|
1267 |
+
|
1268 |
+
common_load_model_from_url_headers headers;
|
1269 |
+
|
1270 |
+
{
|
1271 |
+
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
|
1272 |
+
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
|
1273 |
+
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
|
1274 |
+
|
1275 |
+
static std::regex header_regex("([^:]+): (.*)\r\n");
|
1276 |
+
static std::regex etag_regex("ETag", std::regex_constants::icase);
|
1277 |
+
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
|
1278 |
+
|
1279 |
+
std::string header(buffer, n_items);
|
1280 |
+
std::smatch match;
|
1281 |
+
if (std::regex_match(header, match, header_regex)) {
|
1282 |
+
const std::string & key = match[1];
|
1283 |
+
const std::string & value = match[2];
|
1284 |
+
if (std::regex_match(key, match, etag_regex)) {
|
1285 |
+
headers->etag = value;
|
1286 |
+
} else if (std::regex_match(key, match, last_modified_regex)) {
|
1287 |
+
headers->last_modified = value;
|
1288 |
+
}
|
1289 |
+
}
|
1290 |
+
return n_items;
|
1291 |
+
};
|
1292 |
+
|
1293 |
+
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
|
1294 |
+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
|
1295 |
+
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
|
1296 |
+
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
|
1297 |
+
|
1298 |
+
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
|
1299 |
+
if (!was_perform_successful) {
|
1300 |
+
return false;
|
1301 |
+
}
|
1302 |
+
|
1303 |
+
long http_code = 0;
|
1304 |
+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
1305 |
+
if (http_code != 200) {
|
1306 |
+
// HEAD not supported, we don't know if the file has changed
|
1307 |
+
// force trigger downloading
|
1308 |
+
force_download = true;
|
1309 |
+
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
|
1310 |
+
}
|
1311 |
+
}
|
1312 |
+
|
1313 |
+
bool should_download = !file_exists || force_download;
|
1314 |
+
if (!should_download) {
|
1315 |
+
if (!etag.empty() && etag != headers.etag) {
|
1316 |
+
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
|
1317 |
+
should_download = true;
|
1318 |
+
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
|
1319 |
+
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
|
1320 |
+
should_download = true;
|
1321 |
+
}
|
1322 |
+
}
|
1323 |
+
if (should_download) {
|
1324 |
+
std::string path_temporary = path + ".downloadInProgress";
|
1325 |
+
if (file_exists) {
|
1326 |
+
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
|
1327 |
+
if (remove(path.c_str()) != 0) {
|
1328 |
+
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
1329 |
+
return false;
|
1330 |
+
}
|
1331 |
+
}
|
1332 |
+
|
1333 |
+
// Set the output file
|
1334 |
+
|
1335 |
+
struct FILE_deleter {
|
1336 |
+
void operator()(FILE * f) const {
|
1337 |
+
fclose(f);
|
1338 |
+
}
|
1339 |
+
};
|
1340 |
+
|
1341 |
+
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
|
1342 |
+
if (!outfile) {
|
1343 |
+
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
|
1344 |
+
return false;
|
1345 |
+
}
|
1346 |
+
|
1347 |
+
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
|
1348 |
+
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
|
1349 |
+
return fwrite(data, size, nmemb, (FILE *)fd);
|
1350 |
+
};
|
1351 |
+
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
|
1352 |
+
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
1353 |
+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
|
1354 |
+
|
1355 |
+
// display download progress
|
1356 |
+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
|
1357 |
+
|
1358 |
+
// helper function to hide password in URL
|
1359 |
+
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
|
1360 |
+
std::size_t protocol_pos = url.find("://");
|
1361 |
+
if (protocol_pos == std::string::npos) {
|
1362 |
+
return url; // Malformed URL
|
1363 |
+
}
|
1364 |
+
|
1365 |
+
std::size_t at_pos = url.find('@', protocol_pos + 3);
|
1366 |
+
if (at_pos == std::string::npos) {
|
1367 |
+
return url; // No password in URL
|
1368 |
+
}
|
1369 |
+
|
1370 |
+
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
|
1371 |
+
};
|
1372 |
+
|
1373 |
+
// start the download
|
1374 |
+
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
|
1375 |
+
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
|
1376 |
+
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
|
1377 |
+
if (!was_perform_successful) {
|
1378 |
+
return false;
|
1379 |
+
}
|
1380 |
+
|
1381 |
+
long http_code = 0;
|
1382 |
+
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
|
1383 |
+
if (http_code < 200 || http_code >= 400) {
|
1384 |
+
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
|
1385 |
+
return false;
|
1386 |
+
}
|
1387 |
+
|
1388 |
+
// Causes file to be closed explicitly here before we rename it.
|
1389 |
+
outfile.reset();
|
1390 |
+
|
1391 |
+
// Write the updated JSON metadata file.
|
1392 |
+
metadata.update({
|
1393 |
+
{"url", url},
|
1394 |
+
{"etag", headers.etag},
|
1395 |
+
{"lastModified", headers.last_modified}
|
1396 |
+
});
|
1397 |
+
std::ofstream(metadata_path) << metadata.dump(4);
|
1398 |
+
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
|
1399 |
+
|
1400 |
+
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
|
1401 |
+
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
1402 |
+
return false;
|
1403 |
+
}
|
1404 |
+
}
|
1405 |
+
|
1406 |
+
return true;
|
1407 |
+
}
|
1408 |
+
|
1409 |
+
struct llama_model * common_load_model_from_url(
|
1410 |
+
const std::string & model_url,
|
1411 |
+
const std::string & local_path,
|
1412 |
+
const std::string & hf_token,
|
1413 |
+
const struct llama_model_params & params) {
|
1414 |
+
// Basic validation of the model_url
|
1415 |
+
if (model_url.empty()) {
|
1416 |
+
LOG_ERR("%s: invalid model_url\n", __func__);
|
1417 |
+
return NULL;
|
1418 |
+
}
|
1419 |
+
|
1420 |
+
if (!common_download_file(model_url, local_path, hf_token)) {
|
1421 |
+
return NULL;
|
1422 |
+
}
|
1423 |
+
|
1424 |
+
// check for additional GGUFs split to download
|
1425 |
+
int n_split = 0;
|
1426 |
+
{
|
1427 |
+
struct gguf_init_params gguf_params = {
|
1428 |
+
/*.no_alloc = */ true,
|
1429 |
+
/*.ctx = */ NULL,
|
1430 |
+
};
|
1431 |
+
auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
|
1432 |
+
if (!ctx_gguf) {
|
1433 |
+
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str());
|
1434 |
+
return NULL;
|
1435 |
+
}
|
1436 |
+
|
1437 |
+
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
|
1438 |
+
if (key_n_split >= 0) {
|
1439 |
+
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
|
1440 |
+
}
|
1441 |
+
|
1442 |
+
gguf_free(ctx_gguf);
|
1443 |
+
}
|
1444 |
+
|
1445 |
+
if (n_split > 1) {
|
1446 |
+
char split_prefix[PATH_MAX] = {0};
|
1447 |
+
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
1448 |
+
|
1449 |
+
// Verify the first split file format
|
1450 |
+
// and extract split URL and PATH prefixes
|
1451 |
+
{
|
1452 |
+
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
|
1453 |
+
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
|
1454 |
+
return NULL;
|
1455 |
+
}
|
1456 |
+
|
1457 |
+
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
|
1458 |
+
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
|
1459 |
+
return NULL;
|
1460 |
+
}
|
1461 |
+
}
|
1462 |
+
|
1463 |
+
// Prepare download in parallel
|
1464 |
+
std::vector<std::future<bool>> futures_download;
|
1465 |
+
for (int idx = 1; idx < n_split; idx++) {
|
1466 |
+
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
|
1467 |
+
char split_path[PATH_MAX] = {0};
|
1468 |
+
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
|
1469 |
+
|
1470 |
+
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
|
1471 |
+
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
|
1472 |
+
|
1473 |
+
return common_download_file(split_url, split_path, hf_token);
|
1474 |
+
}, idx));
|
1475 |
+
}
|
1476 |
+
|
1477 |
+
// Wait for all downloads to complete
|
1478 |
+
for (auto & f : futures_download) {
|
1479 |
+
if (!f.get()) {
|
1480 |
+
return NULL;
|
1481 |
+
}
|
1482 |
+
}
|
1483 |
+
}
|
1484 |
+
|
1485 |
+
return llama_model_load_from_file(local_path.c_str(), params);
|
1486 |
+
}
|
1487 |
+
|
1488 |
+
struct llama_model * common_load_model_from_hf(
|
1489 |
+
const std::string & repo,
|
1490 |
+
const std::string & remote_path,
|
1491 |
+
const std::string & local_path,
|
1492 |
+
const std::string & hf_token,
|
1493 |
+
const struct llama_model_params & params) {
|
1494 |
+
// construct hugging face model url:
|
1495 |
+
//
|
1496 |
+
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
|
1497 |
+
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
|
1498 |
+
//
|
1499 |
+
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
|
1500 |
+
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
|
1501 |
+
//
|
1502 |
+
|
1503 |
+
std::string model_url = "https://huggingface.co/";
|
1504 |
+
model_url += repo;
|
1505 |
+
model_url += "/resolve/main/";
|
1506 |
+
model_url += remote_path;
|
1507 |
+
|
1508 |
+
return common_load_model_from_url(model_url, local_path, hf_token, params);
|
1509 |
+
}
|
1510 |
+
|
1511 |
+
/**
|
1512 |
+
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
|
1513 |
+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
|
1514 |
+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
|
1515 |
+
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
|
1516 |
+
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
|
1517 |
+
*
|
1518 |
+
* Return pair of <repo, file> (with "repo" already having tag removed)
|
1519 |
+
*
|
1520 |
+
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
|
1521 |
+
*/
|
1522 |
+
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
|
1523 |
+
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
|
1524 |
+
std::string tag = parts.size() > 1 ? parts.back() : "latest";
|
1525 |
+
std::string hf_repo = parts[0];
|
1526 |
+
if (string_split<std::string>(hf_repo, '/').size() != 2) {
|
1527 |
+
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
|
1528 |
+
}
|
1529 |
+
|
1530 |
+
// fetch model info from Hugging Face Hub API
|
1531 |
+
json model_info;
|
1532 |
+
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
|
1533 |
+
curl_slist_ptr http_headers;
|
1534 |
+
std::string res_str;
|
1535 |
+
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
|
1536 |
+
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
|
1537 |
+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
|
1538 |
+
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
|
1539 |
+
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
|
1540 |
+
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
|
1541 |
+
return size * nmemb;
|
1542 |
+
};
|
1543 |
+
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
|
1544 |
+
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
|
1545 |
+
#if defined(_WIN32)
|
1546 |
+
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
|
1547 |
+
#endif
|
1548 |
+
if (!hf_token.empty()) {
|
1549 |
+
std::string auth_header = "Authorization: Bearer " + hf_token;
|
1550 |
+
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
|
1551 |
+
}
|
1552 |
+
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
|
1553 |
+
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
|
1554 |
+
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
|
1555 |
+
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
|
1556 |
+
|
1557 |
+
CURLcode res = curl_easy_perform(curl.get());
|
1558 |
+
|
1559 |
+
if (res != CURLE_OK) {
|
1560 |
+
throw std::runtime_error("error: cannot make GET request to HF API");
|
1561 |
+
}
|
1562 |
+
|
1563 |
+
long res_code;
|
1564 |
+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
|
1565 |
+
if (res_code == 200) {
|
1566 |
+
model_info = json::parse(res_str);
|
1567 |
+
} else if (res_code == 401) {
|
1568 |
+
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
|
1569 |
+
} else {
|
1570 |
+
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
|
1571 |
+
}
|
1572 |
+
|
1573 |
+
// check response
|
1574 |
+
if (!model_info.contains("ggufFile")) {
|
1575 |
+
throw std::runtime_error("error: model does not have ggufFile");
|
1576 |
+
}
|
1577 |
+
json & gguf_file = model_info.at("ggufFile");
|
1578 |
+
if (!gguf_file.contains("rfilename")) {
|
1579 |
+
throw std::runtime_error("error: ggufFile does not have rfilename");
|
1580 |
+
}
|
1581 |
+
|
1582 |
+
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
|
1583 |
+
}
|
1584 |
+
|
1585 |
+
#else
|
1586 |
+
|
1587 |
+
struct llama_model * common_load_model_from_url(
|
1588 |
+
const std::string & /*model_url*/,
|
1589 |
+
const std::string & /*local_path*/,
|
1590 |
+
const std::string & /*hf_token*/,
|
1591 |
+
const struct llama_model_params & /*params*/) {
|
1592 |
+
LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
|
1593 |
+
return nullptr;
|
1594 |
+
}
|
1595 |
+
|
1596 |
+
struct llama_model * common_load_model_from_hf(
|
1597 |
+
const std::string & /*repo*/,
|
1598 |
+
const std::string & /*remote_path*/,
|
1599 |
+
const std::string & /*local_path*/,
|
1600 |
+
const std::string & /*hf_token*/,
|
1601 |
+
const struct llama_model_params & /*params*/) {
|
1602 |
+
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
1603 |
+
return nullptr;
|
1604 |
+
}
|
1605 |
+
|
1606 |
+
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
|
1607 |
+
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
|
1608 |
+
return std::make_pair("", "");
|
1609 |
+
}
|
1610 |
+
|
1611 |
+
#endif // LLAMA_USE_CURL
|
1612 |
+
|
1613 |
+
//
|
1614 |
+
// Batch utils
|
1615 |
+
//
|
1616 |
+
|
1617 |
+
void common_batch_clear(struct llama_batch & batch) {
|
1618 |
+
batch.n_tokens = 0;
|
1619 |
+
}
|
1620 |
+
|
1621 |
+
void common_batch_add(
|
1622 |
+
struct llama_batch & batch,
|
1623 |
+
llama_token id,
|
1624 |
+
llama_pos pos,
|
1625 |
+
const std::vector<llama_seq_id> & seq_ids,
|
1626 |
+
bool logits) {
|
1627 |
+
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
1628 |
+
|
1629 |
+
batch.token [batch.n_tokens] = id;
|
1630 |
+
batch.pos [batch.n_tokens] = pos;
|
1631 |
+
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
|
1632 |
+
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
1633 |
+
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
1634 |
+
}
|
1635 |
+
batch.logits [batch.n_tokens] = logits;
|
1636 |
+
|
1637 |
+
batch.n_tokens++;
|
1638 |
+
}
|
1639 |
+
|
1640 |
+
//
|
1641 |
+
// Token utils
|
1642 |
+
//
|
1643 |
+
|
1644 |
+
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
|
1645 |
+
size_t i;
|
1646 |
+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
|
1647 |
+
|
1648 |
+
return i;
|
1649 |
+
}
|
1650 |
+
|
1651 |
+
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
|
1652 |
+
// check for empty sequences
|
1653 |
+
if (a.empty() || b.empty()) {
|
1654 |
+
return 0;
|
1655 |
+
}
|
1656 |
+
|
1657 |
+
// get the lengths of the input sequences
|
1658 |
+
size_t a_len = a.size();
|
1659 |
+
size_t b_len = b.size();
|
1660 |
+
|
1661 |
+
// initialize the maximum length of the longest common subsequence (LCS)
|
1662 |
+
size_t max_length = 0;
|
1663 |
+
|
1664 |
+
// use two rows instead of a 2D matrix to optimize space
|
1665 |
+
std::vector<size_t> prev_row(b_len + 1, 0);
|
1666 |
+
std::vector<size_t> curr_row(b_len + 1, 0);
|
1667 |
+
|
1668 |
+
// iterate through the elements of a
|
1669 |
+
for (size_t i = 1; i <= a_len; i++) {
|
1670 |
+
// iterate through the elements of b
|
1671 |
+
for (size_t j = 1; j <= b_len; j++) {
|
1672 |
+
// if elements at the current positions match
|
1673 |
+
if (a[i - 1] == b[j - 1]) {
|
1674 |
+
// if it's the first element of either sequences, set LCS length to 1
|
1675 |
+
if (i == 1 || j == 1) {
|
1676 |
+
curr_row[j] = 1;
|
1677 |
+
} else {
|
1678 |
+
// increment LCS length by 1 compared to the previous element
|
1679 |
+
curr_row[j] = prev_row[j - 1] + 1;
|
1680 |
+
}
|
1681 |
+
|
1682 |
+
// update max_length if necessary
|
1683 |
+
if (curr_row[j] > max_length) {
|
1684 |
+
max_length = curr_row[j];
|
1685 |
+
}
|
1686 |
+
} else {
|
1687 |
+
// reset LCS length if elements don't match
|
1688 |
+
curr_row[j] = 0;
|
1689 |
+
}
|
1690 |
+
}
|
1691 |
+
|
1692 |
+
// update the previous row for the next iteration
|
1693 |
+
prev_row = curr_row;
|
1694 |
+
}
|
1695 |
+
|
1696 |
+
// return the maximum length of the LCS
|
1697 |
+
return max_length;
|
1698 |
+
}
|
1699 |
+
|
1700 |
+
//
|
1701 |
+
// Vocab utils
|
1702 |
+
//
|
1703 |
+
|
1704 |
+
std::vector<llama_token> common_tokenize(
|
1705 |
+
const struct llama_context * ctx,
|
1706 |
+
const std::string & text,
|
1707 |
+
bool add_special,
|
1708 |
+
bool parse_special) {
|
1709 |
+
const llama_model * model = llama_get_model(ctx);
|
1710 |
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
1711 |
+
return common_tokenize(vocab, text, add_special, parse_special);
|
1712 |
+
}
|
1713 |
+
|
1714 |
+
std::vector<llama_token> common_tokenize(
|
1715 |
+
const struct llama_vocab * vocab,
|
1716 |
+
const std::string & text,
|
1717 |
+
bool add_special,
|
1718 |
+
bool parse_special) {
|
1719 |
+
// upper limit for the number of tokens
|
1720 |
+
int n_tokens = text.length() + 2 * add_special;
|
1721 |
+
std::vector<llama_token> result(n_tokens);
|
1722 |
+
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
1723 |
+
if (n_tokens < 0) {
|
1724 |
+
result.resize(-n_tokens);
|
1725 |
+
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
|
1726 |
+
GGML_ASSERT(check == -n_tokens);
|
1727 |
+
} else {
|
1728 |
+
result.resize(n_tokens);
|
1729 |
+
}
|
1730 |
+
return result;
|
1731 |
+
}
|
1732 |
+
|
1733 |
+
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
|
1734 |
+
const llama_model * model = llama_get_model(ctx);
|
1735 |
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
1736 |
+
return common_token_to_piece(vocab, token, special);
|
1737 |
+
}
|
1738 |
+
|
1739 |
+
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
|
1740 |
+
std::string piece;
|
1741 |
+
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
1742 |
+
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
1743 |
+
if (n_chars < 0) {
|
1744 |
+
piece.resize(-n_chars);
|
1745 |
+
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
1746 |
+
GGML_ASSERT(check == -n_chars);
|
1747 |
+
}
|
1748 |
+
else {
|
1749 |
+
piece.resize(n_chars);
|
1750 |
+
}
|
1751 |
+
|
1752 |
+
return piece;
|
1753 |
+
}
|
1754 |
+
|
1755 |
+
std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
|
1756 |
+
const llama_model * model = llama_get_model(ctx);
|
1757 |
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
1758 |
+
return common_detokenize(vocab, tokens, special);
|
1759 |
+
}
|
1760 |
+
|
1761 |
+
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
|
1762 |
+
std::string text;
|
1763 |
+
text.resize(std::max(text.capacity(), tokens.size()));
|
1764 |
+
int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
1765 |
+
if (n_chars < 0) {
|
1766 |
+
text.resize(-n_chars);
|
1767 |
+
n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
|
1768 |
+
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
|
1769 |
+
}
|
1770 |
+
|
1771 |
+
text.resize(n_chars);
|
1772 |
+
|
1773 |
+
// NOTE: the original tokenizer decodes bytes after collecting the pieces.
|
1774 |
+
return text;
|
1775 |
+
}
|
1776 |
+
|
1777 |
+
//
|
1778 |
+
// KV cache utils
|
1779 |
+
//
|
1780 |
+
|
1781 |
+
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
|
1782 |
+
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
|
1783 |
+
|
1784 |
+
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
|
1785 |
+
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
1786 |
+
|
1787 |
+
llama_kv_cache_view_cell * c_curr = view.cells;
|
1788 |
+
llama_seq_id * cs_curr = view.cells_sequences;
|
1789 |
+
|
1790 |
+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
1791 |
+
if (i % row_size == 0) {
|
1792 |
+
printf("\n%5d: ", i);
|
1793 |
+
}
|
1794 |
+
int seq_count = 0;
|
1795 |
+
for (int j = 0; j < view.n_seq_max; j++) {
|
1796 |
+
if (cs_curr[j] >= 0) { seq_count++; }
|
1797 |
+
}
|
1798 |
+
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
|
1799 |
+
}
|
1800 |
+
|
1801 |
+
printf("\n=== Done dumping\n");
|
1802 |
+
}
|
1803 |
+
|
1804 |
+
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
|
1805 |
+
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
|
1806 |
+
|
1807 |
+
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
|
1808 |
+
view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
|
1809 |
+
|
1810 |
+
std::unordered_map<llama_seq_id, size_t> seqs;
|
1811 |
+
llama_kv_cache_view_cell * c_curr = view.cells;
|
1812 |
+
llama_seq_id * cs_curr = view.cells_sequences;
|
1813 |
+
|
1814 |
+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
1815 |
+
for (int j = 0; j < view.n_seq_max; j++) {
|
1816 |
+
if (cs_curr[j] < 0) { continue; }
|
1817 |
+
if (seqs.find(cs_curr[j]) == seqs.end()) {
|
1818 |
+
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
1819 |
+
const size_t sz = seqs.size();
|
1820 |
+
seqs[cs_curr[j]] = sz;
|
1821 |
+
}
|
1822 |
+
}
|
1823 |
+
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
|
1824 |
+
}
|
1825 |
+
|
1826 |
+
printf("=== Sequence legend: ");
|
1827 |
+
for (const auto & it : seqs) {
|
1828 |
+
printf("%zu=%d, ", it.second, it.first);
|
1829 |
+
}
|
1830 |
+
printf("'+'=other sequence ids");
|
1831 |
+
|
1832 |
+
c_curr = view.cells;
|
1833 |
+
cs_curr = view.cells_sequences;
|
1834 |
+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
|
1835 |
+
if (i % row_size == 0) {
|
1836 |
+
printf("\n%5d: ", i);
|
1837 |
+
}
|
1838 |
+
for (int j = 0; j < view.n_seq_max; j++) {
|
1839 |
+
if (cs_curr[j] >= 0) {
|
1840 |
+
const auto & it = seqs.find(cs_curr[j]);
|
1841 |
+
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
|
1842 |
+
} else {
|
1843 |
+
putchar('.');
|
1844 |
+
}
|
1845 |
+
}
|
1846 |
+
putchar(' ');
|
1847 |
+
}
|
1848 |
+
|
1849 |
+
printf("\n=== Done dumping\n");
|
1850 |
+
}
|
1851 |
+
|
1852 |
+
//
|
1853 |
+
// Embedding utils
|
1854 |
+
//
|
1855 |
+
|
1856 |
+
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
|
1857 |
+
double sum = 0.0;
|
1858 |
+
|
1859 |
+
switch (embd_norm) {
|
1860 |
+
case -1: // no normalisation
|
1861 |
+
sum = 1.0;
|
1862 |
+
break;
|
1863 |
+
case 0: // max absolute
|
1864 |
+
for (int i = 0; i < n; i++) {
|
1865 |
+
if (sum < std::abs(inp[i])) {
|
1866 |
+
sum = std::abs(inp[i]);
|
1867 |
+
}
|
1868 |
+
}
|
1869 |
+
sum /= 32760.0; // make an int16 range
|
1870 |
+
break;
|
1871 |
+
case 2: // euclidean
|
1872 |
+
for (int i = 0; i < n; i++) {
|
1873 |
+
sum += inp[i] * inp[i];
|
1874 |
+
}
|
1875 |
+
sum = std::sqrt(sum);
|
1876 |
+
break;
|
1877 |
+
default: // p-norm (euclidean is p-norm p=2)
|
1878 |
+
for (int i = 0; i < n; i++) {
|
1879 |
+
sum += std::pow(std::abs(inp[i]), embd_norm);
|
1880 |
+
}
|
1881 |
+
sum = std::pow(sum, 1.0 / embd_norm);
|
1882 |
+
break;
|
1883 |
+
}
|
1884 |
+
|
1885 |
+
const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
|
1886 |
+
|
1887 |
+
for (int i = 0; i < n; i++) {
|
1888 |
+
out[i] = inp[i] * norm;
|
1889 |
+
}
|
1890 |
+
}
|
1891 |
+
|
1892 |
+
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){
|
1893 |
+
double sum = 0.0;
|
1894 |
+
double sum1 = 0.0;
|
1895 |
+
double sum2 = 0.0;
|
1896 |
+
|
1897 |
+
for (int i = 0; i < n; i++) {
|
1898 |
+
sum += embd1[i] * embd2[i];
|
1899 |
+
sum1 += embd1[i] * embd1[i];
|
1900 |
+
sum2 += embd2[i] * embd2[i];
|
1901 |
+
}
|
1902 |
+
|
1903 |
+
// Handle the case where one or both vectors are zero vectors
|
1904 |
+
if (sum1 == 0.0 || sum2 == 0.0) {
|
1905 |
+
if (sum1 == 0.0 && sum2 == 0.0) {
|
1906 |
+
return 1.0f; // two zero vectors are similar
|
1907 |
+
}
|
1908 |
+
return 0.0f;
|
1909 |
+
}
|
1910 |
+
|
1911 |
+
return sum / (sqrt(sum1) * sqrt(sum2));
|
1912 |
+
}
|
1913 |
+
|
1914 |
+
//
|
1915 |
+
// Control vector utils
|
1916 |
+
//
|
1917 |
+
|
1918 |
+
static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) {
|
1919 |
+
common_control_vector_data result = { -1, {} };
|
1920 |
+
|
1921 |
+
ggml_context * ctx = nullptr;
|
1922 |
+
struct gguf_init_params meta_gguf_params = {
|
1923 |
+
/* .no_alloc = */ false,
|
1924 |
+
/* .ctx = */ &ctx,
|
1925 |
+
};
|
1926 |
+
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
|
1927 |
+
if (!ctx_gguf) {
|
1928 |
+
LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
|
1929 |
+
return result;
|
1930 |
+
}
|
1931 |
+
|
1932 |
+
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
|
1933 |
+
if (n_tensors == 0) {
|
1934 |
+
LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
|
1935 |
+
}
|
1936 |
+
|
1937 |
+
for (int i = 0; i < n_tensors; i++) {
|
1938 |
+
std::string name = gguf_get_tensor_name(ctx_gguf, i);
|
1939 |
+
|
1940 |
+
int layer_idx = -1;
|
1941 |
+
|
1942 |
+
// split on '.'
|
1943 |
+
size_t dotpos = name.find('.');
|
1944 |
+
if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
|
1945 |
+
try {
|
1946 |
+
layer_idx = std::stoi(name.substr(dotpos + 1));
|
1947 |
+
} catch (...) {
|
1948 |
+
layer_idx = -1;
|
1949 |
+
}
|
1950 |
+
}
|
1951 |
+
if (layer_idx < 0) {
|
1952 |
+
LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
1953 |
+
result.n_embd = -1;
|
1954 |
+
break;
|
1955 |
+
} else if (layer_idx == 0) {
|
1956 |
+
LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
1957 |
+
result.n_embd = -1;
|
1958 |
+
break;
|
1959 |
+
}
|
1960 |
+
|
1961 |
+
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
|
1962 |
+
if (tensor->type != GGML_TYPE_F32) {
|
1963 |
+
LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
|
1964 |
+
result.n_embd = -1;
|
1965 |
+
break;
|
1966 |
+
}
|
1967 |
+
if (ggml_n_dims(tensor) != 1) {
|
1968 |
+
LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
|
1969 |
+
result.n_embd = -1;
|
1970 |
+
break;
|
1971 |
+
}
|
1972 |
+
|
1973 |
+
if (result.n_embd == -1) {
|
1974 |
+
result.n_embd = ggml_nelements(tensor);
|
1975 |
+
} else if (ggml_nelements(tensor) != result.n_embd) {
|
1976 |
+
LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
|
1977 |
+
result.n_embd = -1;
|
1978 |
+
break;
|
1979 |
+
}
|
1980 |
+
|
1981 |
+
// extend if necessary - do not store data for layer 0 (it's not used)
|
1982 |
+
result.data.resize(std::max(result.data.size(), static_cast<size_t>(result.n_embd * layer_idx)), 0.0f);
|
1983 |
+
|
1984 |
+
const float * src = (const float *) tensor->data;
|
1985 |
+
float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0]
|
1986 |
+
for (int j = 0; j < result.n_embd; j++) {
|
1987 |
+
dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file
|
1988 |
+
}
|
1989 |
+
|
1990 |
+
}
|
1991 |
+
|
1992 |
+
if (result.n_embd == -1) {
|
1993 |
+
LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
|
1994 |
+
result.data.clear();
|
1995 |
+
}
|
1996 |
+
|
1997 |
+
gguf_free(ctx_gguf);
|
1998 |
+
ggml_free(ctx);
|
1999 |
+
|
2000 |
+
return result;
|
2001 |
+
}
|
2002 |
+
|
2003 |
+
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos) {
|
2004 |
+
common_control_vector_data result = { -1, {} };
|
2005 |
+
|
2006 |
+
for (const auto & info : load_infos) {
|
2007 |
+
auto cur = common_control_vector_load_one(info);
|
2008 |
+
|
2009 |
+
if (cur.n_embd == -1) {
|
2010 |
+
result.n_embd = -1;
|
2011 |
+
break;
|
2012 |
+
}
|
2013 |
+
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
|
2014 |
+
LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
|
2015 |
+
result.n_embd = -1;
|
2016 |
+
break;
|
2017 |
+
}
|
2018 |
+
|
2019 |
+
if (result.n_embd == -1) {
|
2020 |
+
result = std::move(cur);
|
2021 |
+
} else {
|
2022 |
+
result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f); // extend if necessary
|
2023 |
+
for (size_t i = 0; i < cur.data.size(); i++) {
|
2024 |
+
result.data[i] += cur.data[i];
|
2025 |
+
}
|
2026 |
+
}
|
2027 |
+
}
|
2028 |
+
|
2029 |
+
if (result.n_embd == -1) {
|
2030 |
+
LOG_ERR("%s: no valid control vector files passed\n", __func__);
|
2031 |
+
result.data.clear();
|
2032 |
+
}
|
2033 |
+
|
2034 |
+
return result;
|
2035 |
+
}
|
2036 |
+
|
2037 |
+
template <>
|
2038 |
+
json common_grammar_trigger::to_json() const {
|
2039 |
+
json out {
|
2040 |
+
{"type", (int) type},
|
2041 |
+
{"value", value},
|
2042 |
+
};
|
2043 |
+
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
2044 |
+
out["token"] = (int) token;
|
2045 |
+
}
|
2046 |
+
return out;
|
2047 |
+
}
|
2048 |
+
|
2049 |
+
template <>
|
2050 |
+
common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
|
2051 |
+
common_grammar_trigger out;
|
2052 |
+
out.type = (common_grammar_trigger_type) in.at("type").get<int>();
|
2053 |
+
out.value = in.at("value").get<std::string>();
|
2054 |
+
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
|
2055 |
+
out.token = (llama_token) in.at("token").get<int>();
|
2056 |
+
}
|
2057 |
+
return out;
|
2058 |
+
}
|
common/common.h
ADDED
@@ -0,0 +1,681 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Various helper functions and utilities
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include "llama-cpp.h"
|
6 |
+
|
7 |
+
#include <set>
|
8 |
+
#include <string>
|
9 |
+
#include <vector>
|
10 |
+
#include <sstream>
|
11 |
+
|
12 |
+
#ifdef _WIN32
|
13 |
+
#define DIRECTORY_SEPARATOR '\\'
|
14 |
+
#else
|
15 |
+
#define DIRECTORY_SEPARATOR '/'
|
16 |
+
#endif // _WIN32
|
17 |
+
|
18 |
+
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
19 |
+
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
20 |
+
|
21 |
+
#define print_build_info() do { \
|
22 |
+
fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \
|
23 |
+
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
|
24 |
+
} while(0)
|
25 |
+
|
26 |
+
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
|
27 |
+
|
28 |
+
struct common_adapter_lora_info {
|
29 |
+
std::string path;
|
30 |
+
float scale;
|
31 |
+
|
32 |
+
struct llama_adapter_lora * ptr;
|
33 |
+
};
|
34 |
+
|
35 |
+
using llama_tokens = std::vector<llama_token>;
|
36 |
+
|
37 |
+
// build info
|
38 |
+
|
39 |
+
struct common_control_vector_load_info;
|
40 |
+
|
41 |
+
//
|
42 |
+
// CPU utils
|
43 |
+
//
|
44 |
+
|
45 |
+
struct cpu_params {
|
46 |
+
int n_threads = -1;
|
47 |
+
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
|
48 |
+
bool mask_valid = false; // Default: any CPU
|
49 |
+
enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
|
50 |
+
bool strict_cpu = false; // Use strict CPU placement
|
51 |
+
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
|
52 |
+
};
|
53 |
+
|
54 |
+
int32_t cpu_get_num_physical_cores();
|
55 |
+
int32_t cpu_get_num_math();
|
56 |
+
|
57 |
+
//
|
58 |
+
// Common params
|
59 |
+
//
|
60 |
+
|
61 |
+
enum llama_example {
|
62 |
+
LLAMA_EXAMPLE_COMMON,
|
63 |
+
LLAMA_EXAMPLE_SPECULATIVE,
|
64 |
+
LLAMA_EXAMPLE_MAIN,
|
65 |
+
LLAMA_EXAMPLE_INFILL,
|
66 |
+
LLAMA_EXAMPLE_EMBEDDING,
|
67 |
+
LLAMA_EXAMPLE_PERPLEXITY,
|
68 |
+
LLAMA_EXAMPLE_RETRIEVAL,
|
69 |
+
LLAMA_EXAMPLE_PASSKEY,
|
70 |
+
LLAMA_EXAMPLE_IMATRIX,
|
71 |
+
LLAMA_EXAMPLE_BENCH,
|
72 |
+
LLAMA_EXAMPLE_SERVER,
|
73 |
+
LLAMA_EXAMPLE_CVECTOR_GENERATOR,
|
74 |
+
LLAMA_EXAMPLE_EXPORT_LORA,
|
75 |
+
LLAMA_EXAMPLE_LLAVA,
|
76 |
+
LLAMA_EXAMPLE_LOOKUP,
|
77 |
+
LLAMA_EXAMPLE_PARALLEL,
|
78 |
+
LLAMA_EXAMPLE_TTS,
|
79 |
+
|
80 |
+
LLAMA_EXAMPLE_COUNT,
|
81 |
+
};
|
82 |
+
|
83 |
+
enum common_sampler_type {
|
84 |
+
COMMON_SAMPLER_TYPE_NONE = 0,
|
85 |
+
COMMON_SAMPLER_TYPE_DRY = 1,
|
86 |
+
COMMON_SAMPLER_TYPE_TOP_K = 2,
|
87 |
+
COMMON_SAMPLER_TYPE_TOP_P = 3,
|
88 |
+
COMMON_SAMPLER_TYPE_MIN_P = 4,
|
89 |
+
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
|
90 |
+
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
|
91 |
+
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
|
92 |
+
COMMON_SAMPLER_TYPE_XTC = 8,
|
93 |
+
COMMON_SAMPLER_TYPE_INFILL = 9,
|
94 |
+
COMMON_SAMPLER_TYPE_PENALTIES = 10,
|
95 |
+
};
|
96 |
+
|
97 |
+
// dimensionality reduction methods, used by cvector-generator
|
98 |
+
enum dimre_method {
|
99 |
+
DIMRE_METHOD_PCA,
|
100 |
+
DIMRE_METHOD_MEAN,
|
101 |
+
};
|
102 |
+
|
103 |
+
enum common_conversation_mode {
|
104 |
+
COMMON_CONVERSATION_MODE_DISABLED = 0,
|
105 |
+
COMMON_CONVERSATION_MODE_ENABLED = 1,
|
106 |
+
COMMON_CONVERSATION_MODE_AUTO = 2,
|
107 |
+
};
|
108 |
+
|
109 |
+
enum common_grammar_trigger_type {
|
110 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
111 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
112 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
113 |
+
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
114 |
+
};
|
115 |
+
|
116 |
+
struct common_grammar_trigger {
|
117 |
+
common_grammar_trigger_type type;
|
118 |
+
std::string value;
|
119 |
+
llama_token token = LLAMA_TOKEN_NULL;
|
120 |
+
|
121 |
+
// T can only be nlohmann::ordered_json
|
122 |
+
template <class T> T to_json() const;
|
123 |
+
template <class T> static common_grammar_trigger from_json(const T & in);
|
124 |
+
};
|
125 |
+
|
126 |
+
// sampling parameters
|
127 |
+
struct common_params_sampling {
|
128 |
+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
129 |
+
|
130 |
+
int32_t n_prev = 64; // number of previous tokens to remember
|
131 |
+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
132 |
+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
133 |
+
int32_t top_k = 40; // <= 0 to use vocab size
|
134 |
+
float top_p = 0.95f; // 1.0 = disabled
|
135 |
+
float min_p = 0.05f; // 0.0 = disabled
|
136 |
+
float xtc_probability = 0.00f; // 0.0 = disabled
|
137 |
+
float xtc_threshold = 0.10f; // > 0.5 disables XTC
|
138 |
+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
139 |
+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
140 |
+
float dynatemp_range = 0.00f; // 0.0 = disabled
|
141 |
+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
142 |
+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
143 |
+
float penalty_repeat = 1.00f; // 1.0 = disabled
|
144 |
+
float penalty_freq = 0.00f; // 0.0 = disabled
|
145 |
+
float penalty_present = 0.00f; // 0.0 = disabled
|
146 |
+
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
|
147 |
+
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
|
148 |
+
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
|
149 |
+
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
|
150 |
+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
151 |
+
float top_n_sigma = -1.00f;// -1.0 = disabled
|
152 |
+
float mirostat_tau = 5.00f; // target entropy
|
153 |
+
float mirostat_eta = 0.10f; // learning rate
|
154 |
+
bool ignore_eos = false;
|
155 |
+
bool no_perf = false; // disable performance metrics
|
156 |
+
bool timing_per_token = false;
|
157 |
+
|
158 |
+
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
159 |
+
|
160 |
+
|
161 |
+
std::vector<enum common_sampler_type> samplers = {
|
162 |
+
COMMON_SAMPLER_TYPE_PENALTIES,
|
163 |
+
COMMON_SAMPLER_TYPE_DRY,
|
164 |
+
COMMON_SAMPLER_TYPE_TOP_K,
|
165 |
+
COMMON_SAMPLER_TYPE_TYPICAL_P,
|
166 |
+
COMMON_SAMPLER_TYPE_TOP_P,
|
167 |
+
COMMON_SAMPLER_TYPE_MIN_P,
|
168 |
+
COMMON_SAMPLER_TYPE_XTC,
|
169 |
+
COMMON_SAMPLER_TYPE_TEMPERATURE,
|
170 |
+
};
|
171 |
+
|
172 |
+
std::string grammar; // optional BNF-like grammar to constrain sampling
|
173 |
+
bool grammar_lazy = false;
|
174 |
+
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
|
175 |
+
std::set<llama_token> preserved_tokens;
|
176 |
+
|
177 |
+
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
178 |
+
|
179 |
+
// print the parameters into a string
|
180 |
+
std::string print() const;
|
181 |
+
};
|
182 |
+
|
183 |
+
struct common_params_speculative {
|
184 |
+
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
185 |
+
|
186 |
+
int32_t n_ctx = 0; // draft context size
|
187 |
+
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
188 |
+
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
189 |
+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
190 |
+
float p_split = 0.1f; // speculative decoding split probability
|
191 |
+
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
192 |
+
|
193 |
+
struct cpu_params cpuparams;
|
194 |
+
struct cpu_params cpuparams_batch;
|
195 |
+
|
196 |
+
std::string hf_repo = ""; // HF repo // NOLINT
|
197 |
+
std::string hf_file = ""; // HF file // NOLINT
|
198 |
+
|
199 |
+
std::string model = ""; // draft model for speculative decoding // NOLINT
|
200 |
+
std::string model_url = ""; // model url to download // NOLINT
|
201 |
+
};
|
202 |
+
|
203 |
+
struct common_params_vocoder {
|
204 |
+
std::string hf_repo = ""; // HF repo // NOLINT
|
205 |
+
std::string hf_file = ""; // HF file // NOLINT
|
206 |
+
|
207 |
+
std::string model = ""; // model path // NOLINT
|
208 |
+
std::string model_url = ""; // model url to download // NOLINT
|
209 |
+
|
210 |
+
std::string speaker_file = ""; // speaker file path // NOLINT
|
211 |
+
|
212 |
+
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
213 |
+
};
|
214 |
+
|
215 |
+
enum common_reasoning_format {
|
216 |
+
COMMON_REASONING_FORMAT_NONE,
|
217 |
+
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
|
218 |
+
};
|
219 |
+
|
220 |
+
struct common_params {
|
221 |
+
int32_t n_predict = -1; // new tokens to predict
|
222 |
+
int32_t n_ctx = 4096; // context size
|
223 |
+
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
|
224 |
+
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
|
225 |
+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
226 |
+
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
227 |
+
int32_t n_parallel = 1; // number of parallel sequences to decode
|
228 |
+
int32_t n_sequences = 1; // number of sequences to decode
|
229 |
+
int32_t grp_attn_n = 1; // group-attention factor
|
230 |
+
int32_t grp_attn_w = 512; // group-attention width
|
231 |
+
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
|
232 |
+
float rope_freq_base = 0.0f; // RoPE base frequency
|
233 |
+
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
234 |
+
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
235 |
+
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
236 |
+
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
237 |
+
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
238 |
+
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
239 |
+
float defrag_thold = 0.1f; // KV cache defragmentation threshold
|
240 |
+
|
241 |
+
// offload params
|
242 |
+
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
243 |
+
|
244 |
+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
245 |
+
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
246 |
+
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
|
247 |
+
|
248 |
+
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
|
249 |
+
|
250 |
+
struct cpu_params cpuparams;
|
251 |
+
struct cpu_params cpuparams_batch;
|
252 |
+
|
253 |
+
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
254 |
+
void * cb_eval_user_data = nullptr;
|
255 |
+
|
256 |
+
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
|
257 |
+
|
258 |
+
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
|
259 |
+
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
260 |
+
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
261 |
+
|
262 |
+
struct common_params_sampling sampling;
|
263 |
+
struct common_params_speculative speculative;
|
264 |
+
struct common_params_vocoder vocoder;
|
265 |
+
|
266 |
+
std::string model = ""; // model path // NOLINT
|
267 |
+
std::string model_alias = ""; // model alias // NOLINT
|
268 |
+
std::string model_url = ""; // model url to download // NOLINT
|
269 |
+
std::string hf_token = ""; // HF token // NOLINT
|
270 |
+
std::string hf_repo = ""; // HF repo // NOLINT
|
271 |
+
std::string hf_file = ""; // HF file // NOLINT
|
272 |
+
std::string prompt = ""; // NOLINT
|
273 |
+
std::string system_prompt = ""; // NOLINT
|
274 |
+
std::string prompt_file = ""; // store the external prompt file name // NOLINT
|
275 |
+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
276 |
+
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
277 |
+
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
278 |
+
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
279 |
+
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
280 |
+
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
281 |
+
|
282 |
+
std::vector<std::string> in_files; // all input files
|
283 |
+
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
|
284 |
+
std::vector<llama_model_kv_override> kv_overrides;
|
285 |
+
|
286 |
+
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
|
287 |
+
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
|
288 |
+
|
289 |
+
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
|
290 |
+
|
291 |
+
int32_t verbosity = 0;
|
292 |
+
int32_t control_vector_layer_start = -1; // layer range for control vector
|
293 |
+
int32_t control_vector_layer_end = -1; // layer range for control vector
|
294 |
+
|
295 |
+
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
296 |
+
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
297 |
+
// (which is more convenient to use for plotting)
|
298 |
+
//
|
299 |
+
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
|
300 |
+
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
301 |
+
|
302 |
+
bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
|
303 |
+
size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
|
304 |
+
|
305 |
+
bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
|
306 |
+
size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
|
307 |
+
|
308 |
+
bool kl_divergence = false; // compute KL divergence
|
309 |
+
|
310 |
+
bool usage = false; // print usage
|
311 |
+
bool completion = false; // print source-able completion script
|
312 |
+
bool use_color = false; // use color to distinguish generations and inputs
|
313 |
+
bool special = false; // enable special token output
|
314 |
+
bool interactive = false; // interactive mode
|
315 |
+
bool interactive_first = false; // wait for user input immediately
|
316 |
+
bool prompt_cache_all = false; // save user input and generations to prompt cache
|
317 |
+
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
|
318 |
+
|
319 |
+
bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
|
320 |
+
bool multiline_input = false; // reverse the usage of `\`
|
321 |
+
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
322 |
+
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
323 |
+
bool flash_attn = false; // flash attention
|
324 |
+
bool no_perf = false; // disable performance metrics
|
325 |
+
bool ctx_shift = true; // context shift on inifinite text generation
|
326 |
+
|
327 |
+
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
328 |
+
bool logits_all = false; // return logits for all tokens in the batch
|
329 |
+
bool use_mmap = true; // use mmap for faster loads
|
330 |
+
bool use_mlock = false; // use mlock to keep model in memory
|
331 |
+
bool verbose_prompt = false; // print prompt tokens before generation
|
332 |
+
bool display_prompt = true; // print prompt before generation
|
333 |
+
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
|
334 |
+
bool no_kv_offload = false; // disable KV offloading
|
335 |
+
bool warmup = true; // warmup run
|
336 |
+
bool check_tensors = false; // validate tensor data
|
337 |
+
|
338 |
+
bool single_turn = false; // single turn chat conversation
|
339 |
+
|
340 |
+
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
341 |
+
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
342 |
+
|
343 |
+
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
|
344 |
+
|
345 |
+
// multimodal models (see examples/llava)
|
346 |
+
std::string mmproj = ""; // path to multimodal projector // NOLINT
|
347 |
+
std::vector<std::string> image; // path to image file(s)
|
348 |
+
|
349 |
+
// embedding
|
350 |
+
bool embedding = false; // get only sentence embedding
|
351 |
+
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
|
352 |
+
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
|
353 |
+
std::string embd_sep = "\n"; // separator of embeddings
|
354 |
+
bool reranking = false; // enable reranking support on server
|
355 |
+
|
356 |
+
// server params
|
357 |
+
int32_t port = 8080; // server listens on this network port
|
358 |
+
int32_t timeout_read = 600; // http read timeout in seconds
|
359 |
+
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
360 |
+
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
361 |
+
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
362 |
+
|
363 |
+
std::string hostname = "127.0.0.1";
|
364 |
+
std::string public_path = ""; // NOLINT
|
365 |
+
std::string chat_template = ""; // NOLINT
|
366 |
+
bool use_jinja = false; // NOLINT
|
367 |
+
bool enable_chat_template = true;
|
368 |
+
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
369 |
+
|
370 |
+
std::vector<std::string> api_keys;
|
371 |
+
|
372 |
+
std::string ssl_file_key = ""; // NOLINT
|
373 |
+
std::string ssl_file_cert = ""; // NOLINT
|
374 |
+
|
375 |
+
// "advanced" endpoints are disabled by default for better security
|
376 |
+
bool webui = true;
|
377 |
+
bool endpoint_slots = false;
|
378 |
+
bool endpoint_props = false; // only control POST requests, not GET
|
379 |
+
bool endpoint_metrics = false;
|
380 |
+
|
381 |
+
bool log_json = false;
|
382 |
+
|
383 |
+
std::string slot_save_path;
|
384 |
+
|
385 |
+
float slot_prompt_similarity = 0.5f;
|
386 |
+
|
387 |
+
// batched-bench params
|
388 |
+
bool is_pp_shared = false;
|
389 |
+
|
390 |
+
std::vector<int32_t> n_pp;
|
391 |
+
std::vector<int32_t> n_tg;
|
392 |
+
std::vector<int32_t> n_pl;
|
393 |
+
|
394 |
+
// retrieval params
|
395 |
+
std::vector<std::string> context_files; // context files to embed
|
396 |
+
|
397 |
+
int32_t chunk_size = 64; // chunk size for context embedding
|
398 |
+
|
399 |
+
std::string chunk_separator = "\n"; // chunk separator for context embedding
|
400 |
+
|
401 |
+
// passkey params
|
402 |
+
int32_t n_junk = 250; // number of times to repeat the junk text
|
403 |
+
int32_t i_pos = -1; // position of the passkey in the junk text
|
404 |
+
|
405 |
+
// imatrix params
|
406 |
+
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
|
407 |
+
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
|
408 |
+
int32_t i_chunk = 0; // start processing from this chunk
|
409 |
+
|
410 |
+
bool process_output = false; // collect data for the output tensor
|
411 |
+
bool compute_ppl = true; // whether to compute perplexity
|
412 |
+
|
413 |
+
// cvector-generator params
|
414 |
+
int n_pca_batch = 100;
|
415 |
+
int n_pca_iterations = 1000;
|
416 |
+
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
|
417 |
+
std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
|
418 |
+
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
|
419 |
+
|
420 |
+
bool spm_infill = false; // suffix/prefix/middle pattern for infill
|
421 |
+
|
422 |
+
// batched-bench params
|
423 |
+
bool batched_bench_output_jsonl = false;
|
424 |
+
|
425 |
+
// common params
|
426 |
+
std::string out_file; // output filename for all example programs
|
427 |
+
};
|
428 |
+
|
429 |
+
// call once at the start of a program if it uses libcommon
|
430 |
+
// initializes the logging system and prints info about the build
|
431 |
+
void common_init();
|
432 |
+
|
433 |
+
std::string common_params_get_system_info(const common_params & params);
|
434 |
+
|
435 |
+
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
436 |
+
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
|
437 |
+
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
|
438 |
+
bool set_process_priority(enum ggml_sched_priority prio);
|
439 |
+
|
440 |
+
//
|
441 |
+
// String utils
|
442 |
+
//
|
443 |
+
|
444 |
+
#ifdef __GNUC__
|
445 |
+
# if defined(__MINGW32__) && !defined(__clang__)
|
446 |
+
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
447 |
+
# else
|
448 |
+
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
449 |
+
# endif
|
450 |
+
#else
|
451 |
+
# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
|
452 |
+
#endif
|
453 |
+
|
454 |
+
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
|
455 |
+
std::string string_format(const char * fmt, ...);
|
456 |
+
|
457 |
+
std::string string_strip(const std::string & str);
|
458 |
+
std::string string_get_sortable_timestamp();
|
459 |
+
|
460 |
+
std::string string_join(const std::vector<std::string> & values, const std::string & separator);
|
461 |
+
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
|
462 |
+
std::string string_repeat(const std::string & str, size_t n);
|
463 |
+
|
464 |
+
void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
|
465 |
+
|
466 |
+
std::string regex_escape(const std::string & s);
|
467 |
+
|
468 |
+
template<class T>
|
469 |
+
static std::vector<T> string_split(const std::string & str, char delim) {
|
470 |
+
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
|
471 |
+
std::vector<T> values;
|
472 |
+
std::istringstream str_stream(str);
|
473 |
+
std::string token;
|
474 |
+
while (std::getline(str_stream, token, delim)) {
|
475 |
+
T value;
|
476 |
+
std::istringstream token_stream(token);
|
477 |
+
token_stream >> value;
|
478 |
+
values.push_back(value);
|
479 |
+
}
|
480 |
+
return values;
|
481 |
+
}
|
482 |
+
|
483 |
+
template<>
|
484 |
+
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
|
485 |
+
{
|
486 |
+
std::vector<std::string> parts;
|
487 |
+
size_t begin_pos = 0;
|
488 |
+
size_t separator_pos = input.find(separator);
|
489 |
+
while (separator_pos != std::string::npos) {
|
490 |
+
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
|
491 |
+
parts.emplace_back(part);
|
492 |
+
begin_pos = separator_pos + 1;
|
493 |
+
separator_pos = input.find(separator, begin_pos);
|
494 |
+
}
|
495 |
+
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
|
496 |
+
return parts;
|
497 |
+
}
|
498 |
+
|
499 |
+
static bool string_starts_with(const std::string & str,
|
500 |
+
const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
|
501 |
+
return str.rfind(prefix, 0) == 0;
|
502 |
+
}
|
503 |
+
|
504 |
+
static bool string_ends_with(const std::string & str,
|
505 |
+
const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
|
506 |
+
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
|
507 |
+
}
|
508 |
+
|
509 |
+
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
|
510 |
+
void string_process_escapes(std::string & input);
|
511 |
+
|
512 |
+
std::string string_from(bool value);
|
513 |
+
std::string string_from(const std::vector<int> & values);
|
514 |
+
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
|
515 |
+
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
|
516 |
+
|
517 |
+
//
|
518 |
+
// Filesystem utils
|
519 |
+
//
|
520 |
+
|
521 |
+
bool fs_validate_filename(const std::string & filename);
|
522 |
+
bool fs_create_directory_with_parents(const std::string & path);
|
523 |
+
|
524 |
+
std::string fs_get_cache_directory();
|
525 |
+
std::string fs_get_cache_file(const std::string & filename);
|
526 |
+
|
527 |
+
//
|
528 |
+
// Model utils
|
529 |
+
//
|
530 |
+
|
531 |
+
// note: defines object's lifetime
|
532 |
+
struct common_init_result {
|
533 |
+
llama_model_ptr model;
|
534 |
+
llama_context_ptr context;
|
535 |
+
|
536 |
+
std::vector<llama_adapter_lora_ptr> lora;
|
537 |
+
};
|
538 |
+
|
539 |
+
struct common_init_result common_init_from_params(common_params & params);
|
540 |
+
|
541 |
+
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
542 |
+
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
543 |
+
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
|
544 |
+
|
545 |
+
struct llama_model * common_load_model_from_url(
|
546 |
+
const std::string & model_url,
|
547 |
+
const std::string & local_path,
|
548 |
+
const std::string & hf_token,
|
549 |
+
const struct llama_model_params & params);
|
550 |
+
|
551 |
+
struct llama_model * common_load_model_from_hf(
|
552 |
+
const std::string & repo,
|
553 |
+
const std::string & remote_path,
|
554 |
+
const std::string & local_path,
|
555 |
+
const std::string & hf_token,
|
556 |
+
const struct llama_model_params & params);
|
557 |
+
|
558 |
+
std::pair<std::string, std::string> common_get_hf_file(
|
559 |
+
const std::string & hf_repo_with_tag,
|
560 |
+
const std::string & hf_token);
|
561 |
+
|
562 |
+
// clear LoRA adapters from context, then apply new list of adapters
|
563 |
+
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
|
564 |
+
|
565 |
+
//
|
566 |
+
// Batch utils
|
567 |
+
//
|
568 |
+
|
569 |
+
void common_batch_clear(struct llama_batch & batch);
|
570 |
+
|
571 |
+
void common_batch_add(
|
572 |
+
struct llama_batch & batch,
|
573 |
+
llama_token id,
|
574 |
+
llama_pos pos,
|
575 |
+
const std::vector<llama_seq_id> & seq_ids,
|
576 |
+
bool logits);
|
577 |
+
|
578 |
+
//
|
579 |
+
// Token utils
|
580 |
+
//
|
581 |
+
|
582 |
+
// longest common prefix
|
583 |
+
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
|
584 |
+
|
585 |
+
// longet common subsequence
|
586 |
+
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
|
587 |
+
|
588 |
+
//
|
589 |
+
// Vocab utils
|
590 |
+
//
|
591 |
+
|
592 |
+
// tokenizes a string into a vector of tokens
|
593 |
+
// should work similar to Python's `tokenizer.encode`
|
594 |
+
std::vector<llama_token> common_tokenize(
|
595 |
+
const struct llama_context * ctx,
|
596 |
+
const std::string & text,
|
597 |
+
bool add_special,
|
598 |
+
bool parse_special = false);
|
599 |
+
|
600 |
+
std::vector<llama_token> common_tokenize(
|
601 |
+
const struct llama_vocab * vocab,
|
602 |
+
const std::string & text,
|
603 |
+
bool add_special,
|
604 |
+
bool parse_special = false);
|
605 |
+
|
606 |
+
// tokenizes a token into a piece, optionally renders special/control tokens
|
607 |
+
// should work similar to Python's `tokenizer.id_to_piece`
|
608 |
+
std::string common_token_to_piece(
|
609 |
+
const struct llama_context * ctx,
|
610 |
+
llama_token token,
|
611 |
+
bool special = true);
|
612 |
+
|
613 |
+
std::string common_token_to_piece(
|
614 |
+
const struct llama_vocab * vocab,
|
615 |
+
llama_token token,
|
616 |
+
bool special = true);
|
617 |
+
|
618 |
+
// detokenizes a vector of tokens into a string
|
619 |
+
// should work similar to Python's `tokenizer.decode`
|
620 |
+
// optionally renders special/control tokens
|
621 |
+
std::string common_detokenize(
|
622 |
+
const struct llama_context * ctx,
|
623 |
+
const std::vector<llama_token> & tokens,
|
624 |
+
bool special = true);
|
625 |
+
|
626 |
+
std::string common_detokenize(
|
627 |
+
const struct llama_vocab * vocab,
|
628 |
+
const std::vector<llama_token> & tokens,
|
629 |
+
bool special = true);
|
630 |
+
|
631 |
+
//
|
632 |
+
// KV cache utils
|
633 |
+
//
|
634 |
+
|
635 |
+
// Dump the KV cache view with the number of sequences per cell.
|
636 |
+
void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
|
637 |
+
|
638 |
+
// Dump the KV cache view showing individual sequences in each cell (long output).
|
639 |
+
void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
|
640 |
+
|
641 |
+
//
|
642 |
+
// Embedding utils
|
643 |
+
//
|
644 |
+
|
645 |
+
// TODO: repace embd_norm with an enum
|
646 |
+
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
|
647 |
+
|
648 |
+
float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
|
649 |
+
|
650 |
+
//
|
651 |
+
// Control vector utils
|
652 |
+
//
|
653 |
+
|
654 |
+
struct common_control_vector_data {
|
655 |
+
int n_embd;
|
656 |
+
|
657 |
+
// stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
|
658 |
+
std::vector<float> data;
|
659 |
+
};
|
660 |
+
|
661 |
+
struct common_control_vector_load_info {
|
662 |
+
float strength;
|
663 |
+
|
664 |
+
std::string fname;
|
665 |
+
};
|
666 |
+
|
667 |
+
// Load control vectors, scale each by strength, and add them together.
|
668 |
+
// On error, returns {-1, empty}
|
669 |
+
common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
|
670 |
+
|
671 |
+
//
|
672 |
+
// Split utils
|
673 |
+
//
|
674 |
+
|
675 |
+
namespace {
|
676 |
+
|
677 |
+
const char * const LLM_KV_SPLIT_NO = "split.no";
|
678 |
+
const char * const LLM_KV_SPLIT_COUNT = "split.count";
|
679 |
+
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
680 |
+
|
681 |
+
}
|
common/console.cpp
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "console.h"
|
2 |
+
#include <vector>
|
3 |
+
#include <iostream>
|
4 |
+
|
5 |
+
#if defined(_WIN32)
|
6 |
+
#define WIN32_LEAN_AND_MEAN
|
7 |
+
#ifndef NOMINMAX
|
8 |
+
#define NOMINMAX
|
9 |
+
#endif
|
10 |
+
#include <windows.h>
|
11 |
+
#include <fcntl.h>
|
12 |
+
#include <io.h>
|
13 |
+
#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING
|
14 |
+
#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004
|
15 |
+
#endif
|
16 |
+
#else
|
17 |
+
#include <climits>
|
18 |
+
#include <sys/ioctl.h>
|
19 |
+
#include <unistd.h>
|
20 |
+
#include <wchar.h>
|
21 |
+
#include <stdio.h>
|
22 |
+
#include <stdlib.h>
|
23 |
+
#include <signal.h>
|
24 |
+
#include <termios.h>
|
25 |
+
#endif
|
26 |
+
|
27 |
+
#define ANSI_COLOR_RED "\x1b[31m"
|
28 |
+
#define ANSI_COLOR_GREEN "\x1b[32m"
|
29 |
+
#define ANSI_COLOR_YELLOW "\x1b[33m"
|
30 |
+
#define ANSI_COLOR_BLUE "\x1b[34m"
|
31 |
+
#define ANSI_COLOR_MAGENTA "\x1b[35m"
|
32 |
+
#define ANSI_COLOR_CYAN "\x1b[36m"
|
33 |
+
#define ANSI_COLOR_RESET "\x1b[0m"
|
34 |
+
#define ANSI_BOLD "\x1b[1m"
|
35 |
+
|
36 |
+
namespace console {
|
37 |
+
|
38 |
+
//
|
39 |
+
// Console state
|
40 |
+
//
|
41 |
+
|
42 |
+
static bool advanced_display = false;
|
43 |
+
static bool simple_io = true;
|
44 |
+
static display_t current_display = reset;
|
45 |
+
|
46 |
+
static FILE* out = stdout;
|
47 |
+
|
48 |
+
#if defined (_WIN32)
|
49 |
+
static void* hConsole;
|
50 |
+
#else
|
51 |
+
static FILE* tty = nullptr;
|
52 |
+
static termios initial_state;
|
53 |
+
#endif
|
54 |
+
|
55 |
+
//
|
56 |
+
// Init and cleanup
|
57 |
+
//
|
58 |
+
|
59 |
+
void init(bool use_simple_io, bool use_advanced_display) {
|
60 |
+
advanced_display = use_advanced_display;
|
61 |
+
simple_io = use_simple_io;
|
62 |
+
#if defined(_WIN32)
|
63 |
+
// Windows-specific console initialization
|
64 |
+
DWORD dwMode = 0;
|
65 |
+
hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
|
66 |
+
if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
|
67 |
+
hConsole = GetStdHandle(STD_ERROR_HANDLE);
|
68 |
+
if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) {
|
69 |
+
hConsole = nullptr;
|
70 |
+
simple_io = true;
|
71 |
+
}
|
72 |
+
}
|
73 |
+
if (hConsole) {
|
74 |
+
// Check conditions combined to reduce nesting
|
75 |
+
if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) &&
|
76 |
+
!SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
|
77 |
+
advanced_display = false;
|
78 |
+
}
|
79 |
+
// Set console output codepage to UTF8
|
80 |
+
SetConsoleOutputCP(CP_UTF8);
|
81 |
+
}
|
82 |
+
HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
|
83 |
+
if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
|
84 |
+
// Set console input codepage to UTF16
|
85 |
+
_setmode(_fileno(stdin), _O_WTEXT);
|
86 |
+
|
87 |
+
// Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
|
88 |
+
if (simple_io) {
|
89 |
+
dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT;
|
90 |
+
} else {
|
91 |
+
dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
|
92 |
+
}
|
93 |
+
if (!SetConsoleMode(hConIn, dwMode)) {
|
94 |
+
simple_io = true;
|
95 |
+
}
|
96 |
+
}
|
97 |
+
if (simple_io) {
|
98 |
+
_setmode(_fileno(stdin), _O_U8TEXT);
|
99 |
+
}
|
100 |
+
#else
|
101 |
+
// POSIX-specific console initialization
|
102 |
+
if (!simple_io) {
|
103 |
+
struct termios new_termios;
|
104 |
+
tcgetattr(STDIN_FILENO, &initial_state);
|
105 |
+
new_termios = initial_state;
|
106 |
+
new_termios.c_lflag &= ~(ICANON | ECHO);
|
107 |
+
new_termios.c_cc[VMIN] = 1;
|
108 |
+
new_termios.c_cc[VTIME] = 0;
|
109 |
+
tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
|
110 |
+
|
111 |
+
tty = fopen("/dev/tty", "w+");
|
112 |
+
if (tty != nullptr) {
|
113 |
+
out = tty;
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
setlocale(LC_ALL, "");
|
118 |
+
#endif
|
119 |
+
}
|
120 |
+
|
121 |
+
void cleanup() {
|
122 |
+
// Reset console display
|
123 |
+
set_display(reset);
|
124 |
+
|
125 |
+
#if !defined(_WIN32)
|
126 |
+
// Restore settings on POSIX systems
|
127 |
+
if (!simple_io) {
|
128 |
+
if (tty != nullptr) {
|
129 |
+
out = stdout;
|
130 |
+
fclose(tty);
|
131 |
+
tty = nullptr;
|
132 |
+
}
|
133 |
+
tcsetattr(STDIN_FILENO, TCSANOW, &initial_state);
|
134 |
+
}
|
135 |
+
#endif
|
136 |
+
}
|
137 |
+
|
138 |
+
//
|
139 |
+
// Display and IO
|
140 |
+
//
|
141 |
+
|
142 |
+
// Keep track of current display and only emit ANSI code if it changes
|
143 |
+
void set_display(display_t display) {
|
144 |
+
if (advanced_display && current_display != display) {
|
145 |
+
fflush(stdout);
|
146 |
+
switch(display) {
|
147 |
+
case reset:
|
148 |
+
fprintf(out, ANSI_COLOR_RESET);
|
149 |
+
break;
|
150 |
+
case prompt:
|
151 |
+
fprintf(out, ANSI_COLOR_YELLOW);
|
152 |
+
break;
|
153 |
+
case user_input:
|
154 |
+
fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
|
155 |
+
break;
|
156 |
+
case error:
|
157 |
+
fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
|
158 |
+
}
|
159 |
+
current_display = display;
|
160 |
+
fflush(out);
|
161 |
+
}
|
162 |
+
}
|
163 |
+
|
164 |
+
static char32_t getchar32() {
|
165 |
+
#if defined(_WIN32)
|
166 |
+
HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE);
|
167 |
+
wchar_t high_surrogate = 0;
|
168 |
+
|
169 |
+
while (true) {
|
170 |
+
INPUT_RECORD record;
|
171 |
+
DWORD count;
|
172 |
+
if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) {
|
173 |
+
return WEOF;
|
174 |
+
}
|
175 |
+
|
176 |
+
if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) {
|
177 |
+
wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar;
|
178 |
+
if (wc == 0) {
|
179 |
+
continue;
|
180 |
+
}
|
181 |
+
|
182 |
+
if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
|
183 |
+
high_surrogate = wc;
|
184 |
+
continue;
|
185 |
+
}
|
186 |
+
if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate
|
187 |
+
if (high_surrogate != 0) { // Check if we have a high surrogate
|
188 |
+
return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000;
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
high_surrogate = 0; // Reset the high surrogate
|
193 |
+
return static_cast<char32_t>(wc);
|
194 |
+
}
|
195 |
+
}
|
196 |
+
#else
|
197 |
+
wchar_t wc = getwchar();
|
198 |
+
if (static_cast<wint_t>(wc) == WEOF) {
|
199 |
+
return WEOF;
|
200 |
+
}
|
201 |
+
|
202 |
+
#if WCHAR_MAX == 0xFFFF
|
203 |
+
if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
|
204 |
+
wchar_t low_surrogate = getwchar();
|
205 |
+
if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate
|
206 |
+
return (static_cast<char32_t>(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000;
|
207 |
+
}
|
208 |
+
}
|
209 |
+
if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair
|
210 |
+
return 0xFFFD; // Return the replacement character U+FFFD
|
211 |
+
}
|
212 |
+
#endif
|
213 |
+
|
214 |
+
return static_cast<char32_t>(wc);
|
215 |
+
#endif
|
216 |
+
}
|
217 |
+
|
218 |
+
static void pop_cursor() {
|
219 |
+
#if defined(_WIN32)
|
220 |
+
if (hConsole != NULL) {
|
221 |
+
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
|
222 |
+
GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
|
223 |
+
|
224 |
+
COORD newCursorPosition = bufferInfo.dwCursorPosition;
|
225 |
+
if (newCursorPosition.X == 0) {
|
226 |
+
newCursorPosition.X = bufferInfo.dwSize.X - 1;
|
227 |
+
newCursorPosition.Y -= 1;
|
228 |
+
} else {
|
229 |
+
newCursorPosition.X -= 1;
|
230 |
+
}
|
231 |
+
|
232 |
+
SetConsoleCursorPosition(hConsole, newCursorPosition);
|
233 |
+
return;
|
234 |
+
}
|
235 |
+
#endif
|
236 |
+
putc('\b', out);
|
237 |
+
}
|
238 |
+
|
239 |
+
static int estimateWidth(char32_t codepoint) {
|
240 |
+
#if defined(_WIN32)
|
241 |
+
(void)codepoint;
|
242 |
+
return 1;
|
243 |
+
#else
|
244 |
+
return wcwidth(codepoint);
|
245 |
+
#endif
|
246 |
+
}
|
247 |
+
|
248 |
+
static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) {
|
249 |
+
#if defined(_WIN32)
|
250 |
+
CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
|
251 |
+
if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) {
|
252 |
+
// go with the default
|
253 |
+
return expectedWidth;
|
254 |
+
}
|
255 |
+
COORD initialPosition = bufferInfo.dwCursorPosition;
|
256 |
+
DWORD nNumberOfChars = length;
|
257 |
+
WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
|
258 |
+
|
259 |
+
CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
|
260 |
+
GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
|
261 |
+
|
262 |
+
// Figure out our real position if we're in the last column
|
263 |
+
if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
|
264 |
+
DWORD nNumberOfChars;
|
265 |
+
WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL);
|
266 |
+
GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
|
267 |
+
}
|
268 |
+
|
269 |
+
int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
|
270 |
+
if (width < 0) {
|
271 |
+
width += newBufferInfo.dwSize.X;
|
272 |
+
}
|
273 |
+
return width;
|
274 |
+
#else
|
275 |
+
// We can trust expectedWidth if we've got one
|
276 |
+
if (expectedWidth >= 0 || tty == nullptr) {
|
277 |
+
fwrite(utf8_codepoint, length, 1, out);
|
278 |
+
return expectedWidth;
|
279 |
+
}
|
280 |
+
|
281 |
+
fputs("\033[6n", tty); // Query cursor position
|
282 |
+
int x1;
|
283 |
+
int y1;
|
284 |
+
int x2;
|
285 |
+
int y2;
|
286 |
+
int results = 0;
|
287 |
+
results = fscanf(tty, "\033[%d;%dR", &y1, &x1);
|
288 |
+
|
289 |
+
fwrite(utf8_codepoint, length, 1, tty);
|
290 |
+
|
291 |
+
fputs("\033[6n", tty); // Query cursor position
|
292 |
+
results += fscanf(tty, "\033[%d;%dR", &y2, &x2);
|
293 |
+
|
294 |
+
if (results != 4) {
|
295 |
+
return expectedWidth;
|
296 |
+
}
|
297 |
+
|
298 |
+
int width = x2 - x1;
|
299 |
+
if (width < 0) {
|
300 |
+
// Calculate the width considering text wrapping
|
301 |
+
struct winsize w;
|
302 |
+
ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
|
303 |
+
width += w.ws_col;
|
304 |
+
}
|
305 |
+
return width;
|
306 |
+
#endif
|
307 |
+
}
|
308 |
+
|
309 |
+
static void replace_last(char ch) {
|
310 |
+
#if defined(_WIN32)
|
311 |
+
pop_cursor();
|
312 |
+
put_codepoint(&ch, 1, 1);
|
313 |
+
#else
|
314 |
+
fprintf(out, "\b%c", ch);
|
315 |
+
#endif
|
316 |
+
}
|
317 |
+
|
318 |
+
static void append_utf8(char32_t ch, std::string & out) {
|
319 |
+
if (ch <= 0x7F) {
|
320 |
+
out.push_back(static_cast<unsigned char>(ch));
|
321 |
+
} else if (ch <= 0x7FF) {
|
322 |
+
out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
|
323 |
+
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
|
324 |
+
} else if (ch <= 0xFFFF) {
|
325 |
+
out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
|
326 |
+
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
|
327 |
+
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
|
328 |
+
} else if (ch <= 0x10FFFF) {
|
329 |
+
out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
|
330 |
+
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
|
331 |
+
out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
|
332 |
+
out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
|
333 |
+
} else {
|
334 |
+
// Invalid Unicode code point
|
335 |
+
}
|
336 |
+
}
|
337 |
+
|
338 |
+
// Helper function to remove the last UTF-8 character from a string
|
339 |
+
static void pop_back_utf8_char(std::string & line) {
|
340 |
+
if (line.empty()) {
|
341 |
+
return;
|
342 |
+
}
|
343 |
+
|
344 |
+
size_t pos = line.length() - 1;
|
345 |
+
|
346 |
+
// Find the start of the last UTF-8 character (checking up to 4 bytes back)
|
347 |
+
for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
|
348 |
+
if ((line[pos] & 0xC0) != 0x80) {
|
349 |
+
break; // Found the start of the character
|
350 |
+
}
|
351 |
+
}
|
352 |
+
line.erase(pos);
|
353 |
+
}
|
354 |
+
|
355 |
+
static bool readline_advanced(std::string & line, bool multiline_input) {
|
356 |
+
if (out != stdout) {
|
357 |
+
fflush(stdout);
|
358 |
+
}
|
359 |
+
|
360 |
+
line.clear();
|
361 |
+
std::vector<int> widths;
|
362 |
+
bool is_special_char = false;
|
363 |
+
bool end_of_stream = false;
|
364 |
+
|
365 |
+
char32_t input_char;
|
366 |
+
while (true) {
|
367 |
+
fflush(out); // Ensure all output is displayed before waiting for input
|
368 |
+
input_char = getchar32();
|
369 |
+
|
370 |
+
if (input_char == '\r' || input_char == '\n') {
|
371 |
+
break;
|
372 |
+
}
|
373 |
+
|
374 |
+
if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) {
|
375 |
+
end_of_stream = true;
|
376 |
+
break;
|
377 |
+
}
|
378 |
+
|
379 |
+
if (is_special_char) {
|
380 |
+
set_display(user_input);
|
381 |
+
replace_last(line.back());
|
382 |
+
is_special_char = false;
|
383 |
+
}
|
384 |
+
|
385 |
+
if (input_char == '\033') { // Escape sequence
|
386 |
+
char32_t code = getchar32();
|
387 |
+
if (code == '[' || code == 0x1B) {
|
388 |
+
// Discard the rest of the escape sequence
|
389 |
+
while ((code = getchar32()) != (char32_t) WEOF) {
|
390 |
+
if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
|
391 |
+
break;
|
392 |
+
}
|
393 |
+
}
|
394 |
+
}
|
395 |
+
} else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
|
396 |
+
if (!widths.empty()) {
|
397 |
+
int count;
|
398 |
+
do {
|
399 |
+
count = widths.back();
|
400 |
+
widths.pop_back();
|
401 |
+
// Move cursor back, print space, and move cursor back again
|
402 |
+
for (int i = 0; i < count; i++) {
|
403 |
+
replace_last(' ');
|
404 |
+
pop_cursor();
|
405 |
+
}
|
406 |
+
pop_back_utf8_char(line);
|
407 |
+
} while (count == 0 && !widths.empty());
|
408 |
+
}
|
409 |
+
} else {
|
410 |
+
int offset = line.length();
|
411 |
+
append_utf8(input_char, line);
|
412 |
+
int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
|
413 |
+
if (width < 0) {
|
414 |
+
width = 0;
|
415 |
+
}
|
416 |
+
widths.push_back(width);
|
417 |
+
}
|
418 |
+
|
419 |
+
if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
|
420 |
+
set_display(prompt);
|
421 |
+
replace_last(line.back());
|
422 |
+
is_special_char = true;
|
423 |
+
}
|
424 |
+
}
|
425 |
+
|
426 |
+
bool has_more = multiline_input;
|
427 |
+
if (is_special_char) {
|
428 |
+
replace_last(' ');
|
429 |
+
pop_cursor();
|
430 |
+
|
431 |
+
char last = line.back();
|
432 |
+
line.pop_back();
|
433 |
+
if (last == '\\') {
|
434 |
+
line += '\n';
|
435 |
+
fputc('\n', out);
|
436 |
+
has_more = !has_more;
|
437 |
+
} else {
|
438 |
+
// llama will just eat the single space, it won't act as a space
|
439 |
+
if (line.length() == 1 && line.back() == ' ') {
|
440 |
+
line.clear();
|
441 |
+
pop_cursor();
|
442 |
+
}
|
443 |
+
has_more = false;
|
444 |
+
}
|
445 |
+
} else {
|
446 |
+
if (end_of_stream) {
|
447 |
+
has_more = false;
|
448 |
+
} else {
|
449 |
+
line += '\n';
|
450 |
+
fputc('\n', out);
|
451 |
+
}
|
452 |
+
}
|
453 |
+
|
454 |
+
fflush(out);
|
455 |
+
return has_more;
|
456 |
+
}
|
457 |
+
|
458 |
+
static bool readline_simple(std::string & line, bool multiline_input) {
|
459 |
+
#if defined(_WIN32)
|
460 |
+
std::wstring wline;
|
461 |
+
if (!std::getline(std::wcin, wline)) {
|
462 |
+
// Input stream is bad or EOF received
|
463 |
+
line.clear();
|
464 |
+
GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
|
465 |
+
return false;
|
466 |
+
}
|
467 |
+
|
468 |
+
int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
|
469 |
+
line.resize(size_needed);
|
470 |
+
WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
|
471 |
+
#else
|
472 |
+
if (!std::getline(std::cin, line)) {
|
473 |
+
// Input stream is bad or EOF received
|
474 |
+
line.clear();
|
475 |
+
return false;
|
476 |
+
}
|
477 |
+
#endif
|
478 |
+
if (!line.empty()) {
|
479 |
+
char last = line.back();
|
480 |
+
if (last == '/') { // Always return control on '/' symbol
|
481 |
+
line.pop_back();
|
482 |
+
return false;
|
483 |
+
}
|
484 |
+
if (last == '\\') { // '\\' changes the default action
|
485 |
+
line.pop_back();
|
486 |
+
multiline_input = !multiline_input;
|
487 |
+
}
|
488 |
+
}
|
489 |
+
line += '\n';
|
490 |
+
|
491 |
+
// By default, continue input if multiline_input is set
|
492 |
+
return multiline_input;
|
493 |
+
}
|
494 |
+
|
495 |
+
bool readline(std::string & line, bool multiline_input) {
|
496 |
+
set_display(user_input);
|
497 |
+
|
498 |
+
if (simple_io) {
|
499 |
+
return readline_simple(line, multiline_input);
|
500 |
+
}
|
501 |
+
return readline_advanced(line, multiline_input);
|
502 |
+
}
|
503 |
+
|
504 |
+
}
|
common/console.h
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Console functions
|
2 |
+
|
3 |
+
#pragma once
|
4 |
+
|
5 |
+
#include <string>
|
6 |
+
|
7 |
+
namespace console {
|
8 |
+
enum display_t {
|
9 |
+
reset = 0,
|
10 |
+
prompt,
|
11 |
+
user_input,
|
12 |
+
error
|
13 |
+
};
|
14 |
+
|
15 |
+
void init(bool use_simple_io, bool use_advanced_display);
|
16 |
+
void cleanup();
|
17 |
+
void set_display(display_t display);
|
18 |
+
bool readline(std::string & line, bool multiline_input);
|
19 |
+
}
|
common/json-schema-to-grammar.cpp
ADDED
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "json-schema-to-grammar.h"
|
2 |
+
#include "common.h"
|
3 |
+
|
4 |
+
#include <algorithm>
|
5 |
+
#include <fstream>
|
6 |
+
#include <map>
|
7 |
+
#include <regex>
|
8 |
+
#include <sstream>
|
9 |
+
#include <string>
|
10 |
+
#include <unordered_map>
|
11 |
+
#include <unordered_set>
|
12 |
+
#include <vector>
|
13 |
+
|
14 |
+
using json = nlohmann::ordered_json;
|
15 |
+
|
16 |
+
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
|
17 |
+
auto has_max = max_items != std::numeric_limits<int>::max();
|
18 |
+
|
19 |
+
if (min_items == 0 && max_items == 1) {
|
20 |
+
return item_rule + "?";
|
21 |
+
}
|
22 |
+
|
23 |
+
if (separator_rule.empty()) {
|
24 |
+
if (min_items == 1 && !has_max) {
|
25 |
+
return item_rule + "+";
|
26 |
+
} else if (min_items == 0 && !has_max) {
|
27 |
+
return item_rule + "*";
|
28 |
+
} else {
|
29 |
+
return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
|
34 |
+
if (min_items == 0) {
|
35 |
+
result = "(" + result + ")?";
|
36 |
+
}
|
37 |
+
return result;
|
38 |
+
}
|
39 |
+
|
40 |
+
/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
|
41 |
+
class string_view {
|
42 |
+
const std::string & _str;
|
43 |
+
const size_t _start;
|
44 |
+
const size_t _end;
|
45 |
+
public:
|
46 |
+
string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
|
47 |
+
|
48 |
+
size_t size() const {
|
49 |
+
return _end - _start;
|
50 |
+
}
|
51 |
+
|
52 |
+
size_t length() const {
|
53 |
+
return size();
|
54 |
+
}
|
55 |
+
|
56 |
+
operator std::string() const {
|
57 |
+
return str();
|
58 |
+
}
|
59 |
+
|
60 |
+
std::string str() const {
|
61 |
+
return _str.substr(_start, _end - _start);
|
62 |
+
}
|
63 |
+
|
64 |
+
string_view substr(size_t pos, size_t len = std::string::npos) const {
|
65 |
+
return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
|
66 |
+
}
|
67 |
+
|
68 |
+
char operator[](size_t pos) const {
|
69 |
+
auto index = _start + pos;
|
70 |
+
if (index >= _end) {
|
71 |
+
throw std::out_of_range("string_view index out of range");
|
72 |
+
}
|
73 |
+
return _str[_start + pos];
|
74 |
+
}
|
75 |
+
|
76 |
+
bool operator==(const string_view & other) const {
|
77 |
+
std::string this_str = *this;
|
78 |
+
std::string other_str = other;
|
79 |
+
return this_str == other_str;
|
80 |
+
}
|
81 |
+
};
|
82 |
+
|
83 |
+
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
84 |
+
auto has_min = min_value != std::numeric_limits<int>::min();
|
85 |
+
auto has_max = max_value != std::numeric_limits<int>::max();
|
86 |
+
|
87 |
+
auto digit_range = [&](char from, char to) {
|
88 |
+
out << "[";
|
89 |
+
if (from == to) {
|
90 |
+
out << from;
|
91 |
+
} else {
|
92 |
+
out << from << "-" << to;
|
93 |
+
}
|
94 |
+
out << "]";
|
95 |
+
};
|
96 |
+
auto more_digits = [&](int min_digits, int max_digits) {
|
97 |
+
out << "[0-9]";
|
98 |
+
if (min_digits == max_digits && min_digits == 1) {
|
99 |
+
return;
|
100 |
+
}
|
101 |
+
out << "{";
|
102 |
+
out << min_digits;
|
103 |
+
if (max_digits != min_digits) {
|
104 |
+
out << ",";
|
105 |
+
if (max_digits != std::numeric_limits<int>::max()) {
|
106 |
+
out << max_digits;
|
107 |
+
}
|
108 |
+
}
|
109 |
+
out << "}";
|
110 |
+
};
|
111 |
+
std::function<void(const string_view &, const string_view &)> uniform_range =
|
112 |
+
[&](const string_view & from, const string_view & to) {
|
113 |
+
size_t i = 0;
|
114 |
+
while (i < from.length() && i < to.length() && from[i] == to[i]) {
|
115 |
+
i++;
|
116 |
+
}
|
117 |
+
if (i > 0) {
|
118 |
+
out << "\"" << from.substr(0, i).str() << "\"";
|
119 |
+
}
|
120 |
+
if (i < from.length() && i < to.length()) {
|
121 |
+
if (i > 0) {
|
122 |
+
out << " ";
|
123 |
+
}
|
124 |
+
auto sub_len = from.length() - i - 1;
|
125 |
+
if (sub_len > 0) {
|
126 |
+
auto from_sub = from.substr(i + 1);
|
127 |
+
auto to_sub = to.substr(i + 1);
|
128 |
+
auto sub_zeros = string_repeat("0", sub_len);
|
129 |
+
auto sub_nines = string_repeat("9", sub_len);
|
130 |
+
|
131 |
+
auto to_reached = false;
|
132 |
+
out << "(";
|
133 |
+
if (from_sub == sub_zeros) {
|
134 |
+
digit_range(from[i], to[i] - 1);
|
135 |
+
out << " ";
|
136 |
+
more_digits(sub_len, sub_len);
|
137 |
+
} else {
|
138 |
+
out << "[" << from[i] << "] ";
|
139 |
+
out << "(";
|
140 |
+
uniform_range(from_sub, sub_nines);
|
141 |
+
out << ")";
|
142 |
+
if (from[i] < to[i] - 1) {
|
143 |
+
out << " | ";
|
144 |
+
if (to_sub == sub_nines) {
|
145 |
+
digit_range(from[i] + 1, to[i]);
|
146 |
+
to_reached = true;
|
147 |
+
} else {
|
148 |
+
digit_range(from[i] + 1, to[i] - 1);
|
149 |
+
}
|
150 |
+
out << " ";
|
151 |
+
more_digits(sub_len, sub_len);
|
152 |
+
}
|
153 |
+
}
|
154 |
+
if (!to_reached) {
|
155 |
+
out << " | ";
|
156 |
+
digit_range(to[i], to[i]);
|
157 |
+
out << " ";
|
158 |
+
uniform_range(sub_zeros, to_sub);
|
159 |
+
}
|
160 |
+
out << ")";
|
161 |
+
} else {
|
162 |
+
out << "[" << from[i] << "-" << to[i] << "]";
|
163 |
+
}
|
164 |
+
}
|
165 |
+
};
|
166 |
+
|
167 |
+
if (has_min && has_max) {
|
168 |
+
if (min_value < 0 && max_value < 0) {
|
169 |
+
out << "\"-\" (";
|
170 |
+
_build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
|
171 |
+
out << ")";
|
172 |
+
return;
|
173 |
+
}
|
174 |
+
|
175 |
+
if (min_value < 0) {
|
176 |
+
out << "\"-\" (";
|
177 |
+
_build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
|
178 |
+
out << ") | ";
|
179 |
+
min_value = 0;
|
180 |
+
}
|
181 |
+
|
182 |
+
auto min_s = std::to_string(min_value);
|
183 |
+
auto max_s = std::to_string(max_value);
|
184 |
+
auto min_digits = min_s.length();
|
185 |
+
auto max_digits = max_s.length();
|
186 |
+
|
187 |
+
for (auto digits = min_digits; digits < max_digits; digits++) {
|
188 |
+
uniform_range(min_s, string_repeat("9", digits));
|
189 |
+
min_s = "1" + string_repeat("0", digits);
|
190 |
+
out << " | ";
|
191 |
+
}
|
192 |
+
uniform_range(min_s, max_s);
|
193 |
+
return;
|
194 |
+
}
|
195 |
+
|
196 |
+
auto less_decimals = std::max(decimals_left - 1, 1);
|
197 |
+
|
198 |
+
if (has_min) {
|
199 |
+
if (min_value < 0) {
|
200 |
+
out << "\"-\" (";
|
201 |
+
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
202 |
+
out << ") | [0] | [1-9] ";
|
203 |
+
more_digits(0, decimals_left - 1);
|
204 |
+
} else if (min_value == 0) {
|
205 |
+
if (top_level) {
|
206 |
+
out << "[0] | [1-9] ";
|
207 |
+
more_digits(0, less_decimals);
|
208 |
+
} else {
|
209 |
+
more_digits(1, decimals_left);
|
210 |
+
}
|
211 |
+
} else if (min_value <= 9) {
|
212 |
+
char c = '0' + min_value;
|
213 |
+
auto range_start = top_level ? '1' : '0';
|
214 |
+
if (c > range_start) {
|
215 |
+
digit_range(range_start, c - 1);
|
216 |
+
out << " ";
|
217 |
+
more_digits(1, less_decimals);
|
218 |
+
out << " | ";
|
219 |
+
}
|
220 |
+
digit_range(c, '9');
|
221 |
+
out << " ";
|
222 |
+
more_digits(0, less_decimals);
|
223 |
+
} else {
|
224 |
+
auto min_s = std::to_string(min_value);
|
225 |
+
auto len = min_s.length();
|
226 |
+
auto c = min_s[0];
|
227 |
+
|
228 |
+
if (c > '1') {
|
229 |
+
digit_range(top_level ? '1' : '0', c - 1);
|
230 |
+
out << " ";
|
231 |
+
more_digits(len, less_decimals);
|
232 |
+
out << " | ";
|
233 |
+
}
|
234 |
+
digit_range(c, c);
|
235 |
+
out << " (";
|
236 |
+
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
|
237 |
+
out << ")";
|
238 |
+
if (c < '9') {
|
239 |
+
out << " | ";
|
240 |
+
digit_range(c + 1, '9');
|
241 |
+
out << " ";
|
242 |
+
more_digits(len - 1, less_decimals);
|
243 |
+
}
|
244 |
+
}
|
245 |
+
return;
|
246 |
+
}
|
247 |
+
|
248 |
+
if (has_max) {
|
249 |
+
if (max_value >= 0) {
|
250 |
+
if (top_level) {
|
251 |
+
out << "\"-\" [1-9] ";
|
252 |
+
more_digits(0, less_decimals);
|
253 |
+
out << " | ";
|
254 |
+
}
|
255 |
+
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
256 |
+
} else {
|
257 |
+
out << "\"-\" (";
|
258 |
+
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
|
259 |
+
out << ")";
|
260 |
+
}
|
261 |
+
return;
|
262 |
+
}
|
263 |
+
|
264 |
+
throw std::runtime_error("At least one of min_value or max_value must be set");
|
265 |
+
}
|
266 |
+
|
267 |
+
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
|
268 |
+
|
269 |
+
struct BuiltinRule {
|
270 |
+
std::string content;
|
271 |
+
std::vector<std::string> deps;
|
272 |
+
};
|
273 |
+
|
274 |
+
std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
|
275 |
+
{"boolean", {"(\"true\" | \"false\") space", {}}},
|
276 |
+
{"decimal-part", {"[0-9]{1,16}", {}}},
|
277 |
+
{"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
|
278 |
+
{"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
|
279 |
+
{"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
|
280 |
+
{"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
|
281 |
+
{"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
|
282 |
+
{"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
|
283 |
+
{"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
|
284 |
+
{"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
|
285 |
+
{"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
|
286 |
+
{"null", {"\"null\" space", {}}},
|
287 |
+
};
|
288 |
+
|
289 |
+
std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
|
290 |
+
{"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
|
291 |
+
{"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
|
292 |
+
{"date-time", {"date \"T\" time", {"date", "time"}}},
|
293 |
+
{"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
|
294 |
+
{"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
|
295 |
+
{"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
|
296 |
+
};
|
297 |
+
|
298 |
+
static bool is_reserved_name(const std::string & name) {
|
299 |
+
static std::unordered_set<std::string> RESERVED_NAMES;
|
300 |
+
if (RESERVED_NAMES.empty()) {
|
301 |
+
RESERVED_NAMES.insert("root");
|
302 |
+
for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
|
303 |
+
for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first);
|
304 |
+
}
|
305 |
+
return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
|
306 |
+
}
|
307 |
+
|
308 |
+
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
|
309 |
+
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
|
310 |
+
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
|
311 |
+
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
|
312 |
+
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
|
313 |
+
};
|
314 |
+
|
315 |
+
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
|
316 |
+
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
|
317 |
+
|
318 |
+
static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
|
319 |
+
std::smatch match;
|
320 |
+
std::string result;
|
321 |
+
|
322 |
+
std::string::const_iterator searchStart(input.cbegin());
|
323 |
+
std::string::const_iterator searchEnd(input.cend());
|
324 |
+
|
325 |
+
while (std::regex_search(searchStart, searchEnd, match, regex)) {
|
326 |
+
result.append(searchStart, searchStart + match.position());
|
327 |
+
result.append(replacement(match));
|
328 |
+
searchStart = match.suffix().first;
|
329 |
+
}
|
330 |
+
|
331 |
+
result.append(searchStart, searchEnd);
|
332 |
+
|
333 |
+
return result;
|
334 |
+
}
|
335 |
+
|
336 |
+
static std::string format_literal(const std::string & literal) {
|
337 |
+
std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
|
338 |
+
char c = match.str()[0];
|
339 |
+
return GRAMMAR_LITERAL_ESCAPES.at(c);
|
340 |
+
});
|
341 |
+
return "\"" + escaped + "\"";
|
342 |
+
}
|
343 |
+
|
344 |
+
class SchemaConverter {
|
345 |
+
private:
|
346 |
+
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
|
347 |
+
std::function<json(const std::string &)> _fetch_json;
|
348 |
+
bool _dotall;
|
349 |
+
std::map<std::string, std::string> _rules;
|
350 |
+
std::unordered_map<std::string, json> _refs;
|
351 |
+
std::unordered_set<std::string> _refs_being_resolved;
|
352 |
+
std::vector<std::string> _errors;
|
353 |
+
std::vector<std::string> _warnings;
|
354 |
+
|
355 |
+
std::string _add_rule(const std::string & name, const std::string & rule) {
|
356 |
+
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
|
357 |
+
if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
|
358 |
+
_rules[esc_name] = rule;
|
359 |
+
return esc_name;
|
360 |
+
} else {
|
361 |
+
int i = 0;
|
362 |
+
while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
|
363 |
+
i++;
|
364 |
+
}
|
365 |
+
std::string key = esc_name + std::to_string(i);
|
366 |
+
_rules[key] = rule;
|
367 |
+
return key;
|
368 |
+
}
|
369 |
+
}
|
370 |
+
|
371 |
+
std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
|
372 |
+
std::vector<std::string> rules;
|
373 |
+
for (size_t i = 0; i < alt_schemas.size(); i++) {
|
374 |
+
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
|
375 |
+
}
|
376 |
+
return string_join(rules, " | ");
|
377 |
+
}
|
378 |
+
|
379 |
+
std::string _visit_pattern(const std::string & pattern, const std::string & name) {
|
380 |
+
if (!(pattern.front() == '^' && pattern.back() == '$')) {
|
381 |
+
_errors.push_back("Pattern must start with '^' and end with '$'");
|
382 |
+
return "";
|
383 |
+
}
|
384 |
+
std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
|
385 |
+
std::unordered_map<std::string, std::string> sub_rule_ids;
|
386 |
+
|
387 |
+
size_t i = 0;
|
388 |
+
size_t length = sub_pattern.length();
|
389 |
+
|
390 |
+
using literal_or_rule = std::pair<std::string, bool>;
|
391 |
+
auto to_rule = [&](const literal_or_rule & ls) {
|
392 |
+
auto is_literal = ls.second;
|
393 |
+
auto s = ls.first;
|
394 |
+
return is_literal ? "\"" + s + "\"" : s;
|
395 |
+
};
|
396 |
+
std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
|
397 |
+
size_t start = i;
|
398 |
+
std::vector<literal_or_rule> seq;
|
399 |
+
|
400 |
+
auto get_dot = [&]() {
|
401 |
+
std::string rule;
|
402 |
+
if (_dotall) {
|
403 |
+
rule = "[\\U00000000-\\U0010FFFF]";
|
404 |
+
} else {
|
405 |
+
rule = "[^\\x0A\\x0D]";
|
406 |
+
}
|
407 |
+
return _add_rule("dot", rule);
|
408 |
+
};
|
409 |
+
|
410 |
+
// Joins the sequence, merging consecutive literals together.
|
411 |
+
auto join_seq = [&]() {
|
412 |
+
std::vector<literal_or_rule> ret;
|
413 |
+
|
414 |
+
std::string literal;
|
415 |
+
auto flush_literal = [&]() {
|
416 |
+
if (literal.empty()) {
|
417 |
+
return false;
|
418 |
+
}
|
419 |
+
ret.emplace_back(literal, true);
|
420 |
+
literal.clear();
|
421 |
+
return true;
|
422 |
+
};
|
423 |
+
|
424 |
+
for (const auto & item : seq) {
|
425 |
+
auto is_literal = item.second;
|
426 |
+
if (is_literal) {
|
427 |
+
literal += item.first;
|
428 |
+
} else {
|
429 |
+
flush_literal();
|
430 |
+
ret.push_back(item);
|
431 |
+
}
|
432 |
+
}
|
433 |
+
flush_literal();
|
434 |
+
|
435 |
+
std::vector<std::string> results;
|
436 |
+
for (const auto & item : ret) {
|
437 |
+
results.push_back(to_rule(item));
|
438 |
+
}
|
439 |
+
return std::make_pair(string_join(results, " "), false);
|
440 |
+
};
|
441 |
+
|
442 |
+
while (i < length) {
|
443 |
+
char c = sub_pattern[i];
|
444 |
+
if (c == '.') {
|
445 |
+
seq.emplace_back(get_dot(), false);
|
446 |
+
i++;
|
447 |
+
} else if (c == '(') {
|
448 |
+
i++;
|
449 |
+
if (i < length) {
|
450 |
+
if (sub_pattern[i] == '?') {
|
451 |
+
_warnings.push_back("Unsupported pattern syntax");
|
452 |
+
}
|
453 |
+
}
|
454 |
+
seq.emplace_back("(" + to_rule(transform()) + ")", false);
|
455 |
+
} else if (c == ')') {
|
456 |
+
i++;
|
457 |
+
if (start > 0 && sub_pattern[start - 1] != '(') {
|
458 |
+
_errors.push_back("Unbalanced parentheses");
|
459 |
+
}
|
460 |
+
return join_seq();
|
461 |
+
} else if (c == '[') {
|
462 |
+
std::string square_brackets = std::string(1, c);
|
463 |
+
i++;
|
464 |
+
while (i < length && sub_pattern[i] != ']') {
|
465 |
+
if (sub_pattern[i] == '\\') {
|
466 |
+
square_brackets += sub_pattern.substr(i, 2);
|
467 |
+
i += 2;
|
468 |
+
} else {
|
469 |
+
square_brackets += sub_pattern[i];
|
470 |
+
i++;
|
471 |
+
}
|
472 |
+
}
|
473 |
+
if (i >= length) {
|
474 |
+
_errors.push_back("Unbalanced square brackets");
|
475 |
+
}
|
476 |
+
square_brackets += ']';
|
477 |
+
i++;
|
478 |
+
seq.emplace_back(square_brackets, false);
|
479 |
+
} else if (c == '|') {
|
480 |
+
seq.emplace_back("|", false);
|
481 |
+
i++;
|
482 |
+
} else if (c == '*' || c == '+' || c == '?') {
|
483 |
+
seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
|
484 |
+
i++;
|
485 |
+
} else if (c == '{') {
|
486 |
+
std::string curly_brackets = std::string(1, c);
|
487 |
+
i++;
|
488 |
+
while (i < length && sub_pattern[i] != '}') {
|
489 |
+
curly_brackets += sub_pattern[i];
|
490 |
+
i++;
|
491 |
+
}
|
492 |
+
if (i >= length) {
|
493 |
+
_errors.push_back("Unbalanced curly brackets");
|
494 |
+
}
|
495 |
+
curly_brackets += '}';
|
496 |
+
i++;
|
497 |
+
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
|
498 |
+
int min_times = 0;
|
499 |
+
int max_times = std::numeric_limits<int>::max();
|
500 |
+
try {
|
501 |
+
if (nums.size() == 1) {
|
502 |
+
min_times = max_times = std::stoi(nums[0]);
|
503 |
+
} else if (nums.size() != 2) {
|
504 |
+
_errors.push_back("Wrong number of values in curly brackets");
|
505 |
+
} else {
|
506 |
+
if (!nums[0].empty()) {
|
507 |
+
min_times = std::stoi(nums[0]);
|
508 |
+
}
|
509 |
+
if (!nums[1].empty()) {
|
510 |
+
max_times = std::stoi(nums[1]);
|
511 |
+
}
|
512 |
+
}
|
513 |
+
} catch (const std::invalid_argument & e) {
|
514 |
+
_errors.push_back("Invalid number in curly brackets");
|
515 |
+
return std::make_pair("", false);
|
516 |
+
}
|
517 |
+
auto &last = seq.back();
|
518 |
+
auto &sub = last.first;
|
519 |
+
auto sub_is_literal = last.second;
|
520 |
+
|
521 |
+
if (!sub_is_literal) {
|
522 |
+
std::string & sub_id = sub_rule_ids[sub];
|
523 |
+
if (sub_id.empty()) {
|
524 |
+
sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
|
525 |
+
}
|
526 |
+
sub = sub_id;
|
527 |
+
}
|
528 |
+
seq.back().first = build_repetition(
|
529 |
+
sub_is_literal ? "\"" + sub + "\"" : sub,
|
530 |
+
min_times,
|
531 |
+
max_times,
|
532 |
+
""
|
533 |
+
);
|
534 |
+
seq.back().second = false;
|
535 |
+
} else {
|
536 |
+
std::string literal;
|
537 |
+
auto is_non_literal = [&](char c) {
|
538 |
+
return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
|
539 |
+
};
|
540 |
+
while (i < length) {
|
541 |
+
if (sub_pattern[i] == '\\' && i < length - 1) {
|
542 |
+
char next = sub_pattern[i + 1];
|
543 |
+
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
|
544 |
+
i++;
|
545 |
+
literal += sub_pattern[i];
|
546 |
+
i++;
|
547 |
+
} else {
|
548 |
+
literal += sub_pattern.substr(i, 2);
|
549 |
+
i += 2;
|
550 |
+
}
|
551 |
+
} else if (sub_pattern[i] == '"') {
|
552 |
+
literal += "\\\"";
|
553 |
+
i++;
|
554 |
+
} else if (!is_non_literal(sub_pattern[i]) &&
|
555 |
+
(i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
|
556 |
+
literal += sub_pattern[i];
|
557 |
+
i++;
|
558 |
+
} else {
|
559 |
+
break;
|
560 |
+
}
|
561 |
+
}
|
562 |
+
if (!literal.empty()) {
|
563 |
+
seq.emplace_back(literal, true);
|
564 |
+
}
|
565 |
+
}
|
566 |
+
}
|
567 |
+
return join_seq();
|
568 |
+
};
|
569 |
+
return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
|
570 |
+
}
|
571 |
+
|
572 |
+
/*
|
573 |
+
Returns a rule that matches a JSON string that is none of the provided strings
|
574 |
+
|
575 |
+
not_strings({"a"})
|
576 |
+
-> ["] ( [a] char+ | [^"a] char* )? ["] space
|
577 |
+
not_strings({"and", "also"})
|
578 |
+
-> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
|
579 |
+
*/
|
580 |
+
std::string _not_strings(const std::vector<std::string> & strings) {
|
581 |
+
|
582 |
+
struct TrieNode {
|
583 |
+
std::map<char, TrieNode> children;
|
584 |
+
bool is_end_of_string;
|
585 |
+
|
586 |
+
TrieNode() : is_end_of_string(false) {}
|
587 |
+
|
588 |
+
void insert(const std::string & string) {
|
589 |
+
auto node = this;
|
590 |
+
for (char c : string) {
|
591 |
+
node = &node->children[c];
|
592 |
+
}
|
593 |
+
node->is_end_of_string = true;
|
594 |
+
}
|
595 |
+
};
|
596 |
+
|
597 |
+
TrieNode trie;
|
598 |
+
for (const auto & s : strings) {
|
599 |
+
trie.insert(s);
|
600 |
+
}
|
601 |
+
|
602 |
+
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
603 |
+
std::ostringstream out;
|
604 |
+
out << "[\"] ( ";
|
605 |
+
std::function<void(const TrieNode &)> visit = [&](const TrieNode & node) {
|
606 |
+
std::ostringstream rejects;
|
607 |
+
auto first = true;
|
608 |
+
for (const auto & kv : node.children) {
|
609 |
+
rejects << kv.first;
|
610 |
+
if (first) {
|
611 |
+
first = false;
|
612 |
+
} else {
|
613 |
+
out << " | ";
|
614 |
+
}
|
615 |
+
out << "[" << kv.first << "]";
|
616 |
+
if (!kv.second.children.empty()) {
|
617 |
+
out << " (";
|
618 |
+
visit(kv.second);
|
619 |
+
out << ")";
|
620 |
+
} else if (kv.second.is_end_of_string) {
|
621 |
+
out << " " << char_rule << "+";
|
622 |
+
}
|
623 |
+
}
|
624 |
+
if (!node.children.empty()) {
|
625 |
+
if (!first) {
|
626 |
+
out << " | ";
|
627 |
+
}
|
628 |
+
out << "[^\"" << rejects.str() << "] " << char_rule << "*";
|
629 |
+
}
|
630 |
+
};
|
631 |
+
visit(trie);
|
632 |
+
|
633 |
+
out << " )";
|
634 |
+
if (!trie.is_end_of_string) {
|
635 |
+
out << "?";
|
636 |
+
}
|
637 |
+
out << " [\"] space";
|
638 |
+
return out.str();
|
639 |
+
}
|
640 |
+
|
641 |
+
std::string _resolve_ref(const std::string & ref) {
|
642 |
+
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
|
643 |
+
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
|
644 |
+
_refs_being_resolved.insert(ref);
|
645 |
+
json resolved = _refs[ref];
|
646 |
+
ref_name = visit(resolved, ref_name);
|
647 |
+
_refs_being_resolved.erase(ref);
|
648 |
+
}
|
649 |
+
return ref_name;
|
650 |
+
}
|
651 |
+
|
652 |
+
std::string _build_object_rule(
|
653 |
+
const std::vector<std::pair<std::string, json>> & properties,
|
654 |
+
const std::unordered_set<std::string> & required,
|
655 |
+
const std::string & name,
|
656 |
+
const json & additional_properties)
|
657 |
+
{
|
658 |
+
std::vector<std::string> required_props;
|
659 |
+
std::vector<std::string> optional_props;
|
660 |
+
std::unordered_map<std::string, std::string> prop_kv_rule_names;
|
661 |
+
std::vector<std::string> prop_names;
|
662 |
+
for (const auto & kv : properties) {
|
663 |
+
const auto &prop_name = kv.first;
|
664 |
+
const auto &prop_schema = kv.second;
|
665 |
+
|
666 |
+
std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
|
667 |
+
prop_kv_rule_names[prop_name] = _add_rule(
|
668 |
+
name + (name.empty() ? "" : "-") + prop_name + "-kv",
|
669 |
+
format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
|
670 |
+
);
|
671 |
+
if (required.find(prop_name) != required.end()) {
|
672 |
+
required_props.push_back(prop_name);
|
673 |
+
} else {
|
674 |
+
optional_props.push_back(prop_name);
|
675 |
+
}
|
676 |
+
prop_names.push_back(prop_name);
|
677 |
+
}
|
678 |
+
if ((additional_properties.is_boolean() && additional_properties.get<bool>()) || additional_properties.is_object()) {
|
679 |
+
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
|
680 |
+
std::string value_rule =
|
681 |
+
additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
|
682 |
+
: _add_primitive("value", PRIMITIVE_RULES.at("value"));
|
683 |
+
|
684 |
+
auto key_rule =
|
685 |
+
prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
|
686 |
+
: _add_rule(sub_name + "-k", _not_strings(prop_names));
|
687 |
+
std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
|
688 |
+
prop_kv_rule_names["*"] = kv_rule;
|
689 |
+
optional_props.push_back("*");
|
690 |
+
}
|
691 |
+
|
692 |
+
std::string rule = "\"{\" space ";
|
693 |
+
for (size_t i = 0; i < required_props.size(); i++) {
|
694 |
+
if (i > 0) {
|
695 |
+
rule += " \",\" space ";
|
696 |
+
}
|
697 |
+
rule += prop_kv_rule_names[required_props[i]];
|
698 |
+
}
|
699 |
+
|
700 |
+
if (!optional_props.empty()) {
|
701 |
+
rule += " (";
|
702 |
+
if (!required_props.empty()) {
|
703 |
+
rule += " \",\" space ( ";
|
704 |
+
}
|
705 |
+
|
706 |
+
std::function<std::string(const std::vector<std::string> &, bool)> get_recursive_refs = [&](const std::vector<std::string> & ks, bool first_is_optional) {
|
707 |
+
std::string res;
|
708 |
+
if (ks.empty()) {
|
709 |
+
return res;
|
710 |
+
}
|
711 |
+
std::string k = ks[0];
|
712 |
+
std::string kv_rule_name = prop_kv_rule_names[k];
|
713 |
+
std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
|
714 |
+
if (first_is_optional) {
|
715 |
+
res = comma_ref + (k == "*" ? "*" : "?");
|
716 |
+
} else {
|
717 |
+
res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
|
718 |
+
}
|
719 |
+
if (ks.size() > 1) {
|
720 |
+
res += " " + _add_rule(
|
721 |
+
name + (name.empty() ? "" : "-") + k + "-rest",
|
722 |
+
get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
|
723 |
+
);
|
724 |
+
}
|
725 |
+
return res;
|
726 |
+
};
|
727 |
+
|
728 |
+
for (size_t i = 0; i < optional_props.size(); i++) {
|
729 |
+
if (i > 0) {
|
730 |
+
rule += " | ";
|
731 |
+
}
|
732 |
+
rule += get_recursive_refs(std::vector<std::string>(optional_props.begin() + i, optional_props.end()), false);
|
733 |
+
}
|
734 |
+
if (!required_props.empty()) {
|
735 |
+
rule += " )";
|
736 |
+
}
|
737 |
+
rule += " )?";
|
738 |
+
}
|
739 |
+
|
740 |
+
rule += " \"}\" space";
|
741 |
+
|
742 |
+
return rule;
|
743 |
+
}
|
744 |
+
|
745 |
+
std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
|
746 |
+
auto n = _add_rule(name, rule.content);
|
747 |
+
for (const auto & dep : rule.deps) {
|
748 |
+
BuiltinRule dep_rule;
|
749 |
+
auto it = PRIMITIVE_RULES.find(dep);
|
750 |
+
if (it == PRIMITIVE_RULES.end()) {
|
751 |
+
it = STRING_FORMAT_RULES.find(dep);
|
752 |
+
if (it == STRING_FORMAT_RULES.end()) {
|
753 |
+
_errors.push_back("Rule " + dep + " not known");
|
754 |
+
continue;
|
755 |
+
}
|
756 |
+
}
|
757 |
+
if (_rules.find(dep) == _rules.end()) {
|
758 |
+
_add_primitive(dep, it->second);
|
759 |
+
}
|
760 |
+
}
|
761 |
+
return n;
|
762 |
+
}
|
763 |
+
|
764 |
+
public:
|
765 |
+
SchemaConverter(
|
766 |
+
const std::function<json(const std::string &)> & fetch_json,
|
767 |
+
bool dotall)
|
768 |
+
: _fetch_json(fetch_json), _dotall(dotall)
|
769 |
+
{
|
770 |
+
_rules["space"] = SPACE_RULE;
|
771 |
+
}
|
772 |
+
|
773 |
+
void resolve_refs(json & schema, const std::string & url) {
|
774 |
+
/*
|
775 |
+
* Resolves all $ref fields in the given schema, fetching any remote schemas,
|
776 |
+
* replacing each $ref with absolute reference URL and populates _refs with the
|
777 |
+
* respective referenced (sub)schema dictionaries.
|
778 |
+
*/
|
779 |
+
std::function<void(json &)> visit_refs = [&](json & n) {
|
780 |
+
if (n.is_array()) {
|
781 |
+
for (auto & x : n) {
|
782 |
+
visit_refs(x);
|
783 |
+
}
|
784 |
+
} else if (n.is_object()) {
|
785 |
+
if (n.contains("$ref")) {
|
786 |
+
std::string ref = n["$ref"];
|
787 |
+
if (_refs.find(ref) == _refs.end()) {
|
788 |
+
json target;
|
789 |
+
if (ref.find("https://") == 0) {
|
790 |
+
std::string base_url = ref.substr(0, ref.find('#'));
|
791 |
+
auto it = _refs.find(base_url);
|
792 |
+
if (it != _refs.end()) {
|
793 |
+
target = it->second;
|
794 |
+
} else {
|
795 |
+
// Fetch the referenced schema and resolve its refs
|
796 |
+
auto referenced = _fetch_json(ref);
|
797 |
+
resolve_refs(referenced, base_url);
|
798 |
+
_refs[base_url] = referenced;
|
799 |
+
}
|
800 |
+
if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
|
801 |
+
return;
|
802 |
+
}
|
803 |
+
} else if (ref.find("#/") == 0) {
|
804 |
+
target = schema;
|
805 |
+
n["$ref"] = url + ref;
|
806 |
+
ref = url + ref;
|
807 |
+
} else {
|
808 |
+
_errors.push_back("Unsupported ref: " + ref);
|
809 |
+
return;
|
810 |
+
}
|
811 |
+
std::string pointer = ref.substr(ref.find('#') + 1);
|
812 |
+
std::vector<std::string> tokens = string_split(pointer, "/");
|
813 |
+
for (size_t i = 1; i < tokens.size(); ++i) {
|
814 |
+
std::string sel = tokens[i];
|
815 |
+
if (target.is_null() || !target.contains(sel)) {
|
816 |
+
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
|
817 |
+
return;
|
818 |
+
}
|
819 |
+
target = target[sel];
|
820 |
+
}
|
821 |
+
_refs[ref] = target;
|
822 |
+
}
|
823 |
+
} else {
|
824 |
+
for (auto & kv : n.items()) {
|
825 |
+
visit_refs(kv.value());
|
826 |
+
}
|
827 |
+
}
|
828 |
+
}
|
829 |
+
};
|
830 |
+
|
831 |
+
visit_refs(schema);
|
832 |
+
}
|
833 |
+
|
834 |
+
std::string _generate_constant_rule(const json & value) {
|
835 |
+
return format_literal(value.dump());
|
836 |
+
}
|
837 |
+
|
838 |
+
std::string visit(const json & schema, const std::string & name) {
|
839 |
+
json schema_type = schema.contains("type") ? schema["type"] : json();
|
840 |
+
std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
|
841 |
+
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
|
842 |
+
|
843 |
+
if (schema.contains("$ref")) {
|
844 |
+
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
|
845 |
+
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
|
846 |
+
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
|
847 |
+
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
|
848 |
+
} else if (schema_type.is_array()) {
|
849 |
+
std::vector<json> schema_types;
|
850 |
+
for (const auto & t : schema_type) {
|
851 |
+
json schema_copy(schema);
|
852 |
+
schema_copy["type"] = t;
|
853 |
+
schema_types.push_back(schema_copy);
|
854 |
+
}
|
855 |
+
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
|
856 |
+
} else if (schema.contains("const")) {
|
857 |
+
return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
|
858 |
+
} else if (schema.contains("enum")) {
|
859 |
+
std::vector<std::string> enum_values;
|
860 |
+
for (const auto & v : schema["enum"]) {
|
861 |
+
enum_values.push_back(_generate_constant_rule(v));
|
862 |
+
}
|
863 |
+
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
|
864 |
+
} else if ((schema_type.is_null() || schema_type == "object")
|
865 |
+
&& (schema.contains("properties") ||
|
866 |
+
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
|
867 |
+
std::unordered_set<std::string> required;
|
868 |
+
if (schema.contains("required") && schema["required"].is_array()) {
|
869 |
+
for (const auto & item : schema["required"]) {
|
870 |
+
if (item.is_string()) {
|
871 |
+
required.insert(item.get<std::string>());
|
872 |
+
}
|
873 |
+
}
|
874 |
+
}
|
875 |
+
std::vector<std::pair<std::string, json>> properties;
|
876 |
+
if (schema.contains("properties")) {
|
877 |
+
for (const auto & prop : schema["properties"].items()) {
|
878 |
+
properties.emplace_back(prop.key(), prop.value());
|
879 |
+
}
|
880 |
+
}
|
881 |
+
return _add_rule(rule_name,
|
882 |
+
_build_object_rule(
|
883 |
+
properties, required, name,
|
884 |
+
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
|
885 |
+
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
|
886 |
+
std::unordered_set<std::string> required;
|
887 |
+
std::vector<std::pair<std::string, json>> properties;
|
888 |
+
std::string hybrid_name = name;
|
889 |
+
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
|
890 |
+
if (comp_schema.contains("$ref")) {
|
891 |
+
add_component(_refs[comp_schema["$ref"]], is_required);
|
892 |
+
} else if (comp_schema.contains("properties")) {
|
893 |
+
for (const auto & prop : comp_schema["properties"].items()) {
|
894 |
+
properties.emplace_back(prop.key(), prop.value());
|
895 |
+
if (is_required) {
|
896 |
+
required.insert(prop.key());
|
897 |
+
}
|
898 |
+
}
|
899 |
+
} else {
|
900 |
+
// todo warning
|
901 |
+
}
|
902 |
+
};
|
903 |
+
for (auto & t : schema["allOf"]) {
|
904 |
+
if (t.contains("anyOf")) {
|
905 |
+
for (auto & tt : t["anyOf"]) {
|
906 |
+
add_component(tt, false);
|
907 |
+
}
|
908 |
+
} else {
|
909 |
+
add_component(t, true);
|
910 |
+
}
|
911 |
+
}
|
912 |
+
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
|
913 |
+
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
|
914 |
+
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
|
915 |
+
if (items.is_array()) {
|
916 |
+
std::string rule = "\"[\" space ";
|
917 |
+
for (size_t i = 0; i < items.size(); i++) {
|
918 |
+
if (i > 0) {
|
919 |
+
rule += " \",\" space ";
|
920 |
+
}
|
921 |
+
rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
|
922 |
+
}
|
923 |
+
rule += " \"]\" space";
|
924 |
+
return _add_rule(rule_name, rule);
|
925 |
+
} else {
|
926 |
+
std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
|
927 |
+
int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
|
928 |
+
json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
|
929 |
+
int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
|
930 |
+
|
931 |
+
return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
|
932 |
+
}
|
933 |
+
} else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
|
934 |
+
return _visit_pattern(schema["pattern"], rule_name);
|
935 |
+
} else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
|
936 |
+
return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
|
937 |
+
} else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
|
938 |
+
auto prim_name = schema_format + "-string";
|
939 |
+
return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
|
940 |
+
} else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
|
941 |
+
std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
|
942 |
+
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
|
943 |
+
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
944 |
+
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
945 |
+
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
946 |
+
int min_value = std::numeric_limits<int>::min();
|
947 |
+
int max_value = std::numeric_limits<int>::max();
|
948 |
+
if (schema.contains("minimum")) {
|
949 |
+
min_value = schema["minimum"].get<int>();
|
950 |
+
} else if (schema.contains("exclusiveMinimum")) {
|
951 |
+
min_value = schema["exclusiveMinimum"].get<int>() + 1;
|
952 |
+
}
|
953 |
+
if (schema.contains("maximum")) {
|
954 |
+
max_value = schema["maximum"].get<int>();
|
955 |
+
} else if (schema.contains("exclusiveMaximum")) {
|
956 |
+
max_value = schema["exclusiveMaximum"].get<int>() - 1;
|
957 |
+
}
|
958 |
+
std::stringstream out;
|
959 |
+
out << "(";
|
960 |
+
_build_min_max_int(min_value, max_value, out);
|
961 |
+
out << ") space";
|
962 |
+
return _add_rule(rule_name, out.str());
|
963 |
+
} else if (schema.empty() || schema_type == "object") {
|
964 |
+
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
|
965 |
+
} else {
|
966 |
+
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
|
967 |
+
_errors.push_back("Unrecognized schema: " + schema.dump());
|
968 |
+
return "";
|
969 |
+
}
|
970 |
+
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
971 |
+
return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
|
972 |
+
}
|
973 |
+
}
|
974 |
+
|
975 |
+
void check_errors() {
|
976 |
+
if (!_errors.empty()) {
|
977 |
+
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
|
978 |
+
}
|
979 |
+
if (!_warnings.empty()) {
|
980 |
+
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
|
981 |
+
}
|
982 |
+
}
|
983 |
+
|
984 |
+
std::string format_grammar() {
|
985 |
+
std::stringstream ss;
|
986 |
+
for (const auto & kv : _rules) {
|
987 |
+
ss << kv.first << " ::= " << kv.second << std::endl;
|
988 |
+
}
|
989 |
+
return ss.str();
|
990 |
+
}
|
991 |
+
};
|
992 |
+
|
993 |
+
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
|
994 |
+
#ifdef LLAMA_USE_LLGUIDANCE
|
995 |
+
if (!force_gbnf) {
|
996 |
+
return "%llguidance {}\nstart: %json " + schema.dump();
|
997 |
+
}
|
998 |
+
#else
|
999 |
+
(void)force_gbnf;
|
1000 |
+
#endif // LLAMA_USE_LLGUIDANCE
|
1001 |
+
return build_grammar([&](const common_grammar_builder & callbacks) {
|
1002 |
+
auto copy = schema;
|
1003 |
+
callbacks.resolve_refs(copy);
|
1004 |
+
callbacks.add_schema("", copy);
|
1005 |
+
});
|
1006 |
+
}
|
1007 |
+
|
1008 |
+
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
|
1009 |
+
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
|
1010 |
+
common_grammar_builder builder {
|
1011 |
+
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
|
1012 |
+
return converter._add_rule(name, rule);
|
1013 |
+
},
|
1014 |
+
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
|
1015 |
+
return converter.visit(schema, name == "root" ? "" : name);
|
1016 |
+
},
|
1017 |
+
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
|
1018 |
+
converter.resolve_refs(schema, "");
|
1019 |
+
}
|
1020 |
+
};
|
1021 |
+
cb(builder);
|
1022 |
+
converter.check_errors();
|
1023 |
+
return converter.format_grammar();
|
1024 |
+
}
|
common/json-schema-to-grammar.h
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "ggml.h"
|
4 |
+
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
5 |
+
#define JSON_ASSERT GGML_ASSERT
|
6 |
+
#include "json.hpp"
|
7 |
+
|
8 |
+
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
|
9 |
+
bool force_gbnf = false);
|
10 |
+
|
11 |
+
struct common_grammar_builder {
|
12 |
+
std::function<std::string(const std::string &, const std::string &)> add_rule;
|
13 |
+
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
|
14 |
+
std::function<void(nlohmann::ordered_json &)> resolve_refs;
|
15 |
+
};
|
16 |
+
|
17 |
+
struct common_grammar_options {
|
18 |
+
bool dotall = false;
|
19 |
+
};
|
20 |
+
|
21 |
+
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
|
common/json.hpp
ADDED
The diff for this file is too large to render.
See raw diff
|
|
common/llguidance.cpp
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "sampling.h"
|
2 |
+
#include "log.h"
|
3 |
+
|
4 |
+
#ifdef LLAMA_USE_LLGUIDANCE
|
5 |
+
|
6 |
+
# include "llguidance.h"
|
7 |
+
# include <cmath>
|
8 |
+
|
9 |
+
struct llama_sampler_llg {
|
10 |
+
const llama_vocab * vocab;
|
11 |
+
std::string grammar_kind;
|
12 |
+
std::string grammar_data;
|
13 |
+
LlgTokenizer * tokenizer;
|
14 |
+
LlgConstraint * grammar;
|
15 |
+
LlgMaskResult llg_res;
|
16 |
+
bool has_llg_res;
|
17 |
+
};
|
18 |
+
|
19 |
+
static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
|
20 |
+
const char * grammar_data) {
|
21 |
+
LlgConstraintInit cinit;
|
22 |
+
llg_constraint_init_set_defaults(&cinit, tokenizer);
|
23 |
+
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
|
24 |
+
if (log_level && *log_level) {
|
25 |
+
cinit.log_stderr_level = atoi(log_level);
|
26 |
+
}
|
27 |
+
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
|
28 |
+
if (llg_get_error(c)) {
|
29 |
+
LOG_ERR("llg error: %s\n", llg_get_error(c));
|
30 |
+
llg_free_constraint(c);
|
31 |
+
return nullptr;
|
32 |
+
}
|
33 |
+
return c;
|
34 |
+
}
|
35 |
+
|
36 |
+
static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
|
37 |
+
return "llguidance";
|
38 |
+
}
|
39 |
+
|
40 |
+
static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
|
41 |
+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
42 |
+
if (ctx->grammar) {
|
43 |
+
LlgCommitResult res;
|
44 |
+
llg_commit_token(ctx->grammar, token, &res);
|
45 |
+
ctx->has_llg_res = false;
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
|
50 |
+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
51 |
+
if (ctx->grammar) {
|
52 |
+
if (!ctx->has_llg_res) {
|
53 |
+
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
|
54 |
+
ctx->has_llg_res = true;
|
55 |
+
} else {
|
56 |
+
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
|
57 |
+
llg_free_constraint(ctx->grammar);
|
58 |
+
ctx->grammar = nullptr;
|
59 |
+
}
|
60 |
+
}
|
61 |
+
if (ctx->has_llg_res) {
|
62 |
+
if (ctx->llg_res.is_stop) {
|
63 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
64 |
+
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
|
65 |
+
cur_p->data[i].logit = -INFINITY;
|
66 |
+
}
|
67 |
+
}
|
68 |
+
} else {
|
69 |
+
const uint32_t * mask = ctx->llg_res.sample_mask;
|
70 |
+
for (size_t i = 0; i < cur_p->size; ++i) {
|
71 |
+
auto token = cur_p->data[i].id;
|
72 |
+
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
|
73 |
+
cur_p->data[i].logit = -INFINITY;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
}
|
77 |
+
}
|
78 |
+
}
|
79 |
+
}
|
80 |
+
|
81 |
+
static void llama_sampler_llg_reset(llama_sampler * smpl) {
|
82 |
+
auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
83 |
+
if (!ctx->grammar) {
|
84 |
+
return;
|
85 |
+
}
|
86 |
+
|
87 |
+
auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
|
88 |
+
llg_free_constraint(ctx->grammar);
|
89 |
+
ctx->grammar = grammar_new;
|
90 |
+
ctx->has_llg_res = false;
|
91 |
+
}
|
92 |
+
|
93 |
+
static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
|
94 |
+
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
|
95 |
+
|
96 |
+
auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
|
97 |
+
|
98 |
+
// copy the state
|
99 |
+
{
|
100 |
+
auto * result_ctx = (llama_sampler_llg *) result->ctx;
|
101 |
+
|
102 |
+
if (ctx->grammar) {
|
103 |
+
result_ctx->grammar_kind = ctx->grammar_kind;
|
104 |
+
result_ctx->grammar_data = ctx->grammar_data;
|
105 |
+
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
|
106 |
+
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
return result;
|
111 |
+
}
|
112 |
+
|
113 |
+
static void llama_sampler_llg_free(llama_sampler * smpl) {
|
114 |
+
const auto * ctx = (llama_sampler_llg *) smpl->ctx;
|
115 |
+
|
116 |
+
if (ctx->grammar) {
|
117 |
+
llg_free_constraint(ctx->grammar);
|
118 |
+
llg_free_tokenizer(ctx->tokenizer);
|
119 |
+
}
|
120 |
+
|
121 |
+
delete ctx;
|
122 |
+
}
|
123 |
+
|
124 |
+
static llama_sampler_i llama_sampler_llg_i = {
|
125 |
+
/* .name = */ llama_sampler_llg_name,
|
126 |
+
/* .accept = */ llama_sampler_llg_accept_impl,
|
127 |
+
/* .apply = */ llama_sampler_llg_apply,
|
128 |
+
/* .reset = */ llama_sampler_llg_reset,
|
129 |
+
/* .clone = */ llama_sampler_llg_clone,
|
130 |
+
/* .free = */ llama_sampler_llg_free,
|
131 |
+
};
|
132 |
+
|
133 |
+
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
134 |
+
uint32_t * output_tokens, size_t output_tokens_len) {
|
135 |
+
const llama_vocab * vocab = (const llama_vocab *) user_data;
|
136 |
+
int r = 0;
|
137 |
+
try {
|
138 |
+
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
|
139 |
+
true);
|
140 |
+
} catch (const std::exception & e) {
|
141 |
+
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
|
142 |
+
}
|
143 |
+
if (r < 0) {
|
144 |
+
return -r;
|
145 |
+
}
|
146 |
+
return r;
|
147 |
+
}
|
148 |
+
|
149 |
+
static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
|
150 |
+
// TODO store the tokenizer in the vocab somehow
|
151 |
+
static const llama_vocab * vocab_cache;
|
152 |
+
static LlgTokenizer * tokenizer_cache;
|
153 |
+
|
154 |
+
if (vocab_cache == vocab) {
|
155 |
+
return llg_clone_tokenizer(tokenizer_cache);
|
156 |
+
}
|
157 |
+
|
158 |
+
auto tok_eos = llama_vocab_eot(vocab);
|
159 |
+
if (tok_eos == LLAMA_TOKEN_NULL) {
|
160 |
+
tok_eos = llama_vocab_eos(vocab);
|
161 |
+
}
|
162 |
+
|
163 |
+
size_t vocab_size = llama_vocab_n_tokens(vocab);
|
164 |
+
|
165 |
+
auto token_lens = new uint32_t[vocab_size];
|
166 |
+
// we typically have ~7 bytes per token; let's go on the safe side here
|
167 |
+
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
|
168 |
+
auto token_bytes = new uint8_t[token_bytes_size];
|
169 |
+
|
170 |
+
size_t offset = 0;
|
171 |
+
for (size_t i = 0; i < vocab_size; i++) {
|
172 |
+
size_t max_token = 1024;
|
173 |
+
if (token_bytes_size - offset < max_token) {
|
174 |
+
GGML_ABORT("token_bytes buffer too small\n");
|
175 |
+
}
|
176 |
+
|
177 |
+
llama_token token = i;
|
178 |
+
auto dp = (char *) token_bytes + offset;
|
179 |
+
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
|
180 |
+
if (size < 0) {
|
181 |
+
GGML_ABORT("llama_detokenize failed\n");
|
182 |
+
}
|
183 |
+
if (size == 0) {
|
184 |
+
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
|
185 |
+
if (size < 0) {
|
186 |
+
GGML_ABORT("llama_detokenize failed\n");
|
187 |
+
}
|
188 |
+
if (size != 0) {
|
189 |
+
*dp = '\xff'; // special token prefix marker
|
190 |
+
size += 1;
|
191 |
+
}
|
192 |
+
}
|
193 |
+
|
194 |
+
token_lens[i] = size;
|
195 |
+
offset += size;
|
196 |
+
}
|
197 |
+
|
198 |
+
LlgTokenizerInit tinit = {
|
199 |
+
/* .vocab_size = */ (uint32_t) vocab_size,
|
200 |
+
/* .tok_eos = */ (uint32_t) tok_eos,
|
201 |
+
/* .token_lens = */ token_lens,
|
202 |
+
/* .token_bytes = */ token_bytes,
|
203 |
+
/* .tokenizer_json = */ nullptr,
|
204 |
+
/* .tokenize_assumes_string = */ true,
|
205 |
+
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
|
206 |
+
/* .use_approximate_greedy_tokenize_fn = */ false,
|
207 |
+
/* .tokenize_user_data = */ vocab,
|
208 |
+
};
|
209 |
+
|
210 |
+
char error_buffer[1024];
|
211 |
+
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
|
212 |
+
|
213 |
+
delete[] token_bytes;
|
214 |
+
delete[] token_lens;
|
215 |
+
|
216 |
+
if (tokenizer == nullptr) {
|
217 |
+
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
|
218 |
+
return tokenizer;
|
219 |
+
}
|
220 |
+
|
221 |
+
if (tokenizer_cache) {
|
222 |
+
llg_free_tokenizer(tokenizer_cache);
|
223 |
+
}
|
224 |
+
vocab_cache = vocab;
|
225 |
+
tokenizer_cache = tokenizer;
|
226 |
+
|
227 |
+
return llg_clone_tokenizer(tokenizer_cache);
|
228 |
+
}
|
229 |
+
|
230 |
+
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
|
231 |
+
const char * grammar_data) {
|
232 |
+
auto * ctx = new llama_sampler_llg;
|
233 |
+
|
234 |
+
if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
|
235 |
+
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
|
236 |
+
*ctx = {
|
237 |
+
/* .vocab = */ vocab,
|
238 |
+
/* .grammar_kind = */ grammar_kind,
|
239 |
+
/* .grammar_data = */ grammar_data,
|
240 |
+
/* .tokenizer = */ tokenizer,
|
241 |
+
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
|
242 |
+
/* .llg_res = */ {},
|
243 |
+
/* .has_llg_res = */ false,
|
244 |
+
};
|
245 |
+
} else {
|
246 |
+
*ctx = {
|
247 |
+
/* .vocab = */ vocab,
|
248 |
+
/* .grammar_kind = */ {},
|
249 |
+
/* .grammar_data = */ {},
|
250 |
+
/* .tokenizer = */ nullptr,
|
251 |
+
/* .grammar = */ nullptr,
|
252 |
+
/* .llg_res = */ {},
|
253 |
+
/* .has_llg_res = */ false,
|
254 |
+
};
|
255 |
+
}
|
256 |
+
|
257 |
+
return llama_sampler_init(
|
258 |
+
/* .iface = */ &llama_sampler_llg_i,
|
259 |
+
/* .ctx = */ ctx
|
260 |
+
);
|
261 |
+
}
|
262 |
+
|
263 |
+
#else
|
264 |
+
|
265 |
+
llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
|
266 |
+
LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
267 |
+
return nullptr;
|
268 |
+
}
|
269 |
+
|
270 |
+
#endif // LLAMA_USE_LLGUIDANCE
|
common/log.cpp
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "log.h"
|
2 |
+
|
3 |
+
#include <chrono>
|
4 |
+
#include <condition_variable>
|
5 |
+
#include <cstdarg>
|
6 |
+
#include <cstdio>
|
7 |
+
#include <mutex>
|
8 |
+
#include <sstream>
|
9 |
+
#include <thread>
|
10 |
+
#include <vector>
|
11 |
+
|
12 |
+
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
|
13 |
+
|
14 |
+
void common_log_set_verbosity_thold(int verbosity) {
|
15 |
+
common_log_verbosity_thold = verbosity;
|
16 |
+
}
|
17 |
+
|
18 |
+
static int64_t t_us() {
|
19 |
+
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
20 |
+
}
|
21 |
+
|
22 |
+
// colors
|
23 |
+
enum common_log_col : int {
|
24 |
+
COMMON_LOG_COL_DEFAULT = 0,
|
25 |
+
COMMON_LOG_COL_BOLD,
|
26 |
+
COMMON_LOG_COL_RED,
|
27 |
+
COMMON_LOG_COL_GREEN,
|
28 |
+
COMMON_LOG_COL_YELLOW,
|
29 |
+
COMMON_LOG_COL_BLUE,
|
30 |
+
COMMON_LOG_COL_MAGENTA,
|
31 |
+
COMMON_LOG_COL_CYAN,
|
32 |
+
COMMON_LOG_COL_WHITE,
|
33 |
+
};
|
34 |
+
|
35 |
+
// disable colors by default
|
36 |
+
static std::vector<const char *> g_col = {
|
37 |
+
"",
|
38 |
+
"",
|
39 |
+
"",
|
40 |
+
"",
|
41 |
+
"",
|
42 |
+
"",
|
43 |
+
"",
|
44 |
+
"",
|
45 |
+
"",
|
46 |
+
};
|
47 |
+
|
48 |
+
struct common_log_entry {
|
49 |
+
enum ggml_log_level level;
|
50 |
+
|
51 |
+
bool prefix;
|
52 |
+
|
53 |
+
int64_t timestamp;
|
54 |
+
|
55 |
+
std::vector<char> msg;
|
56 |
+
|
57 |
+
// signals the worker thread to stop
|
58 |
+
bool is_end;
|
59 |
+
|
60 |
+
void print(FILE * file = nullptr) const {
|
61 |
+
FILE * fcur = file;
|
62 |
+
if (!fcur) {
|
63 |
+
// stderr displays DBG messages only when their verbosity level is not higher than the threshold
|
64 |
+
// these messages will still be logged to a file
|
65 |
+
if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) {
|
66 |
+
return;
|
67 |
+
}
|
68 |
+
|
69 |
+
fcur = stdout;
|
70 |
+
|
71 |
+
if (level != GGML_LOG_LEVEL_NONE) {
|
72 |
+
fcur = stderr;
|
73 |
+
}
|
74 |
+
}
|
75 |
+
|
76 |
+
if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
|
77 |
+
if (timestamp) {
|
78 |
+
// [M.s.ms.us]
|
79 |
+
fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
|
80 |
+
g_col[COMMON_LOG_COL_BLUE],
|
81 |
+
(int) (timestamp / 1000000 / 60),
|
82 |
+
(int) (timestamp / 1000000 % 60),
|
83 |
+
(int) (timestamp / 1000 % 1000),
|
84 |
+
(int) (timestamp % 1000),
|
85 |
+
g_col[COMMON_LOG_COL_DEFAULT]);
|
86 |
+
}
|
87 |
+
|
88 |
+
switch (level) {
|
89 |
+
case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break;
|
90 |
+
case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break;
|
91 |
+
case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break;
|
92 |
+
case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break;
|
93 |
+
default:
|
94 |
+
break;
|
95 |
+
}
|
96 |
+
}
|
97 |
+
|
98 |
+
fprintf(fcur, "%s", msg.data());
|
99 |
+
|
100 |
+
if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) {
|
101 |
+
fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]);
|
102 |
+
}
|
103 |
+
|
104 |
+
fflush(fcur);
|
105 |
+
}
|
106 |
+
};
|
107 |
+
|
108 |
+
struct common_log {
|
109 |
+
// default capacity - will be expanded if needed
|
110 |
+
common_log() : common_log(256) {}
|
111 |
+
|
112 |
+
common_log(size_t capacity) {
|
113 |
+
file = nullptr;
|
114 |
+
prefix = false;
|
115 |
+
timestamps = false;
|
116 |
+
running = false;
|
117 |
+
t_start = t_us();
|
118 |
+
|
119 |
+
// initial message size - will be expanded if longer messages arrive
|
120 |
+
entries.resize(capacity);
|
121 |
+
for (auto & entry : entries) {
|
122 |
+
entry.msg.resize(256);
|
123 |
+
}
|
124 |
+
|
125 |
+
head = 0;
|
126 |
+
tail = 0;
|
127 |
+
|
128 |
+
resume();
|
129 |
+
}
|
130 |
+
|
131 |
+
~common_log() {
|
132 |
+
pause();
|
133 |
+
if (file) {
|
134 |
+
fclose(file);
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
private:
|
139 |
+
std::mutex mtx;
|
140 |
+
std::thread thrd;
|
141 |
+
std::condition_variable cv;
|
142 |
+
|
143 |
+
FILE * file;
|
144 |
+
|
145 |
+
bool prefix;
|
146 |
+
bool timestamps;
|
147 |
+
bool running;
|
148 |
+
|
149 |
+
int64_t t_start;
|
150 |
+
|
151 |
+
// ring buffer of entries
|
152 |
+
std::vector<common_log_entry> entries;
|
153 |
+
size_t head;
|
154 |
+
size_t tail;
|
155 |
+
|
156 |
+
// worker thread copies into this
|
157 |
+
common_log_entry cur;
|
158 |
+
|
159 |
+
public:
|
160 |
+
void add(enum ggml_log_level level, const char * fmt, va_list args) {
|
161 |
+
std::lock_guard<std::mutex> lock(mtx);
|
162 |
+
|
163 |
+
if (!running) {
|
164 |
+
// discard messages while the worker thread is paused
|
165 |
+
return;
|
166 |
+
}
|
167 |
+
|
168 |
+
auto & entry = entries[tail];
|
169 |
+
|
170 |
+
{
|
171 |
+
// cannot use args twice, so make a copy in case we need to expand the buffer
|
172 |
+
va_list args_copy;
|
173 |
+
va_copy(args_copy, args);
|
174 |
+
|
175 |
+
#if 1
|
176 |
+
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
|
177 |
+
if (n >= entry.msg.size()) {
|
178 |
+
entry.msg.resize(n + 1);
|
179 |
+
vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy);
|
180 |
+
}
|
181 |
+
#else
|
182 |
+
// hack for bolding arguments
|
183 |
+
|
184 |
+
std::stringstream ss;
|
185 |
+
for (int i = 0; fmt[i] != 0; i++) {
|
186 |
+
if (fmt[i] == '%') {
|
187 |
+
ss << LOG_COL_BOLD;
|
188 |
+
while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++];
|
189 |
+
ss << LOG_COL_DEFAULT;
|
190 |
+
if (fmt[i] == 0) break;
|
191 |
+
}
|
192 |
+
ss << fmt[i];
|
193 |
+
}
|
194 |
+
const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args);
|
195 |
+
if (n >= entry.msg.size()) {
|
196 |
+
entry.msg.resize(n + 1);
|
197 |
+
vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
|
198 |
+
}
|
199 |
+
#endif
|
200 |
+
va_end(args_copy);
|
201 |
+
}
|
202 |
+
|
203 |
+
entry.level = level;
|
204 |
+
entry.prefix = prefix;
|
205 |
+
entry.timestamp = 0;
|
206 |
+
if (timestamps) {
|
207 |
+
entry.timestamp = t_us() - t_start;
|
208 |
+
}
|
209 |
+
entry.is_end = false;
|
210 |
+
|
211 |
+
tail = (tail + 1) % entries.size();
|
212 |
+
if (tail == head) {
|
213 |
+
// expand the buffer
|
214 |
+
std::vector<common_log_entry> new_entries(2*entries.size());
|
215 |
+
|
216 |
+
size_t new_tail = 0;
|
217 |
+
|
218 |
+
do {
|
219 |
+
new_entries[new_tail] = std::move(entries[head]);
|
220 |
+
|
221 |
+
head = (head + 1) % entries.size();
|
222 |
+
new_tail = (new_tail + 1);
|
223 |
+
} while (head != tail);
|
224 |
+
|
225 |
+
head = 0;
|
226 |
+
tail = new_tail;
|
227 |
+
|
228 |
+
for (size_t i = tail; i < new_entries.size(); i++) {
|
229 |
+
new_entries[i].msg.resize(256);
|
230 |
+
}
|
231 |
+
|
232 |
+
entries = std::move(new_entries);
|
233 |
+
}
|
234 |
+
|
235 |
+
cv.notify_one();
|
236 |
+
}
|
237 |
+
|
238 |
+
void resume() {
|
239 |
+
std::lock_guard<std::mutex> lock(mtx);
|
240 |
+
|
241 |
+
if (running) {
|
242 |
+
return;
|
243 |
+
}
|
244 |
+
|
245 |
+
running = true;
|
246 |
+
|
247 |
+
thrd = std::thread([this]() {
|
248 |
+
while (true) {
|
249 |
+
{
|
250 |
+
std::unique_lock<std::mutex> lock(mtx);
|
251 |
+
cv.wait(lock, [this]() { return head != tail; });
|
252 |
+
|
253 |
+
cur = entries[head];
|
254 |
+
|
255 |
+
head = (head + 1) % entries.size();
|
256 |
+
}
|
257 |
+
|
258 |
+
if (cur.is_end) {
|
259 |
+
break;
|
260 |
+
}
|
261 |
+
|
262 |
+
cur.print(); // stdout and stderr
|
263 |
+
|
264 |
+
if (file) {
|
265 |
+
cur.print(file);
|
266 |
+
}
|
267 |
+
}
|
268 |
+
});
|
269 |
+
}
|
270 |
+
|
271 |
+
void pause() {
|
272 |
+
{
|
273 |
+
std::lock_guard<std::mutex> lock(mtx);
|
274 |
+
|
275 |
+
if (!running) {
|
276 |
+
return;
|
277 |
+
}
|
278 |
+
|
279 |
+
running = false;
|
280 |
+
|
281 |
+
// push an entry to signal the worker thread to stop
|
282 |
+
{
|
283 |
+
auto & entry = entries[tail];
|
284 |
+
entry.is_end = true;
|
285 |
+
|
286 |
+
tail = (tail + 1) % entries.size();
|
287 |
+
}
|
288 |
+
|
289 |
+
cv.notify_one();
|
290 |
+
}
|
291 |
+
|
292 |
+
thrd.join();
|
293 |
+
}
|
294 |
+
|
295 |
+
void set_file(const char * path) {
|
296 |
+
pause();
|
297 |
+
|
298 |
+
if (file) {
|
299 |
+
fclose(file);
|
300 |
+
}
|
301 |
+
|
302 |
+
if (path) {
|
303 |
+
file = fopen(path, "w");
|
304 |
+
} else {
|
305 |
+
file = nullptr;
|
306 |
+
}
|
307 |
+
|
308 |
+
resume();
|
309 |
+
}
|
310 |
+
|
311 |
+
void set_colors(bool colors) {
|
312 |
+
pause();
|
313 |
+
|
314 |
+
if (colors) {
|
315 |
+
g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT;
|
316 |
+
g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD;
|
317 |
+
g_col[COMMON_LOG_COL_RED] = LOG_COL_RED;
|
318 |
+
g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN;
|
319 |
+
g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW;
|
320 |
+
g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE;
|
321 |
+
g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA;
|
322 |
+
g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
|
323 |
+
g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
|
324 |
+
} else {
|
325 |
+
for (size_t i = 0; i < g_col.size(); i++) {
|
326 |
+
g_col[i] = "";
|
327 |
+
}
|
328 |
+
}
|
329 |
+
|
330 |
+
resume();
|
331 |
+
}
|
332 |
+
|
333 |
+
void set_prefix(bool prefix) {
|
334 |
+
std::lock_guard<std::mutex> lock(mtx);
|
335 |
+
|
336 |
+
this->prefix = prefix;
|
337 |
+
}
|
338 |
+
|
339 |
+
void set_timestamps(bool timestamps) {
|
340 |
+
std::lock_guard<std::mutex> lock(mtx);
|
341 |
+
|
342 |
+
this->timestamps = timestamps;
|
343 |
+
}
|
344 |
+
};
|
345 |
+
|
346 |
+
//
|
347 |
+
// public API
|
348 |
+
//
|
349 |
+
|
350 |
+
struct common_log * common_log_init() {
|
351 |
+
return new common_log;
|
352 |
+
}
|
353 |
+
|
354 |
+
struct common_log * common_log_main() {
|
355 |
+
static struct common_log log;
|
356 |
+
|
357 |
+
return &log;
|
358 |
+
}
|
359 |
+
|
360 |
+
void common_log_pause(struct common_log * log) {
|
361 |
+
log->pause();
|
362 |
+
}
|
363 |
+
|
364 |
+
void common_log_resume(struct common_log * log) {
|
365 |
+
log->resume();
|
366 |
+
}
|
367 |
+
|
368 |
+
void common_log_free(struct common_log * log) {
|
369 |
+
delete log;
|
370 |
+
}
|
371 |
+
|
372 |
+
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) {
|
373 |
+
va_list args;
|
374 |
+
va_start(args, fmt);
|
375 |
+
log->add(level, fmt, args);
|
376 |
+
va_end(args);
|
377 |
+
}
|
378 |
+
|
379 |
+
void common_log_set_file(struct common_log * log, const char * file) {
|
380 |
+
log->set_file(file);
|
381 |
+
}
|
382 |
+
|
383 |
+
void common_log_set_colors(struct common_log * log, bool colors) {
|
384 |
+
log->set_colors(colors);
|
385 |
+
}
|
386 |
+
|
387 |
+
void common_log_set_prefix(struct common_log * log, bool prefix) {
|
388 |
+
log->set_prefix(prefix);
|
389 |
+
}
|
390 |
+
|
391 |
+
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
|
392 |
+
log->set_timestamps(timestamps);
|
393 |
+
}
|
common/log.h
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "ggml.h" // for ggml_log_level
|
4 |
+
|
5 |
+
#define LOG_CLR_TO_EOL "\033[K\r"
|
6 |
+
#define LOG_COL_DEFAULT "\033[0m"
|
7 |
+
#define LOG_COL_BOLD "\033[1m"
|
8 |
+
#define LOG_COL_RED "\033[31m"
|
9 |
+
#define LOG_COL_GREEN "\033[32m"
|
10 |
+
#define LOG_COL_YELLOW "\033[33m"
|
11 |
+
#define LOG_COL_BLUE "\033[34m"
|
12 |
+
#define LOG_COL_MAGENTA "\033[35m"
|
13 |
+
#define LOG_COL_CYAN "\033[36m"
|
14 |
+
#define LOG_COL_WHITE "\033[37m"
|
15 |
+
|
16 |
+
#ifndef __GNUC__
|
17 |
+
# define LOG_ATTRIBUTE_FORMAT(...)
|
18 |
+
#elif defined(__MINGW32__) && !defined(__clang__)
|
19 |
+
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
|
20 |
+
#else
|
21 |
+
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
|
22 |
+
#endif
|
23 |
+
|
24 |
+
#define LOG_DEFAULT_DEBUG 1
|
25 |
+
#define LOG_DEFAULT_LLAMA 0
|
26 |
+
|
27 |
+
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
|
28 |
+
// set via common_log_set_verbosity()
|
29 |
+
extern int common_log_verbosity_thold;
|
30 |
+
|
31 |
+
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
|
32 |
+
|
33 |
+
// the common_log uses an internal worker thread to print/write log messages
|
34 |
+
// when the worker thread is paused, incoming log messages are discarded
|
35 |
+
struct common_log;
|
36 |
+
|
37 |
+
struct common_log * common_log_init();
|
38 |
+
struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
|
39 |
+
void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe
|
40 |
+
void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
|
41 |
+
void common_log_free (struct common_log * log);
|
42 |
+
|
43 |
+
LOG_ATTRIBUTE_FORMAT(3, 4)
|
44 |
+
void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...);
|
45 |
+
|
46 |
+
// defaults: file = NULL, colors = false, prefix = false, timestamps = false
|
47 |
+
//
|
48 |
+
// regular log output:
|
49 |
+
//
|
50 |
+
// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
51 |
+
// llm_load_tensors: ggml ctx size = 0.27 MiB
|
52 |
+
// llm_load_tensors: offloading 32 repeating layers to GPU
|
53 |
+
// llm_load_tensors: offloading non-repeating layers to GPU
|
54 |
+
//
|
55 |
+
// with prefix = true, timestamps = true, the log output will look like this:
|
56 |
+
//
|
57 |
+
// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
|
58 |
+
// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB
|
59 |
+
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
|
60 |
+
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
|
61 |
+
//
|
62 |
+
// I - info (stdout, V = 0)
|
63 |
+
// W - warning (stderr, V = 0)
|
64 |
+
// E - error (stderr, V = 0)
|
65 |
+
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
|
66 |
+
//
|
67 |
+
|
68 |
+
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
|
69 |
+
void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe
|
70 |
+
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
|
71 |
+
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
|
72 |
+
|
73 |
+
// helper macros for logging
|
74 |
+
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
|
75 |
+
//
|
76 |
+
// for example:
|
77 |
+
//
|
78 |
+
// LOG_DBG("this is a debug message: %d\n", expensive_function());
|
79 |
+
//
|
80 |
+
// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold
|
81 |
+
//
|
82 |
+
|
83 |
+
#define LOG_TMPL(level, verbosity, ...) \
|
84 |
+
do { \
|
85 |
+
if ((verbosity) <= common_log_verbosity_thold) { \
|
86 |
+
common_log_add(common_log_main(), (level), __VA_ARGS__); \
|
87 |
+
} \
|
88 |
+
} while (0)
|
89 |
+
|
90 |
+
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
|
91 |
+
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
|
92 |
+
|
93 |
+
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
|
94 |
+
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
|
95 |
+
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
|
96 |
+
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
|
97 |
+
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
|
98 |
+
|
99 |
+
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
|
100 |
+
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
|
101 |
+
#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
|
102 |
+
#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
|
103 |
+
#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
|
common/minja/chat-template.hpp
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
Copyright 2024 Google LLC
|
3 |
+
|
4 |
+
Use of this source code is governed by an MIT-style
|
5 |
+
license that can be found in the LICENSE file or at
|
6 |
+
https://opensource.org/licenses/MIT.
|
7 |
+
*/
|
8 |
+
// SPDX-License-Identifier: MIT
|
9 |
+
#pragma once
|
10 |
+
|
11 |
+
#include "minja.hpp"
|
12 |
+
#include <json.hpp>
|
13 |
+
#include <string>
|
14 |
+
#include <vector>
|
15 |
+
|
16 |
+
using json = nlohmann::ordered_json;
|
17 |
+
|
18 |
+
namespace minja {
|
19 |
+
|
20 |
+
struct chat_template_caps {
|
21 |
+
bool supports_tools = false;
|
22 |
+
bool supports_tool_calls = false;
|
23 |
+
bool supports_tool_responses = false;
|
24 |
+
bool supports_system_role = false;
|
25 |
+
bool supports_parallel_tool_calls = false;
|
26 |
+
bool supports_tool_call_id = false;
|
27 |
+
// meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
|
28 |
+
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
|
29 |
+
bool requires_object_arguments = false;
|
30 |
+
// CohereForAI/c4ai-command-r-plus simple variant
|
31 |
+
bool requires_non_null_content = false;
|
32 |
+
// MiniMaxAI/MiniMax-Text-01 special
|
33 |
+
bool requires_typed_content = false;
|
34 |
+
};
|
35 |
+
|
36 |
+
struct chat_template_inputs {
|
37 |
+
nlohmann::ordered_json messages;
|
38 |
+
nlohmann::ordered_json tools;
|
39 |
+
bool add_generation_prompt = true;
|
40 |
+
nlohmann::ordered_json extra_context;
|
41 |
+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
42 |
+
};
|
43 |
+
|
44 |
+
struct chat_template_options {
|
45 |
+
bool apply_polyfills = true;
|
46 |
+
bool use_bos_token = true;
|
47 |
+
bool use_eos_token = true;
|
48 |
+
bool define_strftime_now = true;
|
49 |
+
|
50 |
+
bool polyfill_tools = true;
|
51 |
+
bool polyfill_tool_call_examples = true;
|
52 |
+
bool polyfill_tool_calls = true;
|
53 |
+
bool polyfill_tool_responses = true;
|
54 |
+
bool polyfill_system_role = true;
|
55 |
+
bool polyfill_object_arguments = true;
|
56 |
+
bool polyfill_typed_content = true;
|
57 |
+
};
|
58 |
+
|
59 |
+
class chat_template {
|
60 |
+
|
61 |
+
private:
|
62 |
+
chat_template_caps caps_;
|
63 |
+
std::string source_;
|
64 |
+
std::string bos_token_;
|
65 |
+
std::string eos_token_;
|
66 |
+
std::shared_ptr<minja::TemplateNode> template_root_;
|
67 |
+
std::string tool_call_example_;
|
68 |
+
|
69 |
+
std::string try_raw_render(
|
70 |
+
const nlohmann::ordered_json & messages,
|
71 |
+
const nlohmann::ordered_json & tools,
|
72 |
+
bool add_generation_prompt,
|
73 |
+
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
|
74 |
+
{
|
75 |
+
try {
|
76 |
+
chat_template_inputs inputs;
|
77 |
+
inputs.messages = messages;
|
78 |
+
inputs.tools = tools;
|
79 |
+
inputs.add_generation_prompt = add_generation_prompt;
|
80 |
+
inputs.extra_context = extra_context;
|
81 |
+
// Use fixed date for tests
|
82 |
+
inputs.now = std::chrono::system_clock::from_time_t(0);
|
83 |
+
|
84 |
+
chat_template_options opts;
|
85 |
+
opts.apply_polyfills = false;
|
86 |
+
|
87 |
+
auto prompt = apply(inputs, opts);
|
88 |
+
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
|
89 |
+
return prompt;
|
90 |
+
} catch (const std::exception & e) {
|
91 |
+
// fprintf(stderr, "try_raw_render error: %s\n", e.what());
|
92 |
+
return "";
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
public:
|
97 |
+
|
98 |
+
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
|
99 |
+
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
|
100 |
+
{
|
101 |
+
template_root_ = minja::Parser::parse(source_, {
|
102 |
+
/* .trim_blocks = */ true,
|
103 |
+
/* .lstrip_blocks = */ true,
|
104 |
+
/* .keep_trailing_newline = */ false,
|
105 |
+
});
|
106 |
+
|
107 |
+
auto contains = [](const std::string & haystack, const std::string & needle) {
|
108 |
+
return haystack.find(needle) != std::string::npos;
|
109 |
+
};
|
110 |
+
|
111 |
+
const std::string user_needle = "<User Needle>";
|
112 |
+
const std::string sys_needle = "<System Needle>";
|
113 |
+
const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
|
114 |
+
const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
|
115 |
+
|
116 |
+
caps_.requires_typed_content =
|
117 |
+
!contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
|
118 |
+
&& contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
|
119 |
+
|
120 |
+
const auto dummy_user_msg = caps_.requires_typed_content
|
121 |
+
? dummy_typed_user_msg
|
122 |
+
: dummy_str_user_msg;
|
123 |
+
const json needle_system_msg = {
|
124 |
+
{"role", "system"},
|
125 |
+
{"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
|
126 |
+
};
|
127 |
+
|
128 |
+
caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
|
129 |
+
|
130 |
+
auto out = try_raw_render(json::array({
|
131 |
+
dummy_user_msg
|
132 |
+
}), json::array({
|
133 |
+
{
|
134 |
+
{"name", "some_tool"},
|
135 |
+
{"type", "function"},
|
136 |
+
{"function", {
|
137 |
+
{"name", "some_tool"},
|
138 |
+
{"description", "Some tool."},
|
139 |
+
{"parameters", {
|
140 |
+
{"type", "object"},
|
141 |
+
{"properties", {
|
142 |
+
{"arg", {
|
143 |
+
{"type", "string"},
|
144 |
+
{"description", "Some argument."},
|
145 |
+
}},
|
146 |
+
}},
|
147 |
+
{"required", json::array({ "arg" })},
|
148 |
+
}},
|
149 |
+
}},
|
150 |
+
},
|
151 |
+
}), false);
|
152 |
+
caps_.supports_tools = contains(out, "some_tool");
|
153 |
+
|
154 |
+
auto make_tool_calls_msg = [&](const json & tool_calls) {
|
155 |
+
return json {
|
156 |
+
{"role", "assistant"},
|
157 |
+
{"content", nullptr},
|
158 |
+
{"tool_calls", tool_calls},
|
159 |
+
};
|
160 |
+
};
|
161 |
+
auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
|
162 |
+
return json {
|
163 |
+
{"id", "call_1___"},
|
164 |
+
{"type", "function"},
|
165 |
+
{"function", {
|
166 |
+
{"arguments", arguments},
|
167 |
+
{"name", tool_name},
|
168 |
+
}},
|
169 |
+
};
|
170 |
+
};
|
171 |
+
const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
|
172 |
+
|
173 |
+
// Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
|
174 |
+
out = try_raw_render(json::array({
|
175 |
+
dummy_user_msg,
|
176 |
+
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
|
177 |
+
}), {}, false);
|
178 |
+
auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
179 |
+
out = try_raw_render(json::array({
|
180 |
+
dummy_user_msg,
|
181 |
+
make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
|
182 |
+
}), {}, false);
|
183 |
+
auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
|
184 |
+
|
185 |
+
caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
|
186 |
+
caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
|
187 |
+
auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
|
188 |
+
auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
|
189 |
+
caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
|
190 |
+
|
191 |
+
if (caps_.supports_tool_calls) {
|
192 |
+
auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
|
193 |
+
auto tc1 = make_tool_call("test_tool1", dummy_args);
|
194 |
+
auto tc2 = make_tool_call("test_tool2", dummy_args);
|
195 |
+
auto out = try_raw_render(json::array({
|
196 |
+
dummy_user_msg,
|
197 |
+
make_tool_calls_msg(json::array({tc1, tc2})),
|
198 |
+
}), {}, false);
|
199 |
+
caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
|
200 |
+
|
201 |
+
out = try_raw_render(json::array({
|
202 |
+
dummy_user_msg,
|
203 |
+
make_tool_calls_msg(json::array({tc1})),
|
204 |
+
{
|
205 |
+
{"role", "tool"},
|
206 |
+
{"name", "test_tool1"},
|
207 |
+
{"content", "Some response!"},
|
208 |
+
{"tool_call_id", "call_911_"},
|
209 |
+
}
|
210 |
+
}), {}, false);
|
211 |
+
caps_.supports_tool_responses = contains(out, "Some response!");
|
212 |
+
caps_.supports_tool_call_id = contains(out, "call_911_");
|
213 |
+
}
|
214 |
+
|
215 |
+
try {
|
216 |
+
if (!caps_.supports_tools) {
|
217 |
+
const json user_msg {
|
218 |
+
{"role", "user"},
|
219 |
+
{"content", "Hey"},
|
220 |
+
};
|
221 |
+
const json args {
|
222 |
+
{"arg1", "some_value"},
|
223 |
+
};
|
224 |
+
const json tool_call_msg {
|
225 |
+
{"role", "assistant"},
|
226 |
+
{"content", nullptr},
|
227 |
+
{"tool_calls", json::array({
|
228 |
+
{
|
229 |
+
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
|
230 |
+
{"id", "call_1___"},
|
231 |
+
{"type", "function"},
|
232 |
+
{"function", {
|
233 |
+
{"name", "tool_name"},
|
234 |
+
{"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
|
235 |
+
}},
|
236 |
+
},
|
237 |
+
})},
|
238 |
+
};
|
239 |
+
std::string prefix, full;
|
240 |
+
{
|
241 |
+
chat_template_inputs inputs;
|
242 |
+
inputs.messages = json::array({user_msg});
|
243 |
+
inputs.add_generation_prompt = true;
|
244 |
+
prefix = apply(inputs);
|
245 |
+
}
|
246 |
+
{
|
247 |
+
chat_template_inputs inputs;
|
248 |
+
inputs.messages = json::array({user_msg, tool_call_msg});
|
249 |
+
inputs.add_generation_prompt = false;
|
250 |
+
full = apply(inputs);
|
251 |
+
}
|
252 |
+
auto eos_pos_last = full.rfind(eos_token_);
|
253 |
+
if (eos_pos_last == prefix.size() - eos_token_.size() ||
|
254 |
+
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
|
255 |
+
full = full.substr(0, eos_pos_last);
|
256 |
+
}
|
257 |
+
size_t common_prefix_length = 0;
|
258 |
+
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
|
259 |
+
if (prefix[i] != full[i]) {
|
260 |
+
break;
|
261 |
+
}
|
262 |
+
if (prefix[i] == '<') {
|
263 |
+
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
|
264 |
+
// but it removes thinking tags for past messages.
|
265 |
+
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
|
266 |
+
continue;
|
267 |
+
}
|
268 |
+
common_prefix_length = i + 1;
|
269 |
+
}
|
270 |
+
auto example = full.substr(common_prefix_length);
|
271 |
+
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
|
272 |
+
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
|
273 |
+
} else {
|
274 |
+
tool_call_example_ = example;
|
275 |
+
}
|
276 |
+
}
|
277 |
+
} catch (const std::exception & e) {
|
278 |
+
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
|
279 |
+
}
|
280 |
+
}
|
281 |
+
|
282 |
+
const std::string & source() const { return source_; }
|
283 |
+
const std::string & bos_token() const { return bos_token_; }
|
284 |
+
const std::string & eos_token() const { return eos_token_; }
|
285 |
+
const chat_template_caps & original_caps() const { return caps_; }
|
286 |
+
|
287 |
+
// Deprecated, please use the form with chat_template_inputs and chat_template_options
|
288 |
+
std::string apply(
|
289 |
+
const nlohmann::ordered_json & messages,
|
290 |
+
const nlohmann::ordered_json & tools,
|
291 |
+
bool add_generation_prompt,
|
292 |
+
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
|
293 |
+
bool apply_polyfills = true)
|
294 |
+
{
|
295 |
+
fprintf(stderr, "[%s] Deprecated!\n", __func__);
|
296 |
+
chat_template_inputs inputs;
|
297 |
+
inputs.messages = messages;
|
298 |
+
inputs.tools = tools;
|
299 |
+
inputs.add_generation_prompt = add_generation_prompt;
|
300 |
+
inputs.extra_context = extra_context;
|
301 |
+
inputs.now = std::chrono::system_clock::now();
|
302 |
+
|
303 |
+
chat_template_options opts;
|
304 |
+
opts.apply_polyfills = apply_polyfills;
|
305 |
+
|
306 |
+
return apply(inputs, opts);
|
307 |
+
}
|
308 |
+
|
309 |
+
std::string apply(
|
310 |
+
const chat_template_inputs & inputs,
|
311 |
+
const chat_template_options & opts = chat_template_options()) const
|
312 |
+
{
|
313 |
+
json actual_messages;
|
314 |
+
|
315 |
+
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
316 |
+
auto has_tool_calls = false;
|
317 |
+
auto has_tool_responses = false;
|
318 |
+
auto has_string_content = false;
|
319 |
+
for (const auto & message : inputs.messages) {
|
320 |
+
if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
|
321 |
+
has_tool_calls = true;
|
322 |
+
}
|
323 |
+
if (message.contains("role") && message["role"] == "tool") {
|
324 |
+
has_tool_responses = true;
|
325 |
+
}
|
326 |
+
if (message.contains("content") && message["content"].is_string()) {
|
327 |
+
has_string_content = true;
|
328 |
+
}
|
329 |
+
}
|
330 |
+
|
331 |
+
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
|
332 |
+
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
|
333 |
+
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
|
334 |
+
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
|
335 |
+
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
|
336 |
+
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
|
337 |
+
auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
|
338 |
+
|
339 |
+
auto needs_polyfills = opts.apply_polyfills && (false
|
340 |
+
|| polyfill_system_role
|
341 |
+
|| polyfill_tools
|
342 |
+
|| polyfill_tool_calls
|
343 |
+
|| polyfill_tool_responses
|
344 |
+
|| polyfill_object_arguments
|
345 |
+
|| polyfill_typed_content
|
346 |
+
);
|
347 |
+
|
348 |
+
if (needs_polyfills) {
|
349 |
+
actual_messages = json::array();
|
350 |
+
|
351 |
+
auto add_message = [&](const json & msg) {
|
352 |
+
if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
|
353 |
+
actual_messages.push_back({
|
354 |
+
{"role", msg.at("role")},
|
355 |
+
{"content", {{
|
356 |
+
{"type", "text"},
|
357 |
+
{"text", msg.at("content")},
|
358 |
+
}}},
|
359 |
+
});
|
360 |
+
} else {
|
361 |
+
actual_messages.push_back(msg);
|
362 |
+
}
|
363 |
+
};
|
364 |
+
|
365 |
+
std::string pending_system;
|
366 |
+
auto flush_sys = [&]() {
|
367 |
+
if (!pending_system.empty()) {
|
368 |
+
add_message({
|
369 |
+
{"role", "user"},
|
370 |
+
{"content", pending_system},
|
371 |
+
});
|
372 |
+
pending_system.clear();
|
373 |
+
}
|
374 |
+
};
|
375 |
+
|
376 |
+
json adjusted_messages;
|
377 |
+
if (polyfill_tools) {
|
378 |
+
adjusted_messages = add_system(inputs.messages,
|
379 |
+
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
|
380 |
+
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
|
381 |
+
} else {
|
382 |
+
adjusted_messages = inputs.messages;
|
383 |
+
}
|
384 |
+
|
385 |
+
for (const auto & message_ : adjusted_messages) {
|
386 |
+
auto message = message_;
|
387 |
+
if (!message.contains("role") || !message.contains("content")) {
|
388 |
+
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
|
389 |
+
}
|
390 |
+
std::string role = message.at("role");
|
391 |
+
|
392 |
+
if (message.contains("tool_calls")) {
|
393 |
+
if (polyfill_object_arguments || polyfill_tool_calls) {
|
394 |
+
for (auto & tool_call : message.at("tool_calls")) {
|
395 |
+
if (tool_call["type"] == "function") {
|
396 |
+
auto & function = tool_call.at("function");
|
397 |
+
auto & arguments = function.at("arguments");
|
398 |
+
if (arguments.is_string()) {
|
399 |
+
try {
|
400 |
+
arguments = json::parse(arguments.get<std::string>());
|
401 |
+
} catch (const std::exception & ecvt) {
|
402 |
+
fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
|
403 |
+
}
|
404 |
+
}
|
405 |
+
}
|
406 |
+
}
|
407 |
+
}
|
408 |
+
if (polyfill_tool_calls) {
|
409 |
+
auto content = message.at("content");
|
410 |
+
auto tool_calls = json::array();
|
411 |
+
for (const auto & tool_call : message.at("tool_calls")) {
|
412 |
+
if (tool_call.at("type") != "function") {
|
413 |
+
continue;
|
414 |
+
}
|
415 |
+
const auto & function = tool_call.at("function");
|
416 |
+
auto tc = json {
|
417 |
+
{"name", function.at("name")},
|
418 |
+
{"arguments", function.at("arguments")},
|
419 |
+
};
|
420 |
+
if (tool_call.contains("id")) {
|
421 |
+
tc["id"] = tool_call["id"];
|
422 |
+
}
|
423 |
+
tool_calls.push_back(tc);
|
424 |
+
}
|
425 |
+
auto obj = json {
|
426 |
+
{"tool_calls", tool_calls},
|
427 |
+
};
|
428 |
+
if (!content.is_null() && content != "") {
|
429 |
+
obj["content"] = content;
|
430 |
+
}
|
431 |
+
message["content"] = obj.dump(2);
|
432 |
+
message.erase("tool_calls");
|
433 |
+
}
|
434 |
+
}
|
435 |
+
if (polyfill_tool_responses && role == "tool") {
|
436 |
+
message["role"] = "user";
|
437 |
+
auto obj = json {
|
438 |
+
{"tool_response", {
|
439 |
+
{"content", message.at("content")},
|
440 |
+
}},
|
441 |
+
};
|
442 |
+
if (message.contains("name")) {
|
443 |
+
obj["tool_response"]["name"] = message.at("name");
|
444 |
+
}
|
445 |
+
if (message.contains("tool_call_id")) {
|
446 |
+
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
|
447 |
+
}
|
448 |
+
message["content"] = obj.dump(2);
|
449 |
+
message.erase("name");
|
450 |
+
}
|
451 |
+
|
452 |
+
if (!message["content"].is_null() && polyfill_system_role) {
|
453 |
+
std::string content = message.at("content");
|
454 |
+
if (role == "system") {
|
455 |
+
if (!pending_system.empty()) pending_system += "\n";
|
456 |
+
pending_system += content;
|
457 |
+
continue;
|
458 |
+
} else {
|
459 |
+
if (role == "user") {
|
460 |
+
if (!pending_system.empty()) {
|
461 |
+
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
|
462 |
+
pending_system.clear();
|
463 |
+
}
|
464 |
+
} else {
|
465 |
+
flush_sys();
|
466 |
+
}
|
467 |
+
}
|
468 |
+
}
|
469 |
+
add_message(message);
|
470 |
+
}
|
471 |
+
flush_sys();
|
472 |
+
} else {
|
473 |
+
actual_messages = inputs.messages;
|
474 |
+
}
|
475 |
+
|
476 |
+
auto context = minja::Context::make(json({
|
477 |
+
{"messages", actual_messages},
|
478 |
+
{"add_generation_prompt", inputs.add_generation_prompt},
|
479 |
+
}));
|
480 |
+
context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
|
481 |
+
context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
|
482 |
+
if (opts.define_strftime_now) {
|
483 |
+
auto now = inputs.now;
|
484 |
+
context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
|
485 |
+
args.expectArgs("strftime_now", {1, 1}, {0, 0});
|
486 |
+
auto format = args.args[0].get<std::string>();
|
487 |
+
|
488 |
+
auto time = std::chrono::system_clock::to_time_t(now);
|
489 |
+
auto local_time = *std::localtime(&time);
|
490 |
+
std::ostringstream ss;
|
491 |
+
ss << std::put_time(&local_time, format.c_str());
|
492 |
+
return ss.str();
|
493 |
+
}));
|
494 |
+
}
|
495 |
+
if (!inputs.tools.is_null()) {
|
496 |
+
context->set("tools", minja::Value(inputs.tools));
|
497 |
+
}
|
498 |
+
if (!inputs.extra_context.is_null()) {
|
499 |
+
for (auto & kv : inputs.extra_context.items()) {
|
500 |
+
context->set(kv.key(), minja::Value(kv.value()));
|
501 |
+
}
|
502 |
+
}
|
503 |
+
|
504 |
+
auto ret = template_root_->render(context);
|
505 |
+
// fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
|
506 |
+
// fprintf(stderr, "apply: %s\n\n", ret.c_str());
|
507 |
+
return ret;
|
508 |
+
}
|
509 |
+
|
510 |
+
static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
|
511 |
+
json messages_with_system = messages;
|
512 |
+
|
513 |
+
if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
|
514 |
+
std::string existing_system = messages_with_system.at(0).at("content");
|
515 |
+
messages_with_system[0] = json {
|
516 |
+
{"role", "system"},
|
517 |
+
{"content", existing_system + "\n\n" + system_prompt},
|
518 |
+
};
|
519 |
+
} else {
|
520 |
+
messages_with_system.insert(messages_with_system.begin(), json {
|
521 |
+
{"role", "system"},
|
522 |
+
{"content", system_prompt},
|
523 |
+
});
|
524 |
+
}
|
525 |
+
return messages_with_system;
|
526 |
+
}
|
527 |
+
};
|
528 |
+
|
529 |
+
} // namespace minja
|
common/minja/minja.hpp
ADDED
The diff for this file is too large to render.
See raw diff
|
|
common/ngram-cache.cpp
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "ngram-cache.h"
|
2 |
+
#include "common.h"
|
3 |
+
#include "log.h"
|
4 |
+
|
5 |
+
#include <cinttypes>
|
6 |
+
#include <cstdint>
|
7 |
+
#include <cstdio>
|
8 |
+
#include <fstream>
|
9 |
+
#include <thread>
|
10 |
+
#include <algorithm>
|
11 |
+
|
12 |
+
void common_ngram_cache_update(common_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
|
13 |
+
std::vector<llama_token> & inp, int nnew, bool print_progress) {
|
14 |
+
const int64_t t_start_ms = ggml_time_ms();
|
15 |
+
const int64_t inp_size = inp.size();
|
16 |
+
|
17 |
+
const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
|
18 |
+
int64_t n_done = 0;
|
19 |
+
|
20 |
+
for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
|
21 |
+
const int64_t i_start = std::max(inp_size - nnew, ngram_size);
|
22 |
+
for (int64_t i = i_start; i < inp_size; ++i) {
|
23 |
+
const int64_t ngram_start = i - ngram_size;
|
24 |
+
common_ngram ngram(&inp[ngram_start], ngram_size);
|
25 |
+
const llama_token token = inp[i];
|
26 |
+
|
27 |
+
common_ngram_cache::iterator part_it = ngram_cache.find(ngram);
|
28 |
+
if (part_it == ngram_cache.end()) {
|
29 |
+
common_ngram_cache_part part;
|
30 |
+
part.emplace(token, 1);
|
31 |
+
ngram_cache.emplace(ngram, part);
|
32 |
+
} else {
|
33 |
+
common_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
|
34 |
+
if (token_count_it == part_it->second.end()) {
|
35 |
+
part_it->second.emplace(token, 1);
|
36 |
+
} else {
|
37 |
+
token_count_it->second++;
|
38 |
+
}
|
39 |
+
}
|
40 |
+
++n_done;
|
41 |
+
|
42 |
+
if (print_progress && n_done % 10000000 == 0) {
|
43 |
+
const int64_t t_now_ms = ggml_time_ms();
|
44 |
+
const int64_t eta_ms = (inp_size*(ngram_max-ngram_min+1) - n_done) * (t_now_ms - t_start_ms) / n_done;
|
45 |
+
const int64_t eta_min = eta_ms / (60*1000);
|
46 |
+
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
|
47 |
+
|
48 |
+
fprintf(stderr, "%s: %" PRId64 "/%" PRId64 " done, ETA: %02" PRId64 ":%02" PRId64 "\n", __func__, n_done, n_todo, eta_min, eta_s);
|
49 |
+
}
|
50 |
+
}
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
// Helper function to get a token from the combined, speculative sequence of inp and draft.
|
55 |
+
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
|
56 |
+
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
|
57 |
+
}
|
58 |
+
|
59 |
+
// If sample size or percentage are below these thresholds the draft is aborted early:
|
60 |
+
constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
|
61 |
+
constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
|
62 |
+
constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
|
63 |
+
constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
|
64 |
+
|
65 |
+
// Helper function that tries to draft a token from only the static ngram cache:
|
66 |
+
static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) {
|
67 |
+
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
68 |
+
if (part_static_it == nc_static.end()) {
|
69 |
+
return LLAMA_TOKEN_NULL;
|
70 |
+
}
|
71 |
+
const common_ngram_cache_part part_static = part_static_it->second;
|
72 |
+
|
73 |
+
int max_count_static = 0;
|
74 |
+
int sum_count_static = 0;
|
75 |
+
llama_token max_token = LLAMA_TOKEN_NULL;
|
76 |
+
|
77 |
+
for (std::pair<llama_token, int> token_count_static : part_static) {
|
78 |
+
const llama_token token = token_count_static.first;
|
79 |
+
const int32_t count_static = token_count_static.second;
|
80 |
+
|
81 |
+
if (count_static > max_count_static) {
|
82 |
+
max_token = token;
|
83 |
+
max_count_static = count_static;
|
84 |
+
}
|
85 |
+
sum_count_static += count_static;
|
86 |
+
}
|
87 |
+
|
88 |
+
if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
|
89 |
+
return LLAMA_TOKEN_NULL;
|
90 |
+
}
|
91 |
+
if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
|
92 |
+
return LLAMA_TOKEN_NULL;
|
93 |
+
}
|
94 |
+
return max_token;
|
95 |
+
}
|
96 |
+
|
97 |
+
// Try to draft a token from primary cache (context/dynamic), validate with static cache:
|
98 |
+
static llama_token try_draft(
|
99 |
+
common_ngram_cache & nc_primary, const std::vector<common_ngram> & ngrams_primary, common_ngram_cache_part & part_static,
|
100 |
+
const int * min_sample_size, const int * min_percent) {
|
101 |
+
|
102 |
+
llama_token drafted_token = LLAMA_TOKEN_NULL;
|
103 |
+
|
104 |
+
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) {
|
105 |
+
const common_ngram ngram_primary = ngrams_primary[i];
|
106 |
+
|
107 |
+
common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
|
108 |
+
if (part_primary_it == nc_primary.end()) {
|
109 |
+
continue;
|
110 |
+
}
|
111 |
+
const common_ngram_cache_part part_primary = part_primary_it->second;
|
112 |
+
|
113 |
+
int max_count_primary = 0;
|
114 |
+
int max_count_static = 0;
|
115 |
+
int sum_count_primary = 0;
|
116 |
+
llama_token max_token = LLAMA_TOKEN_NULL;
|
117 |
+
|
118 |
+
for (std::pair<llama_token, int> token_count_primary : part_primary) {
|
119 |
+
const llama_token token = token_count_primary.first;
|
120 |
+
|
121 |
+
common_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
|
122 |
+
|
123 |
+
const int32_t count_primary = token_count_primary.second;
|
124 |
+
const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
|
125 |
+
|
126 |
+
if (count_primary*count_static > max_count_primary*max_count_static) {
|
127 |
+
max_token = token;
|
128 |
+
max_count_primary = count_primary;
|
129 |
+
max_count_static = count_static;
|
130 |
+
}
|
131 |
+
sum_count_primary += count_primary;
|
132 |
+
}
|
133 |
+
|
134 |
+
if (sum_count_primary < min_sample_size[i]) {
|
135 |
+
continue;
|
136 |
+
}
|
137 |
+
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
|
138 |
+
continue;;
|
139 |
+
}
|
140 |
+
drafted_token = max_token;
|
141 |
+
}
|
142 |
+
|
143 |
+
return drafted_token;
|
144 |
+
}
|
145 |
+
|
146 |
+
void common_ngram_cache_draft(
|
147 |
+
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
148 |
+
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
|
149 |
+
) {
|
150 |
+
GGML_ASSERT(draft.size() == 1);
|
151 |
+
const int inp_size = inp.size();
|
152 |
+
|
153 |
+
if (inp_size < LLAMA_NGRAM_STATIC) {
|
154 |
+
return;
|
155 |
+
}
|
156 |
+
|
157 |
+
while ((int) draft.size()-1 < n_draft) {
|
158 |
+
llama_token drafted_token = LLAMA_TOKEN_NULL;
|
159 |
+
|
160 |
+
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
|
161 |
+
common_ngram ngram_static;
|
162 |
+
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
|
163 |
+
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
|
164 |
+
}
|
165 |
+
common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
|
166 |
+
common_ngram_cache_part part_static;
|
167 |
+
if (part_static_it != nc_static.end()) {
|
168 |
+
part_static = part_static_it->second;
|
169 |
+
}
|
170 |
+
|
171 |
+
// cd = context + dynamic
|
172 |
+
std::vector<common_ngram> ngrams_cd;
|
173 |
+
for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
|
174 |
+
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
|
175 |
+
common_ngram ngram_cd;
|
176 |
+
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
|
177 |
+
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
|
178 |
+
}
|
179 |
+
ngrams_cd.push_back(ngram_cd);
|
180 |
+
}
|
181 |
+
if (drafted_token == LLAMA_TOKEN_NULL) {
|
182 |
+
drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
|
183 |
+
}
|
184 |
+
if (drafted_token == LLAMA_TOKEN_NULL) {
|
185 |
+
drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
|
186 |
+
}
|
187 |
+
if (drafted_token == LLAMA_TOKEN_NULL) {
|
188 |
+
drafted_token = try_draft(nc_static, ngram_static);
|
189 |
+
}
|
190 |
+
|
191 |
+
if (drafted_token == LLAMA_TOKEN_NULL) {
|
192 |
+
break;
|
193 |
+
}
|
194 |
+
|
195 |
+
LOG(" - draft candidate: token=%d\n", drafted_token);
|
196 |
+
draft.push_back(drafted_token);
|
197 |
+
}
|
198 |
+
}
|
199 |
+
|
200 |
+
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
201 |
+
std::ofstream file_out(filename, std::ios::binary);
|
202 |
+
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
203 |
+
const common_ngram ngram = item.first;
|
204 |
+
common_ngram_cache_part token_counts = item.second;
|
205 |
+
GGML_ASSERT(!token_counts.empty());
|
206 |
+
const int32_t ntokens = token_counts.size();
|
207 |
+
GGML_ASSERT(ntokens > 0);
|
208 |
+
|
209 |
+
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(common_ngram));
|
210 |
+
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
|
211 |
+
for (std::pair<llama_token, int32_t> item2 : token_counts) {
|
212 |
+
const llama_token token = item2.first;
|
213 |
+
const int32_t count = item2.second;
|
214 |
+
GGML_ASSERT(count > 0);
|
215 |
+
|
216 |
+
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
|
217 |
+
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
218 |
+
}
|
219 |
+
}
|
220 |
+
|
221 |
+
}
|
222 |
+
|
223 |
+
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
224 |
+
std::ifstream hashmap_file(filename, std::ios::binary);
|
225 |
+
if (!hashmap_file) {
|
226 |
+
throw std::ifstream::failure("Unable to open file " + filename);
|
227 |
+
}
|
228 |
+
common_ngram_cache ngram_cache;
|
229 |
+
|
230 |
+
common_ngram ngram;
|
231 |
+
int32_t ntokens;
|
232 |
+
llama_token token;
|
233 |
+
int32_t count;
|
234 |
+
|
235 |
+
char * ngramc = reinterpret_cast<char*>(&ngram);
|
236 |
+
char * ntokensc = reinterpret_cast<char*>(&ntokens);
|
237 |
+
char * tokenc = reinterpret_cast<char*>(&token);
|
238 |
+
char * countc = reinterpret_cast<char*>(&count);
|
239 |
+
while(hashmap_file.read(ngramc, sizeof(common_ngram))) {
|
240 |
+
GGML_ASSERT(!hashmap_file.eof());
|
241 |
+
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
|
242 |
+
GGML_ASSERT(ntokens > 0);
|
243 |
+
common_ngram_cache_part token_counts;
|
244 |
+
|
245 |
+
for (int i = 0; i < ntokens; ++i) {
|
246 |
+
GGML_ASSERT(!hashmap_file.eof());
|
247 |
+
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
|
248 |
+
GGML_ASSERT(!hashmap_file.eof());
|
249 |
+
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
|
250 |
+
GGML_ASSERT(count > 0);
|
251 |
+
token_counts.emplace(token, count);
|
252 |
+
}
|
253 |
+
|
254 |
+
ngram_cache.emplace(ngram, token_counts);
|
255 |
+
}
|
256 |
+
GGML_ASSERT(hashmap_file.eof());
|
257 |
+
|
258 |
+
return ngram_cache;
|
259 |
+
}
|
260 |
+
|
261 |
+
void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) {
|
262 |
+
for (std::pair<common_ngram, common_ngram_cache_part> ngram_part : ngram_cache_add) {
|
263 |
+
const common_ngram ngram = ngram_part.first;
|
264 |
+
common_ngram_cache_part part = ngram_part.second;
|
265 |
+
|
266 |
+
common_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
|
267 |
+
if (part_merged_it == ngram_cache_target.end()) {
|
268 |
+
ngram_cache_target.emplace(ngram, part);
|
269 |
+
continue;
|
270 |
+
}
|
271 |
+
|
272 |
+
for (std::pair<llama_token, int32_t> token_count : part) {
|
273 |
+
const llama_token token = token_count.first;
|
274 |
+
const int32_t count = token_count.second;
|
275 |
+
GGML_ASSERT(count > 0);
|
276 |
+
|
277 |
+
common_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
|
278 |
+
if (token_count_merged_it == part_merged_it->second.end()) {
|
279 |
+
part_merged_it->second.emplace(token, count);
|
280 |
+
continue;
|
281 |
+
}
|
282 |
+
|
283 |
+
token_count_merged_it->second += count;
|
284 |
+
}
|
285 |
+
}
|
286 |
+
}
|
common/ngram-cache.h
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "llama.h"
|
4 |
+
|
5 |
+
#include <unordered_map>
|
6 |
+
#include <string>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
#define LLAMA_NGRAM_MIN 1
|
10 |
+
#define LLAMA_NGRAM_MAX 4
|
11 |
+
#define LLAMA_NGRAM_STATIC 2
|
12 |
+
|
13 |
+
// Data structures to map n-grams to empirical token probabilities:
|
14 |
+
|
15 |
+
struct common_ngram {
|
16 |
+
llama_token tokens[LLAMA_NGRAM_MAX];
|
17 |
+
|
18 |
+
common_ngram() {
|
19 |
+
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
20 |
+
tokens[i] = LLAMA_TOKEN_NULL;
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
common_ngram(const llama_token * input, const int ngram_size) {
|
25 |
+
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
26 |
+
tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL;
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
30 |
+
bool operator==(const common_ngram & other) const {
|
31 |
+
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
|
32 |
+
if (tokens[i] != other.tokens[i]) {
|
33 |
+
return false;
|
34 |
+
}
|
35 |
+
}
|
36 |
+
return true;
|
37 |
+
}
|
38 |
+
};
|
39 |
+
|
40 |
+
struct common_token_hash_function {
|
41 |
+
size_t operator()(const llama_token token) const {
|
42 |
+
// see https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/
|
43 |
+
return token * 11400714819323198485llu;
|
44 |
+
}
|
45 |
+
};
|
46 |
+
|
47 |
+
struct common_ngram_hash_function {
|
48 |
+
size_t operator()(const common_ngram & ngram) const {
|
49 |
+
size_t hash = common_token_hash_function{}(ngram.tokens[0]);
|
50 |
+
for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
|
51 |
+
hash ^= common_token_hash_function{}(ngram.tokens[i]);
|
52 |
+
}
|
53 |
+
return hash;
|
54 |
+
}
|
55 |
+
};
|
56 |
+
|
57 |
+
// token -> number of times token has been seen
|
58 |
+
typedef std::unordered_map<llama_token, int32_t> common_ngram_cache_part;
|
59 |
+
|
60 |
+
// n-gram -> empirical distribution of following tokens
|
61 |
+
typedef std::unordered_map<common_ngram, common_ngram_cache_part, common_ngram_hash_function> common_ngram_cache;
|
62 |
+
|
63 |
+
|
64 |
+
// Update an ngram cache with tokens.
|
65 |
+
// ngram_cache: the cache to modify.
|
66 |
+
// ngram_min/ngram_max: the min/max size of the ngrams to extract from inp_data.
|
67 |
+
// inp_data: the token sequence with which to update ngram_cache.
|
68 |
+
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
|
69 |
+
// print_progress: whether to print progress to stderr.
|
70 |
+
//
|
71 |
+
// In order to get correct results inp_data can ONLY BE APPENDED TO.
|
72 |
+
// Changes in the middle need a complete rebuild.
|
73 |
+
void common_ngram_cache_update(
|
74 |
+
common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
|
75 |
+
|
76 |
+
// Try to draft tokens from ngram caches.
|
77 |
+
// inp: the tokens generated so far.
|
78 |
+
// draft: the token sequence to draft. Expected to initially contain the previously sampled token.
|
79 |
+
// n_draft: maximum number of tokens to add to draft.
|
80 |
+
// ngram_min/gram_max: the min/max size of the ngrams in nc_context and nc_dynamic.
|
81 |
+
// nc_context: ngram cache based on current context.
|
82 |
+
// nc_dynamic: ngram cache based on previous user generations.
|
83 |
+
// nc_static: ngram cache generated from a large text corpus, used for validation.
|
84 |
+
void common_ngram_cache_draft(
|
85 |
+
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
|
86 |
+
common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static);
|
87 |
+
|
88 |
+
// Save an ngram cache to a file.
|
89 |
+
// ngram_cache: the ngram cache to save.
|
90 |
+
// filename: the path under which to save the ngram cache.
|
91 |
+
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
92 |
+
|
93 |
+
// Load an ngram cache saved with common_ngram_cache_save.
|
94 |
+
// filename: the path from which to load the ngram cache.
|
95 |
+
// returns: an ngram cache containing the information saved to filename.
|
96 |
+
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
97 |
+
|
98 |
+
// Merge two ngram caches.
|
99 |
+
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
100 |
+
// ngram_cache_add: the ngram cache to add to ngram_cache_target.
|
101 |
+
void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add);
|
common/sampling.cpp
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "sampling.h"
|
2 |
+
|
3 |
+
#include "common.h"
|
4 |
+
|
5 |
+
#include <cmath>
|
6 |
+
#include <unordered_map>
|
7 |
+
#include <algorithm>
|
8 |
+
|
9 |
+
// the ring buffer works similarly to std::deque, but with a fixed capacity
|
10 |
+
// TODO: deduplicate with llama-impl.h
|
11 |
+
template<typename T>
|
12 |
+
struct ring_buffer {
|
13 |
+
ring_buffer(size_t cap) : capacity(cap), data(cap) {}
|
14 |
+
|
15 |
+
T & front() {
|
16 |
+
if (sz == 0) {
|
17 |
+
throw std::runtime_error("ring buffer is empty");
|
18 |
+
}
|
19 |
+
return data[first];
|
20 |
+
}
|
21 |
+
|
22 |
+
const T & front() const {
|
23 |
+
if (sz == 0) {
|
24 |
+
throw std::runtime_error("ring buffer is empty");
|
25 |
+
}
|
26 |
+
return data[first];
|
27 |
+
}
|
28 |
+
|
29 |
+
T & back() {
|
30 |
+
if (sz == 0) {
|
31 |
+
throw std::runtime_error("ring buffer is empty");
|
32 |
+
}
|
33 |
+
return data[pos];
|
34 |
+
}
|
35 |
+
|
36 |
+
const T & back() const {
|
37 |
+
if (sz == 0) {
|
38 |
+
throw std::runtime_error("ring buffer is empty");
|
39 |
+
}
|
40 |
+
return data[pos];
|
41 |
+
}
|
42 |
+
|
43 |
+
void push_back(const T & value) {
|
44 |
+
if (sz == capacity) {
|
45 |
+
// advance the start when buffer is full
|
46 |
+
first = (first + 1) % capacity;
|
47 |
+
} else {
|
48 |
+
sz++;
|
49 |
+
}
|
50 |
+
data[pos] = value;
|
51 |
+
pos = (pos + 1) % capacity;
|
52 |
+
}
|
53 |
+
|
54 |
+
T pop_front() {
|
55 |
+
if (sz == 0) {
|
56 |
+
throw std::runtime_error("ring buffer is empty");
|
57 |
+
}
|
58 |
+
T value = data[first];
|
59 |
+
first = (first + 1) % capacity;
|
60 |
+
sz--;
|
61 |
+
return value;
|
62 |
+
}
|
63 |
+
|
64 |
+
const T & rat(size_t i) const {
|
65 |
+
if (i >= sz) {
|
66 |
+
throw std::runtime_error("ring buffer: index out of bounds");
|
67 |
+
}
|
68 |
+
return data[(first + sz - i - 1) % capacity];
|
69 |
+
}
|
70 |
+
|
71 |
+
std::vector<T> to_vector() const {
|
72 |
+
std::vector<T> result;
|
73 |
+
result.reserve(sz);
|
74 |
+
for (size_t i = 0; i < sz; i++) {
|
75 |
+
result.push_back(data[(first + i) % capacity]);
|
76 |
+
}
|
77 |
+
return result;
|
78 |
+
}
|
79 |
+
|
80 |
+
void clear() {
|
81 |
+
// here only reset the status of the buffer
|
82 |
+
sz = 0;
|
83 |
+
first = 0;
|
84 |
+
pos = 0;
|
85 |
+
}
|
86 |
+
|
87 |
+
bool empty() const {
|
88 |
+
return sz == 0;
|
89 |
+
}
|
90 |
+
|
91 |
+
size_t size() const {
|
92 |
+
return sz;
|
93 |
+
}
|
94 |
+
|
95 |
+
size_t capacity = 0;
|
96 |
+
size_t sz = 0;
|
97 |
+
size_t first = 0;
|
98 |
+
size_t pos = 0;
|
99 |
+
std::vector<T> data;
|
100 |
+
};
|
101 |
+
|
102 |
+
struct common_sampler {
|
103 |
+
common_params_sampling params;
|
104 |
+
|
105 |
+
struct llama_sampler * grmr;
|
106 |
+
struct llama_sampler * chain;
|
107 |
+
|
108 |
+
ring_buffer<llama_token> prev;
|
109 |
+
|
110 |
+
std::vector<llama_token_data> cur;
|
111 |
+
|
112 |
+
llama_token_data_array cur_p;
|
113 |
+
|
114 |
+
void set_logits(struct llama_context * ctx, int idx) {
|
115 |
+
const auto * logits = llama_get_logits_ith(ctx, idx);
|
116 |
+
|
117 |
+
const llama_model * model = llama_get_model(ctx);
|
118 |
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
119 |
+
|
120 |
+
const int n_vocab = llama_vocab_n_tokens(vocab);
|
121 |
+
|
122 |
+
cur.resize(n_vocab);
|
123 |
+
|
124 |
+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
125 |
+
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
126 |
+
}
|
127 |
+
|
128 |
+
cur_p = { cur.data(), cur.size(), -1, false };
|
129 |
+
}
|
130 |
+
};
|
131 |
+
|
132 |
+
std::string common_params_sampling::print() const {
|
133 |
+
char result[1024];
|
134 |
+
|
135 |
+
snprintf(result, sizeof(result),
|
136 |
+
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
|
137 |
+
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
|
138 |
+
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
|
139 |
+
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
|
140 |
+
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
|
141 |
+
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
|
142 |
+
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
|
143 |
+
mirostat, mirostat_eta, mirostat_tau);
|
144 |
+
|
145 |
+
return std::string(result);
|
146 |
+
}
|
147 |
+
|
148 |
+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
149 |
+
const llama_vocab * vocab = llama_model_get_vocab(model);
|
150 |
+
|
151 |
+
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
152 |
+
|
153 |
+
lparams.no_perf = params.no_perf;
|
154 |
+
|
155 |
+
struct llama_sampler * grmr;
|
156 |
+
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
157 |
+
#ifdef LLAMA_USE_LLGUIDANCE
|
158 |
+
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
159 |
+
#else
|
160 |
+
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
161 |
+
#endif // LLAMA_USE_LLGUIDANCE
|
162 |
+
} else {
|
163 |
+
std::vector<std::string> patterns_at_start;
|
164 |
+
std::vector<std::string> patterns_anywhere;
|
165 |
+
std::vector<llama_token> trigger_tokens;
|
166 |
+
for (const auto & trigger : params.grammar_triggers) {
|
167 |
+
switch (trigger.type) {
|
168 |
+
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
|
169 |
+
{
|
170 |
+
const auto & word = trigger.value;
|
171 |
+
patterns_anywhere.push_back(regex_escape(word));
|
172 |
+
break;
|
173 |
+
}
|
174 |
+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
175 |
+
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
176 |
+
{
|
177 |
+
const auto & pattern = trigger.value;
|
178 |
+
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
179 |
+
break;
|
180 |
+
}
|
181 |
+
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
182 |
+
{
|
183 |
+
const auto token = trigger.token;
|
184 |
+
trigger_tokens.push_back(token);
|
185 |
+
break;
|
186 |
+
}
|
187 |
+
default:
|
188 |
+
GGML_ASSERT(false && "unknown trigger type");
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
std::vector<std::string> trigger_patterns;
|
193 |
+
if (!patterns_at_start.empty()) {
|
194 |
+
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
195 |
+
}
|
196 |
+
if (!patterns_anywhere.empty()) {
|
197 |
+
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
198 |
+
}
|
199 |
+
|
200 |
+
std::vector<const char *> trigger_patterns_c;
|
201 |
+
trigger_patterns_c.reserve(trigger_patterns.size());
|
202 |
+
for (const auto & regex : trigger_patterns) {
|
203 |
+
trigger_patterns_c.push_back(regex.c_str());
|
204 |
+
}
|
205 |
+
|
206 |
+
grmr = params.grammar_lazy
|
207 |
+
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
208 |
+
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
209 |
+
trigger_tokens.data(), trigger_tokens.size())
|
210 |
+
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
211 |
+
}
|
212 |
+
|
213 |
+
auto * result = new common_sampler {
|
214 |
+
/* .params = */ params,
|
215 |
+
/* .grmr = */ grmr,
|
216 |
+
/* .chain = */ llama_sampler_chain_init(lparams),
|
217 |
+
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
218 |
+
/* .cur = */ {},
|
219 |
+
/* .cur_p = */ {},
|
220 |
+
};
|
221 |
+
|
222 |
+
llama_sampler_chain_add(result->chain,
|
223 |
+
llama_sampler_init_logit_bias(
|
224 |
+
llama_vocab_n_tokens(vocab),
|
225 |
+
params.logit_bias.size(),
|
226 |
+
params.logit_bias.data()));
|
227 |
+
|
228 |
+
if (params.mirostat == 0) {
|
229 |
+
if (params.top_n_sigma >= 0) {
|
230 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
231 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
|
232 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
233 |
+
} else {
|
234 |
+
for (const auto & cnstr : params.samplers) {
|
235 |
+
switch (cnstr) {
|
236 |
+
case COMMON_SAMPLER_TYPE_DRY:
|
237 |
+
{
|
238 |
+
std::vector<const char *> c_breakers;
|
239 |
+
c_breakers.reserve(params.dry_sequence_breakers.size());
|
240 |
+
for (const auto & str : params.dry_sequence_breakers) {
|
241 |
+
c_breakers.push_back(str.c_str());
|
242 |
+
}
|
243 |
+
|
244 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
245 |
+
}
|
246 |
+
break;
|
247 |
+
case COMMON_SAMPLER_TYPE_TOP_K:
|
248 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
249 |
+
break;
|
250 |
+
case COMMON_SAMPLER_TYPE_TOP_P:
|
251 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
252 |
+
break;
|
253 |
+
case COMMON_SAMPLER_TYPE_MIN_P:
|
254 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
255 |
+
break;
|
256 |
+
case COMMON_SAMPLER_TYPE_XTC:
|
257 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
258 |
+
break;
|
259 |
+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
260 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
261 |
+
break;
|
262 |
+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
263 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
264 |
+
break;
|
265 |
+
case COMMON_SAMPLER_TYPE_INFILL:
|
266 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
267 |
+
break;
|
268 |
+
case COMMON_SAMPLER_TYPE_PENALTIES:
|
269 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
270 |
+
break;
|
271 |
+
default:
|
272 |
+
GGML_ASSERT(false && "unknown sampler type");
|
273 |
+
}
|
274 |
+
}
|
275 |
+
}
|
276 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
277 |
+
} else if (params.mirostat == 1) {
|
278 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
279 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
280 |
+
} else if (params.mirostat == 2) {
|
281 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
282 |
+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
283 |
+
} else {
|
284 |
+
GGML_ASSERT(false && "unknown mirostat version");
|
285 |
+
}
|
286 |
+
|
287 |
+
return result;
|
288 |
+
}
|
289 |
+
|
290 |
+
void common_sampler_free(struct common_sampler * gsmpl) {
|
291 |
+
if (gsmpl) {
|
292 |
+
llama_sampler_free(gsmpl->grmr);
|
293 |
+
|
294 |
+
llama_sampler_free(gsmpl->chain);
|
295 |
+
|
296 |
+
delete gsmpl;
|
297 |
+
}
|
298 |
+
}
|
299 |
+
|
300 |
+
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
301 |
+
if (accept_grammar) {
|
302 |
+
llama_sampler_accept(gsmpl->grmr, token);
|
303 |
+
}
|
304 |
+
|
305 |
+
llama_sampler_accept(gsmpl->chain, token);
|
306 |
+
|
307 |
+
gsmpl->prev.push_back(token);
|
308 |
+
}
|
309 |
+
|
310 |
+
void common_sampler_reset(struct common_sampler * gsmpl) {
|
311 |
+
llama_sampler_reset(gsmpl->grmr);
|
312 |
+
|
313 |
+
llama_sampler_reset(gsmpl->chain);
|
314 |
+
}
|
315 |
+
|
316 |
+
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
317 |
+
return new common_sampler {
|
318 |
+
/* .params = */ gsmpl->params,
|
319 |
+
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
320 |
+
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
321 |
+
/* .prev = */ gsmpl->prev,
|
322 |
+
/* .cur = */ gsmpl->cur,
|
323 |
+
/* .cur_p = */ gsmpl->cur_p,
|
324 |
+
};
|
325 |
+
}
|
326 |
+
|
327 |
+
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
|
328 |
+
// TODO: measure grammar performance
|
329 |
+
|
330 |
+
if (gsmpl) {
|
331 |
+
llama_perf_sampler_print(gsmpl->chain);
|
332 |
+
}
|
333 |
+
if (ctx) {
|
334 |
+
llama_perf_context_print(ctx);
|
335 |
+
}
|
336 |
+
}
|
337 |
+
|
338 |
+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
339 |
+
gsmpl->set_logits(ctx, idx);
|
340 |
+
|
341 |
+
auto & grmr = gsmpl->grmr;
|
342 |
+
auto & chain = gsmpl->chain;
|
343 |
+
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
344 |
+
|
345 |
+
if (grammar_first) {
|
346 |
+
llama_sampler_apply(grmr, &cur_p);
|
347 |
+
}
|
348 |
+
|
349 |
+
llama_sampler_apply(chain, &cur_p);
|
350 |
+
|
351 |
+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
352 |
+
|
353 |
+
const llama_token id = cur_p.data[cur_p.selected].id;
|
354 |
+
|
355 |
+
if (grammar_first) {
|
356 |
+
return id;
|
357 |
+
}
|
358 |
+
|
359 |
+
// check if it the sampled token fits the grammar
|
360 |
+
{
|
361 |
+
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
362 |
+
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
363 |
+
|
364 |
+
llama_sampler_apply(grmr, &single_token_data_array);
|
365 |
+
|
366 |
+
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
367 |
+
if (is_valid) {
|
368 |
+
return id;
|
369 |
+
}
|
370 |
+
}
|
371 |
+
|
372 |
+
// resampling:
|
373 |
+
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
374 |
+
gsmpl->set_logits(ctx, idx);
|
375 |
+
|
376 |
+
llama_sampler_apply(grmr, &cur_p);
|
377 |
+
llama_sampler_apply(chain, &cur_p);
|
378 |
+
|
379 |
+
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
380 |
+
|
381 |
+
return cur_p.data[cur_p.selected].id;
|
382 |
+
}
|
383 |
+
|
384 |
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
385 |
+
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
386 |
+
|
387 |
+
std::vector<llama_token> result;
|
388 |
+
result.reserve(idxs.size());
|
389 |
+
|
390 |
+
size_t i = 0;
|
391 |
+
for (; i < draft.size(); i++) {
|
392 |
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
393 |
+
|
394 |
+
common_sampler_accept(gsmpl, id, true);
|
395 |
+
|
396 |
+
result.push_back(id);
|
397 |
+
|
398 |
+
if (draft[i] != id) {
|
399 |
+
break;
|
400 |
+
}
|
401 |
+
}
|
402 |
+
|
403 |
+
if (i == draft.size()) {
|
404 |
+
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
405 |
+
|
406 |
+
common_sampler_accept(gsmpl, id, true);
|
407 |
+
|
408 |
+
result.push_back(id);
|
409 |
+
}
|
410 |
+
|
411 |
+
return result;
|
412 |
+
}
|
413 |
+
|
414 |
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
415 |
+
std::vector<int> idxs(draft.size() + 1);
|
416 |
+
for (size_t i = 0; i < idxs.size(); ++i) {
|
417 |
+
idxs[i] = i;
|
418 |
+
}
|
419 |
+
|
420 |
+
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
421 |
+
}
|
422 |
+
|
423 |
+
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
424 |
+
return llama_sampler_get_seed(gsmpl->chain);
|
425 |
+
}
|
426 |
+
|
427 |
+
// helpers
|
428 |
+
|
429 |
+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
|
430 |
+
return &gsmpl->cur_p;
|
431 |
+
}
|
432 |
+
|
433 |
+
llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
434 |
+
return gsmpl->prev.rat(0);
|
435 |
+
}
|
436 |
+
|
437 |
+
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
438 |
+
std::string result = "logits ";
|
439 |
+
|
440 |
+
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
441 |
+
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
442 |
+
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
443 |
+
}
|
444 |
+
|
445 |
+
return result;
|
446 |
+
}
|
447 |
+
|
448 |
+
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
|
449 |
+
n = std::min(n, (int) gsmpl->prev.size());
|
450 |
+
|
451 |
+
if (n <= 0) {
|
452 |
+
return "";
|
453 |
+
}
|
454 |
+
|
455 |
+
std::string result;
|
456 |
+
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
|
457 |
+
|
458 |
+
for (int i = n - 1; i >= 0; i--) {
|
459 |
+
const llama_token id = gsmpl->prev.rat(i);
|
460 |
+
|
461 |
+
GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
|
462 |
+
|
463 |
+
result += common_token_to_piece(ctx_main, id);
|
464 |
+
}
|
465 |
+
|
466 |
+
return result;
|
467 |
+
}
|
468 |
+
|
469 |
+
char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
|
470 |
+
switch (cnstr) {
|
471 |
+
case COMMON_SAMPLER_TYPE_DRY: return 'd';
|
472 |
+
case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
|
473 |
+
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
|
474 |
+
case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
|
475 |
+
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
|
476 |
+
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
|
477 |
+
case COMMON_SAMPLER_TYPE_XTC: return 'x';
|
478 |
+
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
|
479 |
+
case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
|
480 |
+
default : return '?';
|
481 |
+
}
|
482 |
+
}
|
483 |
+
|
484 |
+
std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
|
485 |
+
switch (cnstr) {
|
486 |
+
case COMMON_SAMPLER_TYPE_DRY: return "dry";
|
487 |
+
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
|
488 |
+
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
|
489 |
+
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
|
490 |
+
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
|
491 |
+
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
|
492 |
+
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
|
493 |
+
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
|
494 |
+
case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
|
495 |
+
default : return "";
|
496 |
+
}
|
497 |
+
}
|
498 |
+
|
499 |
+
std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
|
500 |
+
std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
|
501 |
+
{ "dry", COMMON_SAMPLER_TYPE_DRY },
|
502 |
+
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K },
|
503 |
+
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P },
|
504 |
+
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
505 |
+
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P },
|
506 |
+
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
507 |
+
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
508 |
+
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
509 |
+
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
510 |
+
};
|
511 |
+
|
512 |
+
// since samplers names are written multiple ways
|
513 |
+
// make it ready for both system names and input names
|
514 |
+
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
|
515 |
+
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K },
|
516 |
+
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P },
|
517 |
+
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
|
518 |
+
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
519 |
+
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
520 |
+
{ "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
521 |
+
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
522 |
+
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
523 |
+
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
524 |
+
};
|
525 |
+
|
526 |
+
std::vector<common_sampler_type> samplers;
|
527 |
+
samplers.reserve(names.size());
|
528 |
+
|
529 |
+
for (const auto & name : names) {
|
530 |
+
auto sampler = sampler_canonical_name_map.find(name);
|
531 |
+
if (sampler != sampler_canonical_name_map.end()) {
|
532 |
+
samplers.push_back(sampler->second);
|
533 |
+
} else {
|
534 |
+
if (allow_alt_names) {
|
535 |
+
sampler = sampler_alt_name_map.find(name);
|
536 |
+
if (sampler != sampler_alt_name_map.end()) {
|
537 |
+
samplers.push_back(sampler->second);
|
538 |
+
}
|
539 |
+
}
|
540 |
+
}
|
541 |
+
}
|
542 |
+
|
543 |
+
return samplers;
|
544 |
+
}
|
545 |
+
|
546 |
+
std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
|
547 |
+
std::unordered_map<char, common_sampler_type> sampler_name_map = {
|
548 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
|
549 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
|
550 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
|
551 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
|
552 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
|
553 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
|
554 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
|
555 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
|
556 |
+
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
|
557 |
+
};
|
558 |
+
|
559 |
+
std::vector<common_sampler_type> samplers;
|
560 |
+
samplers.reserve(chars.size());
|
561 |
+
|
562 |
+
for (const auto & c : chars) {
|
563 |
+
const auto sampler = sampler_name_map.find(c);
|
564 |
+
if (sampler != sampler_name_map.end()) {
|
565 |
+
samplers.push_back(sampler->second);
|
566 |
+
}
|
567 |
+
}
|
568 |
+
|
569 |
+
return samplers;
|
570 |
+
}
|
common/sampling.h
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "llama.h"
|
4 |
+
|
5 |
+
#include "common.h"
|
6 |
+
|
7 |
+
#include <string>
|
8 |
+
#include <vector>
|
9 |
+
|
10 |
+
// common_sampler extends llama_sampler with additional functionality:
|
11 |
+
//
|
12 |
+
// - grammar support
|
13 |
+
// - custom sampler logic based on the parameters
|
14 |
+
// - history of the last accepted tokens
|
15 |
+
// - performance metrics
|
16 |
+
//
|
17 |
+
// This goal is to have a common implementation of the sampling logic shared across the examples.
|
18 |
+
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
|
19 |
+
// complex (top-k, top-p, etc).
|
20 |
+
//
|
21 |
+
// Another example is related to the grammar. In general, the grammar constraints applied on the full
|
22 |
+
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
|
23 |
+
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
24 |
+
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
25 |
+
//
|
26 |
+
// The common_sampler also maintains a container with the last accepted tokens. In the future, this can
|
27 |
+
// be moved into the core llama library.
|
28 |
+
//
|
29 |
+
// For convenience, the common_sampler also maintains a container with the current candidate tokens.
|
30 |
+
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
31 |
+
//
|
32 |
+
// TODO: measure grammar performance
|
33 |
+
//
|
34 |
+
|
35 |
+
struct common_sampler;
|
36 |
+
|
37 |
+
// llama_sampler API overloads
|
38 |
+
|
39 |
+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
40 |
+
|
41 |
+
void common_sampler_free(struct common_sampler * gsmpl);
|
42 |
+
|
43 |
+
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
44 |
+
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
|
45 |
+
void common_sampler_reset (struct common_sampler * gsmpl);
|
46 |
+
struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
47 |
+
|
48 |
+
// arguments can be nullptr to skip printing
|
49 |
+
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
50 |
+
|
51 |
+
// extended sampling implementation:
|
52 |
+
//
|
53 |
+
// - set logits
|
54 |
+
// - apply the configured sampler chain
|
55 |
+
// - check if the token fits the grammar (if any)
|
56 |
+
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
57 |
+
//
|
58 |
+
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
59 |
+
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
60 |
+
//
|
61 |
+
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
62 |
+
|
63 |
+
// generalized version of common_sampler_sample
|
64 |
+
//
|
65 |
+
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
66 |
+
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
67 |
+
//
|
68 |
+
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
69 |
+
//
|
70 |
+
// is equivalent to
|
71 |
+
//
|
72 |
+
// common_sampler_sample(gsmpl, ctx, idx);
|
73 |
+
// common_sampler_accept(gsmpl, token, true);
|
74 |
+
//
|
75 |
+
// requires: idxs.size() == draft.size() + 1
|
76 |
+
//
|
77 |
+
// returns at least 1 token, up to idxs.size()
|
78 |
+
//
|
79 |
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
80 |
+
|
81 |
+
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
82 |
+
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
83 |
+
|
84 |
+
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
85 |
+
|
86 |
+
// helpers
|
87 |
+
|
88 |
+
// access the internal list of current candidate tokens
|
89 |
+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
|
90 |
+
|
91 |
+
// get the last accepted token
|
92 |
+
llama_token common_sampler_last(const struct common_sampler * gsmpl);
|
93 |
+
|
94 |
+
// print the sampler chain into a string
|
95 |
+
std::string common_sampler_print(const struct common_sampler * gsmpl);
|
96 |
+
|
97 |
+
// get a string representation of the last accepted tokens
|
98 |
+
std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
|
99 |
+
|
100 |
+
char common_sampler_type_to_chr(enum common_sampler_type cnstr);
|
101 |
+
std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
|
102 |
+
|
103 |
+
std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
104 |
+
std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|
105 |
+
|
106 |
+
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
107 |
+
const char * grammar_kind, const char * grammar_data);
|
common/speculative.cpp
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "speculative.h"
|
2 |
+
|
3 |
+
#include "log.h"
|
4 |
+
#include "common.h"
|
5 |
+
#include "sampling.h"
|
6 |
+
|
7 |
+
#include <cstring>
|
8 |
+
#include <algorithm>
|
9 |
+
|
10 |
+
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
11 |
+
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
12 |
+
|
13 |
+
struct common_speculative {
|
14 |
+
struct llama_context * ctx;
|
15 |
+
struct common_sampler * smpl;
|
16 |
+
|
17 |
+
llama_batch batch;
|
18 |
+
llama_tokens prompt;
|
19 |
+
};
|
20 |
+
|
21 |
+
struct common_speculative * common_speculative_init(
|
22 |
+
struct llama_context * ctx_dft) {
|
23 |
+
auto * result = new common_speculative {
|
24 |
+
/* .ctx = */ ctx_dft,
|
25 |
+
/* .smpl = */ nullptr,
|
26 |
+
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
27 |
+
/* .prompt = */ {},
|
28 |
+
};
|
29 |
+
|
30 |
+
// TODO: optimize or pass from outside?
|
31 |
+
#if 0
|
32 |
+
{
|
33 |
+
common_params_sampling params;
|
34 |
+
params.no_perf = false;
|
35 |
+
|
36 |
+
params.top_k = 40;
|
37 |
+
params.top_p = 0.9;
|
38 |
+
|
39 |
+
params.samplers = {
|
40 |
+
COMMON_SAMPLER_TYPE_TOP_K,
|
41 |
+
COMMON_SAMPLER_TYPE_TOP_P,
|
42 |
+
COMMON_SAMPLER_TYPE_INFILL,
|
43 |
+
};
|
44 |
+
|
45 |
+
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
46 |
+
}
|
47 |
+
#else
|
48 |
+
{
|
49 |
+
common_params_sampling params;
|
50 |
+
params.no_perf = false;
|
51 |
+
|
52 |
+
params.top_k = 10;
|
53 |
+
|
54 |
+
params.samplers = {
|
55 |
+
COMMON_SAMPLER_TYPE_TOP_K,
|
56 |
+
};
|
57 |
+
|
58 |
+
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
|
59 |
+
}
|
60 |
+
#endif
|
61 |
+
|
62 |
+
return result;
|
63 |
+
}
|
64 |
+
|
65 |
+
void common_speculative_free(struct common_speculative * spec) {
|
66 |
+
if (spec == nullptr) {
|
67 |
+
return;
|
68 |
+
}
|
69 |
+
|
70 |
+
common_sampler_free(spec->smpl);
|
71 |
+
|
72 |
+
llama_batch_free(spec->batch);
|
73 |
+
|
74 |
+
delete spec;
|
75 |
+
}
|
76 |
+
|
77 |
+
bool common_speculative_are_compatible(
|
78 |
+
const struct llama_context * ctx_tgt,
|
79 |
+
const struct llama_context * ctx_dft) {
|
80 |
+
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
|
81 |
+
const struct llama_model * model_dft = llama_get_model(ctx_dft);
|
82 |
+
|
83 |
+
const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
84 |
+
const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
85 |
+
|
86 |
+
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
87 |
+
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
88 |
+
|
89 |
+
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
|
90 |
+
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
91 |
+
|
92 |
+
if (vocab_type_tgt != vocab_type_dft) {
|
93 |
+
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
|
94 |
+
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
|
95 |
+
return false;
|
96 |
+
}
|
97 |
+
|
98 |
+
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
99 |
+
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
100 |
+
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
|
101 |
+
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
|
102 |
+
LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
|
103 |
+
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
|
104 |
+
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
|
105 |
+
return false;
|
106 |
+
}
|
107 |
+
|
108 |
+
{
|
109 |
+
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
|
110 |
+
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
|
111 |
+
|
112 |
+
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
|
113 |
+
|
114 |
+
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
115 |
+
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
|
116 |
+
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
117 |
+
__func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
118 |
+
return false;
|
119 |
+
}
|
120 |
+
|
121 |
+
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
122 |
+
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
|
123 |
+
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
124 |
+
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
125 |
+
LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
|
126 |
+
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
|
127 |
+
common_token_to_piece(ctx_tgt, i).c_str(),
|
128 |
+
common_token_to_piece(ctx_dft, i).c_str());
|
129 |
+
return false;
|
130 |
+
}
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
return true;
|
135 |
+
}
|
136 |
+
|
137 |
+
llama_tokens common_speculative_gen_draft(
|
138 |
+
struct common_speculative * spec,
|
139 |
+
struct common_speculative_params params,
|
140 |
+
const llama_tokens & prompt_tgt,
|
141 |
+
llama_token id_last) {
|
142 |
+
auto & batch = spec->batch;
|
143 |
+
auto & ctx = spec->ctx;
|
144 |
+
auto & smpl = spec->smpl;
|
145 |
+
auto & prompt = spec->prompt;
|
146 |
+
|
147 |
+
int reuse_i = 0;
|
148 |
+
int reuse_n = 0;
|
149 |
+
|
150 |
+
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
|
151 |
+
|
152 |
+
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
153 |
+
|
154 |
+
// reuse as much as possible from the old draft context
|
155 |
+
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
|
156 |
+
for (int i = 0; i < (int) prompt.size(); ++i) {
|
157 |
+
int cur = 0;
|
158 |
+
while (i_start + cur < (int) prompt_tgt.size() &&
|
159 |
+
i + cur < (int) prompt.size() &&
|
160 |
+
prompt_tgt[i_start + cur] == prompt[i + cur]) {
|
161 |
+
cur++;
|
162 |
+
}
|
163 |
+
|
164 |
+
if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
|
165 |
+
reuse_i = i;
|
166 |
+
reuse_n = cur;
|
167 |
+
}
|
168 |
+
}
|
169 |
+
|
170 |
+
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
|
171 |
+
|
172 |
+
llama_tokens result;
|
173 |
+
result.reserve(params.n_draft);
|
174 |
+
|
175 |
+
if (reuse_n == 0) {
|
176 |
+
llama_kv_cache_clear(ctx);
|
177 |
+
|
178 |
+
prompt.clear();
|
179 |
+
} else {
|
180 |
+
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
181 |
+
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
182 |
+
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
|
183 |
+
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
|
184 |
+
result.push_back(prompt[i]);
|
185 |
+
|
186 |
+
if (params.n_draft <= (int) result.size()) {
|
187 |
+
break;
|
188 |
+
}
|
189 |
+
}
|
190 |
+
|
191 |
+
return result;
|
192 |
+
}
|
193 |
+
|
194 |
+
if (reuse_i > 0) {
|
195 |
+
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
|
196 |
+
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
|
197 |
+
|
198 |
+
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
|
199 |
+
}
|
200 |
+
|
201 |
+
if (reuse_n < (int) prompt.size()) {
|
202 |
+
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
|
203 |
+
|
204 |
+
prompt.erase(prompt.begin() + reuse_n, prompt.end());
|
205 |
+
}
|
206 |
+
}
|
207 |
+
|
208 |
+
// prepare a batch to evaluate any new tokens in the prompt
|
209 |
+
common_batch_clear(batch);
|
210 |
+
|
211 |
+
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
|
212 |
+
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
|
213 |
+
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
|
214 |
+
|
215 |
+
prompt.push_back(prompt_tgt[i]);
|
216 |
+
}
|
217 |
+
|
218 |
+
// we should rarely end-up here during normal decoding
|
219 |
+
if (batch.n_tokens > 0) {
|
220 |
+
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
221 |
+
|
222 |
+
llama_decode(ctx, batch);
|
223 |
+
}
|
224 |
+
|
225 |
+
const llama_pos n_past = prompt.size();
|
226 |
+
|
227 |
+
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
|
228 |
+
|
229 |
+
common_batch_clear(batch);
|
230 |
+
common_batch_add (batch, id_last, n_past, { 0 }, true);
|
231 |
+
|
232 |
+
prompt.push_back(id_last);
|
233 |
+
|
234 |
+
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
|
235 |
+
|
236 |
+
llama_decode(ctx, batch);
|
237 |
+
|
238 |
+
common_sampler_reset(smpl);
|
239 |
+
|
240 |
+
// sample n_draft tokens from the draft model
|
241 |
+
for (int i = 0; i < params.n_draft; ++i) {
|
242 |
+
common_batch_clear(batch);
|
243 |
+
|
244 |
+
common_sampler_sample(smpl, ctx, 0, true);
|
245 |
+
|
246 |
+
const auto * cur_p = common_sampler_get_candidates(smpl);
|
247 |
+
|
248 |
+
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
249 |
+
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
250 |
+
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
251 |
+
}
|
252 |
+
|
253 |
+
// add drafted token for each sequence
|
254 |
+
const llama_token id = cur_p->data[0].id;
|
255 |
+
|
256 |
+
common_sampler_accept(smpl, id, true);
|
257 |
+
|
258 |
+
result.push_back(id);
|
259 |
+
|
260 |
+
if (params.n_draft <= (int) result.size()) {
|
261 |
+
break;
|
262 |
+
}
|
263 |
+
|
264 |
+
// only collect very high-confidence draft tokens
|
265 |
+
if (cur_p->data[0].p < params.p_min) {
|
266 |
+
break;
|
267 |
+
}
|
268 |
+
|
269 |
+
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
270 |
+
|
271 |
+
// evaluate the drafted tokens on the draft model
|
272 |
+
llama_decode(ctx, batch);
|
273 |
+
|
274 |
+
prompt.push_back(id);
|
275 |
+
}
|
276 |
+
|
277 |
+
return result;
|
278 |
+
}
|
common/speculative.h
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "llama.h"
|
4 |
+
#include "common.h"
|
5 |
+
|
6 |
+
struct common_speculative;
|
7 |
+
|
8 |
+
struct common_speculative_params {
|
9 |
+
int n_draft = 16; // max drafted tokens
|
10 |
+
int n_reuse = 256;
|
11 |
+
|
12 |
+
float p_min = 0.75f; // min probability required to accept a token in the draft
|
13 |
+
};
|
14 |
+
|
15 |
+
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
16 |
+
|
17 |
+
void common_speculative_free(struct common_speculative * spec);
|
18 |
+
|
19 |
+
bool common_speculative_are_compatible(
|
20 |
+
const struct llama_context * ctx_tgt,
|
21 |
+
const struct llama_context * ctx_dft);
|
22 |
+
|
23 |
+
// sample up to n_draft tokens and add them to the batch using the draft model
|
24 |
+
llama_tokens common_speculative_gen_draft(
|
25 |
+
struct common_speculative * spec,
|
26 |
+
struct common_speculative_params params,
|
27 |
+
const llama_tokens & prompt,
|
28 |
+
llama_token id_last);
|