Darkknight535 commited on
Commit
1d30d42
·
verified ·
1 Parent(s): 64f6551

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .clang-format +161 -0
  2. .devops/nix/package-gguf-py.nix +36 -0
  3. .devops/nix/python-scripts.nix +66 -0
  4. .editorconfig +50 -0
  5. .gitattributes +20 -0
  6. .github/ISSUE_TEMPLATE/create-new-issue.md +14 -0
  7. .github/workflows/kcpp-build-release-arm64.yaml +87 -0
  8. .github/workflows/kcpp-build-release-linux-cuda12.yaml +34 -0
  9. .github/workflows/kcpp-build-release-linux.yaml +34 -0
  10. .github/workflows/kcpp-build-release-osx.yaml +41 -0
  11. .github/workflows/kcpp-build-release-win-full-cu12.yaml +91 -0
  12. .github/workflows/kcpp-build-release-win-full.yaml +92 -0
  13. .github/workflows/kcpp-build-release-win-oldcpu-full.yaml +91 -0
  14. .gitignore +140 -0
  15. CLINFO_LICENSE +19 -0
  16. CMakeLists.txt +543 -0
  17. LICENSE.md +661 -0
  18. MIT_LICENSE_GGML_LLAMACPP_ONLY +26 -0
  19. Makefile +758 -0
  20. OpenCL.dll +0 -0
  21. README.md +194 -0
  22. Remote-Link.cmd +18 -0
  23. build-info.h +12 -0
  24. build-xcframework.sh +519 -0
  25. clblast.dll +3 -0
  26. colab.ipynb +174 -0
  27. common/arg.cpp +0 -0
  28. common/arg.h +80 -0
  29. common/base64.hpp +392 -0
  30. common/build-info.cpp.in +4 -0
  31. common/chat.cpp +1779 -0
  32. common/chat.h +135 -0
  33. common/common.cpp +2058 -0
  34. common/common.h +681 -0
  35. common/console.cpp +504 -0
  36. common/console.h +19 -0
  37. common/json-schema-to-grammar.cpp +1024 -0
  38. common/json-schema-to-grammar.h +21 -0
  39. common/json.hpp +0 -0
  40. common/llguidance.cpp +270 -0
  41. common/log.cpp +393 -0
  42. common/log.h +103 -0
  43. common/minja/chat-template.hpp +529 -0
  44. common/minja/minja.hpp +0 -0
  45. common/ngram-cache.cpp +286 -0
  46. common/ngram-cache.h +101 -0
  47. common/sampling.cpp +570 -0
  48. common/sampling.h +107 -0
  49. common/speculative.cpp +278 -0
  50. common/speculative.h +28 -0
.clang-format ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ Language: Cpp
3
+ AlignAfterOpenBracket: Align
4
+ AlignArrayOfStructures: Left
5
+ AlignConsecutiveAssignments: AcrossComments
6
+ AlignConsecutiveBitFields: AcrossComments
7
+ AlignConsecutiveDeclarations: AcrossComments
8
+ AlignConsecutiveMacros: AcrossComments
9
+ # AlignConsecutiveShortCaseStatements: AcrossComments
10
+ AlignEscapedNewlines: Left # LeftWithLastLine
11
+ AlignOperands: Align
12
+ AlignTrailingComments:
13
+ Kind: Always
14
+ OverEmptyLines: 1
15
+ AllowAllArgumentsOnNextLine: true
16
+ AllowAllParametersOfDeclarationOnNextLine: false
17
+ # AllowBreakBeforeNoexceptSpecifier: OnlyWithParen
18
+ AllowShortBlocksOnASingleLine: Never
19
+ AllowShortCaseLabelsOnASingleLine: false
20
+ AllowShortFunctionsOnASingleLine: Inline
21
+ AllowShortIfStatementsOnASingleLine: Never
22
+ AllowShortLambdasOnASingleLine: Inline
23
+ AllowShortLoopsOnASingleLine: false
24
+ AlwaysBreakBeforeMultilineStrings: true
25
+ BinPackArguments: true
26
+ BinPackParameters: true # OnePerLine
27
+ BitFieldColonSpacing: Both
28
+ BreakBeforeBraces: Custom # Attach
29
+ BraceWrapping:
30
+ AfterCaseLabel: true
31
+ AfterClass: false
32
+ AfterControlStatement: false
33
+ AfterEnum: false
34
+ AfterFunction: false
35
+ AfterNamespace: false
36
+ AfterObjCDeclaration: false
37
+ AfterStruct: false
38
+ AfterUnion: false
39
+ AfterExternBlock: false
40
+ BeforeCatch: false
41
+ BeforeElse: false
42
+ BeforeLambdaBody: false
43
+ BeforeWhile: false
44
+ IndentBraces: false
45
+ SplitEmptyFunction: false
46
+ SplitEmptyRecord: false
47
+ SplitEmptyNamespace: false
48
+ # BreakAdjacentStringLiterals: true
49
+ BreakAfterAttributes: Never
50
+ BreakBeforeBinaryOperators: None
51
+ BreakBeforeInlineASMColon: OnlyMultiline
52
+ BreakBeforeTernaryOperators: false
53
+ # BreakBinaryOperations: Never
54
+ BreakConstructorInitializers: AfterColon
55
+ # BreakFunctionDefinitionParameters: false
56
+ BreakInheritanceList: AfterComma
57
+ BreakStringLiterals: true
58
+ # BreakTemplateDeclarations: Yes
59
+ ColumnLimit: 120
60
+ CommentPragmas: '^ IWYU pragma:'
61
+ CompactNamespaces: false
62
+ ConstructorInitializerIndentWidth: 4
63
+ ContinuationIndentWidth: 4
64
+ Cpp11BracedListStyle: false
65
+ DerivePointerAlignment: false
66
+ DisableFormat: false
67
+ EmptyLineBeforeAccessModifier: Leave
68
+ EmptyLineAfterAccessModifier: Never
69
+ ExperimentalAutoDetectBinPacking: false
70
+ FixNamespaceComments: true
71
+ IncludeBlocks: Regroup
72
+ IncludeCategories:
73
+ - Regex: '^<.*\.h>'
74
+ Priority: 1
75
+ SortPriority: 0
76
+ - Regex: '^<.*'
77
+ Priority: 2
78
+ SortPriority: 0
79
+ - Regex: '.*'
80
+ Priority: 3
81
+ SortPriority: 0
82
+ IncludeIsMainRegex: '([-_](test|unittest))?$'
83
+ IncludeIsMainSourceRegex: ''
84
+ IndentAccessModifiers: false
85
+ IndentCaseBlocks: true
86
+ IndentCaseLabels: true
87
+ IndentExternBlock: NoIndent
88
+ IndentGotoLabels: false
89
+ IndentPPDirectives: AfterHash
90
+ IndentWidth: 4
91
+ IndentWrappedFunctionNames: false
92
+ InsertBraces: true # NOTE: may lead to incorrect formatting
93
+ InsertNewlineAtEOF: true
94
+ JavaScriptQuotes: Leave
95
+ JavaScriptWrapImports: true
96
+ KeepEmptyLinesAtTheStartOfBlocks: false
97
+ LambdaBodyIndentation: Signature
98
+ LineEnding: LF
99
+ MacroBlockBegin: ''
100
+ MacroBlockEnd: ''
101
+ MaxEmptyLinesToKeep: 1
102
+ NamespaceIndentation: None
103
+ ObjCBinPackProtocolList: Auto
104
+ ObjCBlockIndentWidth: 4
105
+ ObjCSpaceAfterProperty: true
106
+ ObjCSpaceBeforeProtocolList: true
107
+ PPIndentWidth: -1
108
+ PackConstructorInitializers: CurrentLine
109
+ PenaltyBreakAssignment: 2
110
+ PenaltyBreakBeforeFirstCallParameter: 1
111
+ PenaltyBreakComment: 300
112
+ PenaltyBreakFirstLessLess: 120
113
+ PenaltyBreakString: 1000
114
+ PenaltyBreakTemplateDeclaration: 10
115
+ PenaltyExcessCharacter: 1000000
116
+ PenaltyReturnTypeOnItsOwnLine: 200
117
+ PointerAlignment: Middle
118
+ QualifierAlignment: Left
119
+ #QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict']
120
+ RawStringFormats:
121
+ - Language: Cpp
122
+ Delimiters:
123
+ - cc
124
+ - CC
125
+ - cpp
126
+ - Cpp
127
+ - CPP
128
+ - 'c++'
129
+ - 'C++'
130
+ CanonicalDelimiter: ''
131
+ ReferenceAlignment: Middle
132
+ ReflowComments: false # IndentOnly
133
+ SeparateDefinitionBlocks: Always
134
+ SortIncludes: CaseInsensitive
135
+ SortUsingDeclarations: LexicographicNumeric
136
+ SpaceAfterCStyleCast: true
137
+ SpaceAfterLogicalNot: false
138
+ SpaceAfterTemplateKeyword: true
139
+ SpaceBeforeAssignmentOperators: true
140
+ SpaceBeforeCpp11BracedList: false
141
+ SpaceBeforeCtorInitializerColon: true
142
+ SpaceBeforeInheritanceColon: true
143
+ SpaceBeforeParens: ControlStatements
144
+ SpaceBeforeRangeBasedForLoopColon: true
145
+ SpaceInEmptyBlock: false
146
+ SpaceInEmptyParentheses: false
147
+ SpacesBeforeTrailingComments: 2
148
+ SpacesInAngles: Never
149
+ SpacesInContainerLiterals: true
150
+ SpacesInLineCommentPrefix:
151
+ Minimum: 1
152
+ Maximum: -1
153
+ SpacesInParentheses: false
154
+ SpacesInSquareBrackets: false
155
+ SpaceBeforeSquareBrackets: false
156
+ Standard: c++17
157
+ TabWidth: 4
158
+ UseTab: Never
159
+ WhitespaceSensitiveMacros: ['STRINGIZE']
160
+ ...
161
+
.devops/nix/package-gguf-py.nix ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ lib,
3
+ llamaVersion,
4
+ numpy,
5
+ tqdm,
6
+ sentencepiece,
7
+ pyyaml,
8
+ poetry-core,
9
+ buildPythonPackage,
10
+ pytestCheckHook,
11
+ }:
12
+
13
+ buildPythonPackage {
14
+ pname = "gguf";
15
+ version = llamaVersion;
16
+ pyproject = true;
17
+ nativeBuildInputs = [ poetry-core ];
18
+ propagatedBuildInputs = [
19
+ numpy
20
+ tqdm
21
+ sentencepiece
22
+ pyyaml
23
+ ];
24
+ src = lib.cleanSource ../../gguf-py;
25
+ pythonImportsCheck = [
26
+ "numpy"
27
+ "gguf"
28
+ ];
29
+ nativeCheckInputs = [ pytestCheckHook ];
30
+ doCheck = true;
31
+ meta = with lib; {
32
+ description = "Python package for writing binary files in the GGUF format";
33
+ license = licenses.mit;
34
+ maintainers = [ maintainers.ditsuke ];
35
+ };
36
+ }
.devops/nix/python-scripts.nix ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ lib,
3
+ stdenv,
4
+ buildPythonPackage,
5
+ poetry-core,
6
+ mkShell,
7
+ python3Packages,
8
+ gguf-py,
9
+ }@inputs:
10
+
11
+ let
12
+ llama-python-deps = with python3Packages; [
13
+ numpy
14
+ sentencepiece
15
+ transformers
16
+ protobuf
17
+ torchWithoutCuda
18
+ gguf-py
19
+ tqdm
20
+
21
+ # for scripts/compare-llama-bench.py
22
+ gitpython
23
+ tabulate
24
+
25
+ # for examples/pydantic-models-to-grammar-examples.py
26
+ docstring-parser
27
+ pydantic
28
+
29
+ ];
30
+
31
+ llama-python-test-deps = with python3Packages; [
32
+ # Server bench
33
+ matplotlib
34
+
35
+ # server tests
36
+ openai
37
+ pytest
38
+ prometheus-client
39
+ ];
40
+ in
41
+
42
+ buildPythonPackage ({
43
+ pname = "llama-scripts";
44
+ version = "0.0.0";
45
+ pyproject = true;
46
+
47
+ # NOTE: The files filtered out here are not visible in the build sandbox, neither
48
+ # do they affect the output hash. They can be modified without triggering a rebuild.
49
+ src = lib.cleanSourceWith {
50
+ filter =
51
+ name: type:
52
+ let
53
+ any = builtins.any (x: x);
54
+ baseName = builtins.baseNameOf name;
55
+ in
56
+ any [
57
+ (lib.hasSuffix ".py" name)
58
+ (baseName == "README.md")
59
+ (baseName == "pyproject.toml")
60
+ ];
61
+ src = lib.cleanSource ../../.;
62
+ };
63
+ nativeBuildInputs = [ poetry-core ];
64
+ nativeCheckInputs = llama-python-test-deps;
65
+ dependencies = llama-python-deps;
66
+ })
.editorconfig ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://EditorConfig.org
2
+
3
+ # Top-most EditorConfig file
4
+ root = true
5
+
6
+ # Unix-style newlines with a newline ending every file, utf-8 charset
7
+ [*]
8
+ end_of_line = lf
9
+ insert_final_newline = true
10
+ trim_trailing_whitespace = true
11
+ charset = utf-8
12
+ indent_style = space
13
+ indent_size = 4
14
+
15
+ [Makefile]
16
+ indent_style = tab
17
+
18
+ [scripts/*.mk]
19
+ indent_style = tab
20
+
21
+ [prompts/*.txt]
22
+ insert_final_newline = unset
23
+
24
+ [examples/server/public/*]
25
+ indent_size = 2
26
+
27
+ [examples/server/public/deps_*]
28
+ trim_trailing_whitespace = unset
29
+ indent_style = unset
30
+ indent_size = unset
31
+
32
+ [examples/server/deps_*]
33
+ trim_trailing_whitespace = unset
34
+ indent_style = unset
35
+ indent_size = unset
36
+
37
+ [examples/llama.swiftui/llama.swiftui.xcodeproj/*]
38
+ indent_style = tab
39
+
40
+ [examples/cvector-generator/*.txt]
41
+ trim_trailing_whitespace = unset
42
+ insert_final_newline = unset
43
+
44
+ [models/templates/*.jinja]
45
+ indent_style = unset
46
+ indent_size = unset
47
+ end_of_line = unset
48
+ charset = unset
49
+ trim_trailing_whitespace = unset
50
+ insert_final_newline = unset
.gitattributes CHANGED
@@ -33,3 +33,23 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ clblast.dll filter=lfs diff=lfs merge=lfs -text
37
+ cudart64_110.dll filter=lfs diff=lfs merge=lfs -text
38
+ cudart64_12.dll filter=lfs diff=lfs merge=lfs -text
39
+ examples/server/themes/buttons-top/buttons_top.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/server/themes/wild/llamapattern.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/server/themes/wild/wild.png filter=lfs diff=lfs merge=lfs -text
42
+ ggml/src/ggml-vulkan-shaders.cpp filter=lfs diff=lfs merge=lfs -text
43
+ glslc-linux filter=lfs diff=lfs merge=lfs -text
44
+ glslc.exe filter=lfs diff=lfs merge=lfs -text
45
+ lib/clblast.lib filter=lfs diff=lfs merge=lfs -text
46
+ msvcp140.dll filter=lfs diff=lfs merge=lfs -text
47
+ nikogreen.ico filter=lfs diff=lfs merge=lfs -text
48
+ otherarch/sdcpp/vocab.hpp filter=lfs diff=lfs merge=lfs -text
49
+ taesd.embd filter=lfs diff=lfs merge=lfs -text
50
+ taesd_3.embd filter=lfs diff=lfs merge=lfs -text
51
+ taesd_f.embd filter=lfs diff=lfs merge=lfs -text
52
+ taesd_xl.embd filter=lfs diff=lfs merge=lfs -text
53
+ vcruntime140.dll filter=lfs diff=lfs merge=lfs -text
54
+ vulkan-1.dll filter=lfs diff=lfs merge=lfs -text
55
+ winclinfo.exe filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/create-new-issue.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Create New Issue
3
+ about: Please describe the issue in detail
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the Issue**
11
+ A clear and detailed description of what the issue is, and how to duplicate it (if applicable).
12
+
13
+ **Additional Information:**
14
+ Please provide as much relevant information about your setup as possible, such as the Operating System, CPU, GPU, KoboldCpp Version, and relevant logs (helpful to include the launch params from the terminal output, flags and crash logs)
.github/workflows/kcpp-build-release-arm64.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Koboldcpp Linux ARM64
2
+
3
+ on: workflow_dispatch
4
+ env:
5
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
6
+
7
+ jobs:
8
+ linux-arm:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - name: Clone
12
+ id: checkout
13
+ uses: actions/checkout@v3
14
+ with:
15
+ ref: ${{ github.head_ref || github.ref_name }}
16
+
17
+ - name: Install Dependencies
18
+ id: depends
19
+ run: |
20
+ sudo apt-get update
21
+ sudo apt-get install -y python3-tk python3-pip python3-dev build-essential \
22
+ libffi-dev libssl-dev libbz2-dev libreadline-dev libsqlite3-dev \
23
+ crossbuild-essential-arm64 gcc-aarch64-linux-gnu g++-aarch64-linux-gnu
24
+
25
+ - name: Install New GCC for Cross-Compilation
26
+ run: |
27
+ sudo apt-get install -y software-properties-common
28
+ sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
29
+ sudo apt-get update
30
+ sudo apt-get install -y gcc-12 g++-12 gcc-12-aarch64-linux-gnu g++-12-aarch64-linux-gnu
31
+ export CC=/usr/bin/aarch64-linux-gnu-gcc-12
32
+ export CXX=/usr/bin/aarch64-linux-gnu-g++-12
33
+ export AR=aarch64-linux-gnu-ar
34
+ export UNAME_M=aarch64
35
+ export UNAME_S=Linux
36
+ export PATH=/usr/bin:$PATH
37
+ make LLAMA_PORTABLE=1
38
+ chmod +x './create_ver_file.sh'
39
+ . create_ver_file.sh
40
+ mkdir -p dist
41
+ cp './koboldcpp_default.so' dist
42
+ ls
43
+
44
+ - name: Install QEMU
45
+ run: |
46
+ sudo apt-get update
47
+ sudo apt-get install -y qemu-user-static binfmt-support
48
+
49
+ - name: Setup QEMU for ARM64
50
+ run: |
51
+ docker run --rm --privileged multiarch/qemu-user-static --reset -p yes
52
+
53
+ - name: Build ARM64 PyInstaller
54
+ run: |
55
+ docker run --rm \
56
+ --platform linux/arm64 \
57
+ -v "${PWD}:/src" \
58
+ python:3.9-slim \
59
+ /bin/bash -c "
60
+ apt-get update && apt-get install -y build-essential && \
61
+ apt-get update && apt-get install -y gcc-12 g++-12 && \
62
+ export LD_LIBRARY_PATH=/usr/lib/gcc/x86_64-linux-gnu/12:$LD_LIBRARY_PATH && \
63
+ pip install customtkinter pyinstaller tk && \
64
+ cd /src && \
65
+ pyinstaller --noconfirm --onefile --collect-all customtkinter --collect-all psutil \
66
+ --add-data './koboldcpp_default.so:.' \
67
+ --add-data './kcpp_adapters:./kcpp_adapters' \
68
+ --add-data './koboldcpp.py:.' \
69
+ --add-data './klite.embd:.' \
70
+ --add-data './kcpp_docs.embd:.' \
71
+ --add-data './kcpp_sdui.embd:.' \
72
+ --add-data './taesd.embd:.' \
73
+ --add-data './taesd_xl.embd:.' \
74
+ --add-data './taesd_f.embd:.' \
75
+ --add-data './taesd_3.embd:.' \
76
+ --add-data './rwkv_vocab.embd:.' \
77
+ --add-data './rwkv_world_vocab.embd:.' \
78
+ --version-file './version.txt' \
79
+ --clean --console koboldcpp.py -n 'koboldcpp-linux-arm64'
80
+ "
81
+
82
+ - name: Save artifact
83
+ uses: actions/upload-artifact@v4
84
+ with:
85
+ name: kcpp_linux_arm64_binary
86
+ path: dist/
87
+
.github/workflows/kcpp-build-release-linux-cuda12.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Koboldcpp Linux CUDA12
2
+
3
+ on: workflow_dispatch
4
+ env:
5
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
6
+ KCPP_CUDA: 12.1.0
7
+ REBUILD_VK_SHADERS: 1
8
+
9
+ jobs:
10
+ linux:
11
+ runs-on: ubuntu-22.04
12
+ steps:
13
+ - name: Clone
14
+ id: checkout
15
+ uses: actions/checkout@v3
16
+ with:
17
+ ref: ${{ github.head_ref || github.ref_name }}
18
+
19
+ - name: Dependencies
20
+ id: depends
21
+ run: |
22
+ sudo apt-get update
23
+ sudo apt-get install git curl bzip2
24
+
25
+ - name: Build
26
+ id: make_build
27
+ run: |
28
+ ./koboldcpp.sh dist
29
+
30
+ - name: Save artifact
31
+ uses: actions/upload-artifact@v4
32
+ with:
33
+ name: kcpp_linux_binary
34
+ path: dist/
.github/workflows/kcpp-build-release-linux.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Koboldcpp Linux
2
+
3
+ on: workflow_dispatch
4
+ env:
5
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
6
+ NOAVX2: 1
7
+ REBUILD_VK_SHADERS: 1
8
+
9
+ jobs:
10
+ linux:
11
+ runs-on: ubuntu-22.04
12
+ steps:
13
+ - name: Clone
14
+ id: checkout
15
+ uses: actions/checkout@v3
16
+ with:
17
+ ref: ${{ github.head_ref || github.ref_name }}
18
+
19
+ - name: Dependencies
20
+ id: depends
21
+ run: |
22
+ sudo apt-get update
23
+ sudo apt-get install git curl bzip2
24
+
25
+ - name: Build
26
+ id: make_build
27
+ run: |
28
+ ./koboldcpp.sh dist
29
+
30
+ - name: Save artifact
31
+ uses: actions/upload-artifact@v4
32
+ with:
33
+ name: kcpp_linux_binary
34
+ path: dist/
.github/workflows/kcpp-build-release-osx.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Koboldcpp Mac
2
+
3
+ on: workflow_dispatch
4
+ env:
5
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
6
+
7
+ jobs:
8
+ osx:
9
+ runs-on: macos-latest
10
+ steps:
11
+ - name: Clone
12
+ id: checkout
13
+ uses: actions/checkout@v3
14
+ with:
15
+ ref: ${{ github.head_ref || github.ref_name }}
16
+
17
+ - name: Dependencies
18
+ id: depends
19
+ run: |
20
+ pip install customtkinter pyinstaller tk
21
+
22
+ - name: Build
23
+ id: make_build
24
+ run: |
25
+ make LLAMA_METAL=1 LLAMA_PORTABLE=1
26
+ chmod +x './create_ver_file.sh'
27
+ . create_ver_file.sh
28
+ pyinstaller --noconfirm --onefile --collect-all customtkinter --collect-all psutil --add-data './koboldcpp_default.so:.' --add-data './ggml-metal-merged.metal:.' --add-data './kcpp_adapters:./kcpp_adapters' --add-data './koboldcpp.py:.' --add-data './klite.embd:.' --add-data './kcpp_docs.embd:.' --add-data './kcpp_sdui.embd:.' --add-data './taesd.embd:.' --add-data './taesd_xl.embd:.' --add-data './taesd_f.embd:.' --add-data './taesd_3.embd:.' --add-data './rwkv_vocab.embd:.' --add-data './rwkv_world_vocab.embd:.' --version-file './version.txt' --clean --console koboldcpp.py -n "koboldcpp-mac-arm64"
29
+
30
+ - name: Test
31
+ id: test
32
+ run: |
33
+ wget https://huggingface.co/concedo/koboldcpp/resolve/main/baby_llama.gguf
34
+ dist/koboldcpp-mac-arm64 --model baby_llama.gguf --gpulayers 99 --benchmark --prompt 'Hi, my name is'
35
+
36
+ - name: Save artifact
37
+ uses: actions/upload-artifact@v4
38
+ with:
39
+ name: kcpp_mac_binary
40
+ path: dist/
41
+
.github/workflows/kcpp-build-release-win-full-cu12.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Koboldcpp Windows Full Binaries CUDA 12
2
+
3
+ on: workflow_dispatch
4
+ env:
5
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
6
+
7
+ jobs:
8
+ windows:
9
+ runs-on: windows-2019
10
+ steps:
11
+ - name: Clone
12
+ id: checkout
13
+ uses: actions/checkout@v3
14
+ with:
15
+ ref: ${{ github.head_ref || github.ref_name }}
16
+
17
+ - name: Get Python
18
+ uses: actions/setup-python@v2
19
+ with:
20
+ python-version: 3.8.10
21
+
22
+ - name: Install python dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install customtkinter==5.2.0 pyinstaller==5.11.0 psutil==5.9.5
26
+
27
+ - name: Download and install win64devkit
28
+ run: |
29
+ curl -L https://github.com/skeeto/w64devkit/releases/download/v1.22.0/w64devkit-1.22.0.zip --output w64devkit.zip
30
+ Expand-Archive w64devkit.zip -DestinationPath .
31
+
32
+ - name: Add w64devkit to PATH
33
+ run: |
34
+ echo "$(Get-Location)\w64devkit\bin" | Out-File -Append -FilePath $env:GITHUB_PATH -Encoding utf8
35
+
36
+ - name: Print System Environment Variables
37
+ id: printvars
38
+ run: |
39
+ echo "Number of processors: ${env:NUMBER_OF_PROCESSORS}"
40
+ echo "Processor Architecture: ${env:PROCESSOR_ARCHITECTURE}"
41
+ echo "Computer Name: ${env:COMPUTERNAME}"
42
+ wmic cpu get name
43
+ wmic os get TotalVisibleMemorySize, FreePhysicalMemory
44
+
45
+ - name: Rebuild Vulkan Shaders
46
+ id: make_vk_shaders
47
+ run: |
48
+ make vulkan_shaders_gen -j ${env:NUMBER_OF_PROCESSORS}
49
+
50
+ - name: Build Non-CUDA
51
+ id: make_build
52
+ run: |
53
+ make LLAMA_CLBLAST=1 LLAMA_VULKAN=1 LLAMA_PORTABLE=1 -j ${env:NUMBER_OF_PROCESSORS}
54
+ echo "Vulkan Shaders Rebuilt"
55
+
56
+ - uses: Jimver/[email protected]
57
+ id: cuda-toolkit
58
+ with:
59
+ cuda: '12.1.0'
60
+
61
+ - name: Build CUDA
62
+ id: cmake_build
63
+ run: |
64
+ mkdir build
65
+ cd build
66
+ cmake .. -DLLAMA_CUBLAS=ON -DCMAKE_SYSTEM_VERSION="10.0.19041.0"
67
+ cmake --build . --config Release -j 2
68
+ cd ..
69
+
70
+ # note: The libraries that come from the github cuda directory seem to be larger, so they are not recommended
71
+ # - name: Download CuBLAS Libraries
72
+ # run: |
73
+ # curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublas64_11.dll --output cublas64_11.dll
74
+ # curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublasLt64_11.dll --output cublasLt64_11.dll
75
+ # ls
76
+ - name: Copy CuBLAS Libraries
77
+ run: |
78
+ copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cublasLt64_12.dll" .
79
+ copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cublas64_12.dll" .
80
+ ls
81
+
82
+ - name: Package PyInstallers
83
+ id: make_pyinstaller
84
+ run: |
85
+ ./make_pyinstaller_cuda12.bat
86
+
87
+ - name: Save artifact
88
+ uses: actions/upload-artifact@v4
89
+ with:
90
+ name: kcpp_windows_pyinstallers
91
+ path: dist/
.github/workflows/kcpp-build-release-win-full.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Koboldcpp Windows Full Binaries
2
+
3
+ on: workflow_dispatch
4
+ env:
5
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
6
+
7
+ jobs:
8
+ windows:
9
+ runs-on: windows-2019
10
+ steps:
11
+ - name: Clone
12
+ id: checkout
13
+ uses: actions/checkout@v3
14
+ with:
15
+ ref: ${{ github.head_ref || github.ref_name }}
16
+
17
+ - name: Get Python
18
+ uses: actions/setup-python@v2
19
+ with:
20
+ python-version: 3.8.10
21
+
22
+ - name: Install python dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install customtkinter==5.2.0 pyinstaller==5.11.0 psutil==5.9.5
26
+
27
+ - name: Download and install win64devkit
28
+ run: |
29
+ curl -L https://github.com/skeeto/w64devkit/releases/download/v1.22.0/w64devkit-1.22.0.zip --output w64devkit.zip
30
+ Expand-Archive w64devkit.zip -DestinationPath .
31
+
32
+ - name: Add w64devkit to PATH
33
+ run: |
34
+ echo "$(Get-Location)\w64devkit\bin" | Out-File -Append -FilePath $env:GITHUB_PATH -Encoding utf8
35
+
36
+ - name: Print System Environment Variables
37
+ id: printvars
38
+ run: |
39
+ echo "Number of processors: ${env:NUMBER_OF_PROCESSORS}"
40
+ echo "Processor Architecture: ${env:PROCESSOR_ARCHITECTURE}"
41
+ echo "Computer Name: ${env:COMPUTERNAME}"
42
+ wmic cpu get name
43
+ wmic os get TotalVisibleMemorySize, FreePhysicalMemory
44
+
45
+ - name: Rebuild Vulkan Shaders
46
+ id: make_vk_shaders
47
+ run: |
48
+ make vulkan_shaders_gen -j ${env:NUMBER_OF_PROCESSORS}
49
+ echo "Vulkan Shaders Rebuilt"
50
+
51
+ - name: Build Non-CUDA
52
+ id: make_build
53
+ run: |
54
+ make LLAMA_CLBLAST=1 LLAMA_VULKAN=1 LLAMA_PORTABLE=1 -j ${env:NUMBER_OF_PROCESSORS}
55
+
56
+ - uses: Jimver/[email protected]
57
+ id: cuda-toolkit
58
+ with:
59
+ cuda: '11.4.4'
60
+
61
+ - name: Build CUDA
62
+ id: cmake_build
63
+ run: |
64
+ mkdir build
65
+ cd build
66
+ cmake .. -DLLAMA_CUBLAS=ON -DCMAKE_SYSTEM_VERSION="10.0.19041.0"
67
+ cmake --build . --config Release -j 2
68
+ cd ..
69
+
70
+ # note: The libraries that come from the github cuda directory seem to be larger, so they are not recommended
71
+ - name: Download CuBLAS Libraries
72
+ run: |
73
+ curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublas64_11.dll --output cublas64_11.dll
74
+ curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublasLt64_11.dll --output cublasLt64_11.dll
75
+ ls
76
+ # - name: Copy CuBLAS Libraries
77
+ # run: |
78
+ # copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublasLt64_11.dll" .
79
+ # copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublas64_11.dll" .
80
+ # ls
81
+
82
+ - name: Package PyInstallers
83
+ id: make_pyinstaller
84
+ run: |
85
+ ./make_pyinstaller.bat
86
+ ./make_pyinstaller_cuda.bat
87
+
88
+ - name: Save artifact
89
+ uses: actions/upload-artifact@v4
90
+ with:
91
+ name: kcpp_windows_pyinstallers
92
+ path: dist/
.github/workflows/kcpp-build-release-win-oldcpu-full.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Koboldcpp Windows Full OldCPU Binaries
2
+
3
+ on: workflow_dispatch
4
+ env:
5
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
6
+
7
+ jobs:
8
+ windows:
9
+ runs-on: windows-2019
10
+ steps:
11
+ - name: Clone
12
+ id: checkout
13
+ uses: actions/checkout@v3
14
+ with:
15
+ ref: ${{ github.head_ref || github.ref_name }}
16
+
17
+ - name: Get Python
18
+ uses: actions/setup-python@v2
19
+ with:
20
+ python-version: 3.8.10
21
+
22
+ - name: Install python dependencies
23
+ run: |
24
+ python -m pip install --upgrade pip
25
+ pip install customtkinter==5.2.0 pyinstaller==5.11.0 psutil==5.9.5
26
+
27
+ - name: Download and install win64devkit
28
+ run: |
29
+ curl -L https://github.com/skeeto/w64devkit/releases/download/v1.22.0/w64devkit-1.22.0.zip --output w64devkit.zip
30
+ Expand-Archive w64devkit.zip -DestinationPath .
31
+
32
+ - name: Add w64devkit to PATH
33
+ run: |
34
+ echo "$(Get-Location)\w64devkit\bin" | Out-File -Append -FilePath $env:GITHUB_PATH -Encoding utf8
35
+
36
+ - name: Print System Environment Variables
37
+ id: printvars
38
+ run: |
39
+ echo "Number of processors: ${env:NUMBER_OF_PROCESSORS}"
40
+ echo "Processor Architecture: ${env:PROCESSOR_ARCHITECTURE}"
41
+ echo "Computer Name: ${env:COMPUTERNAME}"
42
+ wmic cpu get name
43
+ wmic os get TotalVisibleMemorySize, FreePhysicalMemory
44
+
45
+ - name: Rebuild Vulkan Shaders
46
+ id: make_vk_shaders
47
+ run: |
48
+ make vulkan_shaders_gen -j ${env:NUMBER_OF_PROCESSORS}
49
+ echo "Vulkan Shaders Rebuilt"
50
+
51
+ - name: Build Non-CUDA
52
+ id: make_build
53
+ run: |
54
+ make LLAMA_CLBLAST=1 LLAMA_VULKAN=1 LLAMA_PORTABLE=1 -j ${env:NUMBER_OF_PROCESSORS} LLAMA_NOAVX2=1
55
+
56
+ - uses: Jimver/[email protected]
57
+ id: cuda-toolkit
58
+ with:
59
+ cuda: '11.4.4'
60
+
61
+ - name: Build CUDA
62
+ id: cmake_build
63
+ run: |
64
+ mkdir build
65
+ cd build
66
+ cmake .. -DLLAMA_CUBLAS=ON -DLLAMA_AVX2=OFF -DCMAKE_SYSTEM_VERSION="10.0.19041.0"
67
+ cmake --build . --config Release -j 2
68
+ cd ..
69
+
70
+ # note: The libraries that come from the github cuda directory seem to be larger, so they are not recommended
71
+ - name: Download CuBLAS Libraries
72
+ run: |
73
+ curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublas64_11.dll --output cublas64_11.dll
74
+ curl -L https://github.com/LostRuins/koboldcpp/releases/download/cuda11_cublas_libraries/cublasLt64_11.dll --output cublasLt64_11.dll
75
+ ls
76
+ # - name: Copy CuBLAS Libraries
77
+ # run: |
78
+ # copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublasLt64_11.dll" .
79
+ # copy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\bin\cublas64_11.dll" .
80
+ # ls
81
+
82
+ - name: Package PyInstallers
83
+ id: make_pyinstaller
84
+ run: |
85
+ ./make_pyinstaller_cuda_oldcpu.bat
86
+
87
+ - name: Save artifact
88
+ uses: actions/upload-artifact@v4
89
+ with:
90
+ name: kcpp_windows_pyinstallers
91
+ path: dist/
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.o
2
+ *.a
3
+ *.bin
4
+ .DS_Store
5
+ .build/
6
+ .cache/
7
+ .ccls-cache/
8
+ .direnv/
9
+ .envrc
10
+ .swiftpm
11
+ .venv
12
+ .clang-tidy
13
+ .vs/
14
+ .vscode/
15
+
16
+ ggml-metal-embed.metal
17
+
18
+ lcov-report/
19
+ gcovr-report/
20
+
21
+ build*/
22
+ out/
23
+ tmp/
24
+ autogen-*.md
25
+
26
+ models/*
27
+ models-mnt
28
+
29
+ /Pipfile
30
+ /baby-llama
31
+ /beam-search
32
+ /benchmark-matmult
33
+ /convert-llama2c-to-ggml
34
+ /embd-input-test
35
+ /embedding
36
+ /eval-callback
37
+ /gguf
38
+ /gguf-llama-simple
39
+ /gritlm
40
+ /imatrix
41
+ /infill
42
+ /libllama.so
43
+ /llama-bench
44
+ /llava-cli
45
+ /lookahead
46
+ /lookup
47
+ /main
48
+ /metal
49
+ /passkey
50
+ /perplexity
51
+ /q8dot
52
+ /quantize
53
+ /quantize-stats
54
+ /result
55
+ /save-load-state
56
+ /server
57
+ /simple
58
+ /batched
59
+ /batched-bench
60
+ /export-lora
61
+ /finetune
62
+ /speculative
63
+ /parallel
64
+ /train-text-from-scratch
65
+ /tokenize
66
+ /vdot
67
+ /common/build-info.cpp
68
+ arm_neon.h
69
+ compile_commands.json
70
+ CMakeSettings.json
71
+
72
+ __pycache__
73
+ dist
74
+
75
+ dist/
76
+ *.spec
77
+
78
+ zig-out/
79
+ zig-cache/
80
+
81
+ ppl-*.txt
82
+ qnt-*.txt
83
+ perf-*.txt
84
+
85
+ examples/jeopardy/results.txt
86
+
87
+ poetry.lock
88
+ poetry.toml
89
+
90
+ ggml-metal-merged.metal
91
+
92
+ # Test binaries
93
+ /tests/test-llama-grammar
94
+ tests/test-double-float
95
+ tests/test-grad0
96
+ tests/test-opt
97
+ tests/test-quantize-fns
98
+ tests/test-quantize-perf
99
+ tests/test-sampling
100
+ tests/test-tokenizer-0
101
+ tests/test-tokenizer-0-llama
102
+ tests/test-tokenizer-0-falcon
103
+ tests/test-tokenizer-1-llama
104
+ tests/test-tokenizer-1-bpe
105
+ /tests/test-rope
106
+ /tests/test-backend-ops
107
+
108
+ /koboldcpp_default.so
109
+ /koboldcpp_failsafe.so
110
+ /koboldcpp_noavx2.so
111
+ /koboldcpp_clblast.so
112
+ /koboldcpp_clblast_noavx2.so
113
+ /koboldcpp_clblast_failsafe.so
114
+ /koboldcpp_cublas.so
115
+ /koboldcpp_vulkan.so
116
+ /koboldcpp_vulkan_noavx2.so
117
+ /koboldcpp_default.dll
118
+ /koboldcpp_failsafe.dll
119
+ /koboldcpp_noavx2.dll
120
+ /koboldcpp_clblast.dll
121
+ /koboldcpp_clblast_noavx2.dll
122
+ /koboldcpp_vulkan_noavx2.dll
123
+ /koboldcpp_clblast_failsafe.dll
124
+ /koboldcpp_cublas.dll
125
+ /koboldcpp_vulkan.dll
126
+ /cublas64_11.dll
127
+ /cublasLt64_11.dll
128
+ /cublas64_12.dll
129
+ /cublasLt64_12.dll
130
+ /rocblas/
131
+ rocblas.dll
132
+ hipblas.dll
133
+ koboldcpp_hipblas.so
134
+ koboldcpp_hipblas.dll
135
+
136
+ bin/
137
+ conda/
138
+
139
+ # Jetbrains idea folder
140
+ .idea/
CLINFO_LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Windows binaries obtained from the clinfo repo fork here:
2
+
3
+ https://github.com/ahoylabs/clinfo/releases/tag/master-d2baa06
4
+
5
+ Source available here:
6
+ https://github.com/Oblomov/clinfo
7
+
8
+ see below LICENSE file for details on clinfo license
9
+
10
+ =======
11
+
12
+ clinfo by Giuseppe Bilotta
13
+
14
+ To the extent possible under law, the person who associated CC0 with
15
+ clinfo has waived all copyright and related or neighboring rights
16
+ to clinfo.
17
+
18
+ You should have received a copy of the CC0 legalcode along with this
19
+ work. If not, see <http://creativecommons.org/publicdomain/zero/1.0/>
CMakeLists.txt ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE IS ONLY INTENDED CUBLAS BUILD PURPOSES ON WINDOWS VISUAL STUDIO.
2
+ # YOU'RE NOT RECOMMENDED TO USE IT
3
+
4
+ message(STATUS "============== ============== ==============")
5
+ message(STATUS "WARNING! Recommend NOT to use this file. It is UNSUPPORTED for normal users. Use MAKE instead.")
6
+ message(STATUS "It is ONLY for CUBLAS builds on windows visual studio. IT WILL OVERWRITE YOUR EXISTING MAKEFILE !!!")
7
+ message(STATUS "IF YOU ARE SEEING THIS, you MUST ONLY be building CUBLAS BUILDS! NOTHING ELSE WILL BE SUPPORTED !!!")
8
+ message(STATUS "============== ============== ==============")
9
+
10
+ cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
11
+ project("llama.cpp" C CXX)
12
+
13
+ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
14
+ set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS 1)
15
+ set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
16
+ set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Release")
17
+ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
18
+ set(LLAMA_STANDALONE ON)
19
+ set(BUILD_SHARED_LIBS_DEFAULT ON)
20
+ set(LLAMA_STATIC OFF)
21
+ set(LLAMA_NATIVE OFF)
22
+ set(LLAMA_LTO OFF)
23
+ set(LLAMA_ALL_WARNINGS OFF)
24
+ set(LLAMA_ALL_WARNINGS_3RD_PARTY OFF)
25
+ set(LLAMA_GPROF OFF)
26
+ set(LLAMA_SANITIZE_THREAD OFF)
27
+ set(LLAMA_SANITIZE_ADDRESS OFF)
28
+ set(LLAMA_SANITIZE_UNDEFINED OFF)
29
+
30
+
31
+ # instruction set specific
32
+ option(LLAMA_AVX "llama: enable AVX" ON)
33
+ option(LLAMA_AVX2 "llama: enable AVX2" ON)
34
+ option(LLAMA_AVX512 "llama: enable AVX512" OFF)
35
+ option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
36
+ option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
37
+ option(LLAMA_FMA "llama: enable FMA" ON)
38
+ # in MSVC F16C is implied with AVX2/AVX512
39
+ if (NOT MSVC)
40
+ option(LLAMA_F16C "llama: enable F16C" ON)
41
+ endif()
42
+
43
+ # 3rd party libs
44
+ option(LLAMA_CUBLAS "llama: use CUDA" ON)
45
+ option(LLAMA_CUDA_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
46
+ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
47
+ "llama: max. batch size for using peer access")
48
+
49
+ option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
50
+
51
+ # Other
52
+ option(LLAMA_OPENMP "llama: use OpenMP" OFF)
53
+
54
+ #
55
+ # Compile flags
56
+ #
57
+
58
+ set(CMAKE_CXX_STANDARD 17)
59
+ set(CMAKE_CXX_STANDARD_REQUIRED true)
60
+ set(CMAKE_C_STANDARD 11)
61
+ set(CMAKE_C_STANDARD_REQUIRED true)
62
+ set(THREADS_PREFER_PTHREAD_FLAG ON)
63
+ find_package(Threads REQUIRED)
64
+
65
+ add_compile_definitions(LOG_DISABLE_LOGS)
66
+ add_compile_definitions(GGML_USE_CPU)
67
+ add_compile_definitions(GGML_USE_CPU_AARCH64)
68
+
69
+ if (MSVC)
70
+ add_compile_options("$<$<COMPILE_LANGUAGE:C>:/utf-8>")
71
+ add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/utf-8>")
72
+ add_compile_options("$<$<COMPILE_LANGUAGE:C>:/bigobj>")
73
+ add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/bigobj>")
74
+ endif()
75
+
76
+ file(GLOB GGML_SOURCES_CUDA "ggml/src/ggml-cuda/*.cu")
77
+ list(APPEND GGML_SOURCES_CUDA "ggml/src/ggml-cuda/ggml-cuda.cu")
78
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-mma*.cu")
79
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
80
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
81
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
82
+ set(GGML_V3_CUDA_SOURCES otherarch/ggml_v3-cuda.cu otherarch/ggml_v3-cuda.h)
83
+ set(GGML_V2_CUDA_SOURCES otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h)
84
+ set(GGML_V2_LEGACY_CUDA_SOURCES otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h)
85
+
86
+
87
+ if (LLAMA_CUBLAS)
88
+ cmake_minimum_required(VERSION 3.17)
89
+
90
+ find_package(CUDAToolkit)
91
+ if (CUDAToolkit_FOUND)
92
+ message(STATUS "cuBLAS found")
93
+
94
+ enable_language(CUDA)
95
+
96
+ add_compile_definitions(GGML_USE_LLAMAFILE)
97
+ add_compile_definitions(GGML_USE_CUDA)
98
+ add_compile_definitions(SD_USE_CUBLAS)
99
+
100
+ if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
101
+ add_compile_definitions(GGML_CUDA_F16)
102
+ endif()
103
+ add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${LLAMA_CUDA_PEER_MAX_BATCH_SIZE})
104
+
105
+ # only build minimal quants required for fattn quant kv
106
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
107
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
108
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
109
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
110
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
111
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
112
+
113
+ if (LLAMA_STATIC)
114
+ if (WIN32)
115
+ # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
116
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
117
+ else ()
118
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
119
+ endif()
120
+ else()
121
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
122
+ endif()
123
+
124
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
125
+
126
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
127
+ # 50 == lowest CUDA 12 standard
128
+ # 60 == f16 CUDA intrinsics
129
+ # 61 == integer CUDA intrinsics
130
+ # 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster
131
+ # 75 == int8 tensor cores
132
+ if (LLAMA_CUDA_F16 OR LLAMA_CUDA_DMMV_F16)
133
+ set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") # needed for f16 CUDA intrinsics
134
+ else()
135
+ message("CUDA Toolkit Version: ${CUDAToolkit_VERSION}")
136
+ if(CUDAToolkit_VERSION VERSION_GREATER 12)
137
+ add_compile_definitions(GGML_CUDA_USE_GRAPHS) #try enable cuda graphs on cu12 build
138
+ set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
139
+ else()
140
+ set(CMAKE_CUDA_ARCHITECTURES "37;50;61;70;75") # lowest CUDA 12 standard + lowest for integer intrinsics
141
+ endif()
142
+ endif()
143
+ endif()
144
+ message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
145
+
146
+ else()
147
+ message(WARNING "cuBLAS not found")
148
+ endif()
149
+ endif()
150
+
151
+ if (LLAMA_HIPBLAS)
152
+ if (MSVC)
153
+ list(APPEND CMAKE_PREFIX_PATH "C:/Program Files/AMD/ROCm/5.5")
154
+ else()
155
+ list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
156
+ endif()
157
+
158
+
159
+ if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
160
+ message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
161
+ endif()
162
+ if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
163
+ message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
164
+ endif()
165
+
166
+ find_package(hip)
167
+ find_package(hipblas)
168
+ find_package(rocblas)
169
+
170
+ if (${hipblas_FOUND} AND ${hip_FOUND})
171
+ message(STATUS "HIP and hipBLAS found")
172
+ file(GLOB GGML_SOURCES_ROCM "ggml/src/ggml-cuda/*.cu")
173
+ list(APPEND GGML_SOURCES_ROCM "ggml/src/ggml-cuda/ggml-cuda.cu")
174
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-mma*.cu")
175
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
176
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
177
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
178
+ add_compile_definitions(GGML_USE_HIP GGML_USE_CUDA SD_USE_CUBLAS)
179
+ add_library(ggml-rocm ${GGML_SOURCES_CUDA})
180
+
181
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
182
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
183
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
184
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
185
+ file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
186
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
187
+
188
+ # only build minimal quants required for fattn quant kv
189
+ set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
190
+ target_link_libraries(ggml-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
191
+
192
+ add_library(ggml-v2-rocm ${GGML_V2_CUDA_SOURCES})
193
+ set_source_files_properties(otherarch/ggml_v2-cuda.cu PROPERTIES LANGUAGE CXX)
194
+ target_link_libraries(ggml-v2-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
195
+
196
+ add_library(ggml-v3-rocm ${GGML_V3_CUDA_SOURCES})
197
+ set_source_files_properties(otherarch/ggml_v3-cuda.cu PROPERTIES LANGUAGE CXX)
198
+ target_link_libraries(ggml-v3-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
199
+
200
+ add_library(ggml-v2-legacy-rocm ${GGML_V2_LEGACY_CUDA_SOURCES})
201
+ set_source_files_properties(otherarch/ggml_v2-cuda-legacy.cu PROPERTIES LANGUAGE CXX)
202
+ target_link_libraries(ggml-v2-legacy-rocm PUBLIC hip::device hip::host roc::rocblas roc::hipblas)
203
+
204
+ if (LLAMA_STATIC)
205
+ message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
206
+ endif()
207
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm ggml-v2-rocm ggml-v2-legacy-rocm)
208
+ else()
209
+ message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
210
+ endif()
211
+ endif()
212
+
213
+ if (LLAMA_ALL_WARNINGS)
214
+ if (NOT MSVC)
215
+ set(warning_flags -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
216
+ set(c_flags -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -Werror=implicit-int
217
+ -Werror=implicit-function-declaration)
218
+ set(cxx_flags -Wmissing-declarations -Wmissing-noreturn)
219
+
220
+ if (CMAKE_C_COMPILER_ID MATCHES "Clang")
221
+ set(warning_flags ${warning_flags} -Wunreachable-code-break -Wunreachable-code-return)
222
+ set(cxx_flags ${cxx_flags} -Wmissing-prototypes -Wextra-semi)
223
+
224
+ if (
225
+ (CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 3.8.0) OR
226
+ (CMAKE_C_COMPILER_ID STREQUAL "AppleClang" AND CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 7.3.0)
227
+ )
228
+ set(c_flags ${c_flags} -Wdouble-promotion)
229
+ endif()
230
+ elseif (CMAKE_C_COMPILER_ID STREQUAL "GNU")
231
+ set(c_flags ${c_flags} -Wdouble-promotion)
232
+ set(cxx_flags ${cxx_flags} -Wno-array-bounds)
233
+
234
+ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 7.1.0)
235
+ set(cxx_flags ${cxx_flags} -Wno-format-truncation)
236
+ endif()
237
+ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1.0)
238
+ set(cxx_flags ${cxx_flags} -Wextra-semi)
239
+ endif()
240
+ endif()
241
+ else()
242
+ # todo : msvc
243
+ endif()
244
+
245
+ add_compile_options(
246
+ ${warning_flags}
247
+ "$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
248
+ "$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
249
+ )
250
+
251
+ endif()
252
+
253
+ if (WIN32)
254
+ add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
255
+
256
+ if (BUILD_SHARED_LIBS)
257
+ set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
258
+ endif()
259
+ endif()
260
+
261
+ if (LLAMA_LTO)
262
+ include(CheckIPOSupported)
263
+ check_ipo_supported(RESULT result OUTPUT output)
264
+ if (result)
265
+ set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
266
+ else()
267
+ message(WARNING "IPO is not supported: ${output}")
268
+ endif()
269
+ endif()
270
+
271
+ if (LLAMA_OPENMP)
272
+ find_package(OpenMP)
273
+ if (OpenMP_FOUND)
274
+ message(STATUS "OpenMP found")
275
+ add_compile_definitions(GGML_USE_OPENMP)
276
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
277
+ else()
278
+ message(WARNING "OpenMP not found")
279
+ endif()
280
+ endif()
281
+
282
+ # this version of Apple ld64 is buggy
283
+ execute_process(
284
+ COMMAND ${CMAKE_C_COMPILER} ${CMAKE_EXE_LINKER_FLAGS} -Wl,-v
285
+ ERROR_VARIABLE output
286
+ )
287
+ if (output MATCHES "dyld-1015\.7")
288
+ add_compile_definitions(HAVE_BUGGY_APPLE_LINKER)
289
+ endif()
290
+
291
+ # Architecture specific
292
+ # TODO: probably these flags need to be tweaked on some architectures
293
+ # feel free to update the Makefile for your architecture and send a pull request or issue
294
+ message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
295
+ if (NOT MSVC)
296
+ if (LLAMA_STATIC)
297
+ add_link_options(-static)
298
+ if (MINGW)
299
+ add_link_options(-static-libgcc -static-libstdc++)
300
+ endif()
301
+ endif()
302
+ if (LLAMA_GPROF)
303
+ add_compile_options(-pg)
304
+ endif()
305
+ if (LLAMA_NATIVE)
306
+ add_compile_options(-march=native)
307
+ endif()
308
+ endif()
309
+
310
+ if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64"))
311
+ message(STATUS "ARM detected")
312
+ if (MSVC)
313
+ # TODO: arm msvc?
314
+ else()
315
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
316
+ # Raspberry Pi 1, Zero
317
+ add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access)
318
+ endif()
319
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
320
+ # Raspberry Pi 2
321
+ add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations)
322
+ endif()
323
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
324
+ # Raspberry Pi 3, 4, Zero 2 (32-bit)
325
+ add_compile_options(-mfp16-format=ieee -mno-unaligned-access)
326
+ endif()
327
+ endif()
328
+ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
329
+ message(STATUS "x86 detected")
330
+ if (MSVC)
331
+ if (LLAMA_AVX512)
332
+ add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
333
+ add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
334
+ # MSVC has no compile-time flags enabling specific
335
+ # AVX512 extensions, neither it defines the
336
+ # macros corresponding to the extensions.
337
+ # Do it manually.
338
+ if (LLAMA_AVX512_VBMI)
339
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
340
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
341
+ endif()
342
+ if (LLAMA_AVX512_VNNI)
343
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
344
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
345
+ endif()
346
+ elseif (LLAMA_AVX2)
347
+ add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
348
+ add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
349
+ elseif (LLAMA_AVX)
350
+ add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
351
+ add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
352
+ endif()
353
+ else()
354
+ if (LLAMA_F16C)
355
+ add_compile_options(-mf16c)
356
+ endif()
357
+ if (LLAMA_FMA)
358
+ add_compile_options(-mfma)
359
+ endif()
360
+ if (LLAMA_AVX)
361
+ add_compile_options(-mavx)
362
+ endif()
363
+ if (LLAMA_AVX2)
364
+ add_compile_options(-mavx2)
365
+ endif()
366
+ if (LLAMA_AVX512)
367
+ add_compile_options(-mavx512f)
368
+ add_compile_options(-mavx512bw)
369
+ endif()
370
+ if (LLAMA_AVX512_VBMI)
371
+ add_compile_options(-mavx512vbmi)
372
+ endif()
373
+ if (LLAMA_AVX512_VNNI)
374
+ add_compile_options(-mavx512vnni)
375
+ endif()
376
+ endif()
377
+ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
378
+ message(STATUS "PowerPC detected")
379
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
380
+ add_compile_options(-mcpu=powerpc64le)
381
+ else()
382
+ add_compile_options(-mcpu=native -mtune=native)
383
+ #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
384
+ endif()
385
+ else()
386
+ message(STATUS "Unknown architecture")
387
+ endif()
388
+
389
+ if (MINGW)
390
+ # Target Windows 8 for PrefetchVirtualMemory
391
+ add_compile_definitions(_WIN32_WINNT=0x602)
392
+ endif()
393
+
394
+ #
395
+ # Build libraries
396
+ #
397
+
398
+ add_library(ggml
399
+ ggml/src/ggml.c
400
+ ggml/include/ggml.h
401
+ ggml/src/ggml-cpu/ggml-cpu.c
402
+ ggml/include/ggml-cpu.h
403
+ ggml/src/ggml-alloc.c
404
+ ggml/include/ggml-alloc.h
405
+ ggml/src/ggml-backend.cpp
406
+ ggml/src/ggml-backend-impl.h
407
+ ggml/include/ggml-backend.h
408
+ ggml/include/ggml-cpp.h
409
+ ggml/src/ggml-quants.c
410
+ ggml/src/ggml-quants.h
411
+ ggml/src/ggml-cpu/llamafile/sgemm.cpp
412
+ ggml/src/ggml-cpu/llamafile/sgemm.h
413
+ ggml/src/ggml-cpu/ggml-cpu-traits.cpp
414
+ ggml/src/ggml-cpu/ggml-cpu-traits.h
415
+ ggml/src/ggml-threading.cpp
416
+ ggml/src/ggml-cpu/ggml-cpu.cpp
417
+ ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
418
+ ggml/src/ggml-cpu/ggml-cpu-aarch64.h
419
+ ggml/src/ggml-cpu/ggml-cpu-quants.c
420
+ ggml/src/ggml-cpu/ggml-cpu-quants.h
421
+ ggml/src/ggml-backend-reg.cpp
422
+ ggml/include/gguf.h
423
+ ggml/src/gguf.cpp
424
+ ${GGML_SOURCES_CUDA})
425
+ target_include_directories(ggml PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
426
+ target_compile_features(ggml PUBLIC c_std_11) # don't bump
427
+ target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
428
+ set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
429
+
430
+ add_library(ggml_v1
431
+ otherarch/ggml_v1.c
432
+ otherarch/ggml_v1.h)
433
+ target_include_directories(ggml_v1 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
434
+ target_compile_features(ggml_v1 PUBLIC c_std_11) # don't bump
435
+ target_link_libraries(ggml_v1 PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
436
+ set_target_properties(ggml_v1 PROPERTIES POSITION_INDEPENDENT_CODE ON)
437
+
438
+ add_library(ggml_v2
439
+ otherarch/ggml_v2.c
440
+ otherarch/ggml_v2.h
441
+ ${GGML_V2_CUDA_SOURCES}
442
+ ${GGML_V2_LEGACY_CUDA_SOURCES})
443
+ target_include_directories(ggml_v2 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
444
+ target_compile_features(ggml_v2 PUBLIC c_std_11) # don't bump
445
+ target_link_libraries(ggml_v2 PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
446
+ set_target_properties(ggml_v2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
447
+
448
+ add_library(ggml_v3
449
+ otherarch/ggml_v3.c
450
+ otherarch/ggml_v3.h
451
+ ${GGML_V3_CUDA_SOURCES})
452
+ target_include_directories(ggml_v3 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools)
453
+ target_compile_features(ggml_v3 PUBLIC c_std_11) # don't bump
454
+ target_link_libraries(ggml_v3 PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
455
+ set_target_properties(ggml_v3 PROPERTIES POSITION_INDEPENDENT_CODE ON)
456
+
457
+ add_library(common2
458
+ common/common.cpp
459
+ common/common.h
460
+ common/sampling.cpp
461
+ common/sampling.h
462
+ examples/llava/llava.cpp
463
+ examples/llava/llava.h
464
+ examples/llava/clip.cpp
465
+ examples/llava/clip.h
466
+ src/unicode.h
467
+ src/unicode.cpp
468
+ src/unicode-data.cpp
469
+ otherarch/utils.cpp
470
+ otherarch/utils.h)
471
+ target_include_directories(common2 PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
472
+ target_compile_features(common2 PUBLIC cxx_std_17) # don't bump
473
+ target_link_libraries(common2 PRIVATE ggml ${LLAMA_EXTRA_LIBS})
474
+ set_target_properties(common2 PROPERTIES POSITION_INDEPENDENT_CODE ON)
475
+
476
+ add_library(sdtype_adapter
477
+ otherarch/sdcpp/sdtype_adapter.cpp)
478
+ target_include_directories(sdtype_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
479
+ target_compile_features(sdtype_adapter PUBLIC cxx_std_17) # don't bump
480
+ target_link_libraries(sdtype_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
481
+ set_target_properties(sdtype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
482
+
483
+ add_library(whisper_adapter
484
+ otherarch/whispercpp/whisper_adapter.cpp)
485
+ target_include_directories(whisper_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/whispercpp ./examples ./common)
486
+ target_compile_features(whisper_adapter PUBLIC cxx_std_17) # don't bump
487
+ target_link_libraries(whisper_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
488
+ set_target_properties(whisper_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
489
+
490
+ add_library(tts_adapter
491
+ otherarch/tts_adapter.cpp)
492
+ target_include_directories(tts_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./examples ./common)
493
+ target_compile_features(tts_adapter PUBLIC cxx_std_17) # don't bump
494
+ target_link_libraries(tts_adapter PRIVATE common2 ggml ${LLAMA_EXTRA_LIBS})
495
+ set_target_properties(tts_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
496
+
497
+ add_library(gpttype_adapter
498
+ gpttype_adapter.cpp)
499
+ target_include_directories(gpttype_adapter PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
500
+ target_compile_features(gpttype_adapter PUBLIC cxx_std_17) # don't bump
501
+ target_link_libraries(gpttype_adapter PRIVATE common2 ggml ggml_v1 ggml_v2 ggml_v3 ${LLAMA_EXTRA_LIBS})
502
+ set_target_properties(gpttype_adapter PROPERTIES POSITION_INDEPENDENT_CODE ON)
503
+
504
+ if (LLAMA_CUBLAS)
505
+ set(TARGET koboldcpp_cublas)
506
+ add_library(${TARGET} SHARED expose.cpp expose.h)
507
+ target_include_directories(${TARGET} PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
508
+ target_compile_features(${TARGET} PUBLIC cxx_std_17) # don't bump
509
+ set_target_properties(${TARGET} PROPERTIES PREFIX "")
510
+ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_cublas")
511
+ set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
512
+ target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter tts_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
513
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)
514
+
515
+ add_custom_command(
516
+ TARGET koboldcpp_cublas POST_BUILD
517
+ COMMAND ${CMAKE_COMMAND} -E copy
518
+ $<TARGET_FILE:koboldcpp_cublas> # The generated DLL
519
+ ${CMAKE_SOURCE_DIR}/ # Destination directory
520
+ COMMENT "Copying DLL to parent directory"
521
+ )
522
+ endif()
523
+
524
+ if (LLAMA_HIPBLAS)
525
+ set(TARGET koboldcpp_hipblas)
526
+ add_library(${TARGET} SHARED expose.cpp expose.h)
527
+ target_include_directories(${TARGET} PUBLIC . ./ggml/include ./ggml/src ./ggml/src/ggml-cpu ./include ./otherarch ./otherarch/tools ./otherarch/sdcpp ./otherarch/sdcpp/thirdparty ./examples ./common)
528
+ target_compile_features(${TARGET} PUBLIC cxx_std_17) # don't bump
529
+ set_target_properties(${TARGET} PROPERTIES PREFIX "")
530
+ set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME "koboldcpp_hipblas")
531
+ set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
532
+ target_link_libraries(${TARGET} PUBLIC Threads::Threads ggml ggml_v1 ggml_v2 ggml_v3 common2 gpttype_adapter whisper_adapter tts_adapter sdtype_adapter ${LLAMA_EXTRA_LIBS})
533
+ target_compile_features(${TARGET} PRIVATE cxx_std_17)
534
+
535
+ add_custom_command(
536
+ TARGET koboldcpp_hipblas POST_BUILD
537
+ COMMAND ${CMAKE_COMMAND} -E copy
538
+ $<TARGET_FILE:koboldcpp_hipblas> # The generated DLL
539
+ ${CMAKE_SOURCE_DIR}/ # Destination directory
540
+ COMMENT "Copying DLL to parent directory"
541
+ )
542
+ endif()
543
+
LICENSE.md ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
MIT_LICENSE_GGML_LLAMACPP_ONLY ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Georgi Gerganov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ===================================
24
+
25
+ Note that the above license applies ONLY to the GGML library and llama.cpp by ggerganov which are licensed under the MIT License
26
+ KoboldAI Lite by Concedo and the provided python ctypes bindings in koboldcpp dlls are licensed under the AGPL v3.0 License
Makefile ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add custom options to Makefile.local rather than editing this file.
2
+ -include $(abspath $(lastword ${MAKEFILE_LIST})).local
3
+
4
+ .PHONY: finishedmsg
5
+
6
+ default: koboldcpp_default koboldcpp_failsafe koboldcpp_noavx2 koboldcpp_clblast koboldcpp_clblast_noavx2 koboldcpp_clblast_failsafe koboldcpp_cublas koboldcpp_hipblas koboldcpp_vulkan koboldcpp_vulkan_noavx2 finishedmsg
7
+ tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip ttsmain whispermain sdmain gguf-split
8
+
9
+ ifndef UNAME_S
10
+ UNAME_S := $(shell uname -s)
11
+ endif
12
+
13
+ ifndef UNAME_P
14
+ UNAME_P := $(shell uname -p)
15
+ endif
16
+
17
+ ifndef UNAME_M
18
+ UNAME_M := $(shell uname -m)
19
+ endif
20
+
21
+ ifndef UNAME_O
22
+ UNAME_O := $(shell uname -o)
23
+ endif
24
+
25
+ ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/null),)
26
+ ARCH_ADD = -lcblas
27
+ endif
28
+
29
+
30
+ # Mac OS + Arm can report x86_64
31
+ # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789
32
+ ifeq ($(UNAME_S),Darwin)
33
+ ifneq ($(UNAME_P),arm)
34
+ SYSCTL_M := $(shell sysctl -n hw.optional.arm64 2>/dev/null)
35
+ ifeq ($(SYSCTL_M),1)
36
+ # UNAME_P := arm
37
+ # UNAME_M := arm64
38
+ warn := $(warning Your arch is announced as x86_64, but it seems to actually be ARM64. Not fixing that can lead to bad performance. For more info see: https://github.com/ggerganov/whisper.cpp/issues/66\#issuecomment-1282546789)
39
+ endif
40
+ endif
41
+ endif
42
+
43
+ #
44
+ # Compile flags
45
+ #
46
+
47
+ # keep standard at C11 and C++17
48
+ CFLAGS =
49
+ CXXFLAGS =
50
+ ifdef KCPP_DEBUG
51
+ CFLAGS = -g -O0
52
+ CXXFLAGS = -g -O0
53
+ endif
54
+ CFLAGS += -I. -Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -I./include -I./include/CL -I./otherarch -I./otherarch/tools -I./otherarch/sdcpp -I./otherarch/sdcpp/thirdparty -I./include/vulkan -O3 -fno-finite-math-only -std=c11 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE -DGGML_USE_CPU -DGGML_USE_CPU_AARCH64
55
+ CXXFLAGS += -I. -Iggml/include -Iggml/src -Iggml/src/ggml-cpu -Iinclude -Isrc -I./common -I./include -I./include/CL -I./otherarch -I./otherarch/tools -I./otherarch/sdcpp -I./otherarch/sdcpp/thirdparty -I./include/vulkan -O3 -fno-finite-math-only -std=c++17 -fPIC -DLOG_DISABLE_LOGS -D_GNU_SOURCE -DGGML_USE_CPU -DGGML_USE_CPU_AARCH64
56
+ ifndef KCPP_DEBUG
57
+ CFLAGS += -DNDEBUG -s
58
+ CXXFLAGS += -DNDEBUG -s
59
+ endif
60
+ ifdef LLAMA_NO_LLAMAFILE
61
+ GGML_NO_LLAMAFILE := 1
62
+ endif
63
+ ifndef GGML_NO_LLAMAFILE
64
+ CFLAGS += -DGGML_USE_LLAMAFILE
65
+ CXXFLAGS += -DGGML_USE_LLAMAFILE
66
+ endif
67
+
68
+ #lets try enabling everything
69
+ CFLAGS += -pthread -Wno-deprecated -Wno-deprecated-declarations -Wno-unused-variable
70
+ CXXFLAGS += -pthread -Wno-multichar -Wno-write-strings -Wno-deprecated -Wno-deprecated-declarations -Wno-unused-variable
71
+
72
+ LDFLAGS =
73
+ FASTCFLAGS = $(subst -O3,-Ofast,$(CFLAGS))
74
+ FASTCXXFLAGS = $(subst -O3,-Ofast,$(CXXFLAGS))
75
+
76
+ # these are used on windows, to build some libraries with extra old device compatibility
77
+ SIMPLECFLAGS =
78
+ SIMPLERCFLAGS =
79
+ FULLCFLAGS =
80
+ NONECFLAGS =
81
+
82
+ CLBLAST_FLAGS = -DGGML_USE_CLBLAST
83
+ FAILSAFE_FLAGS = -DUSE_FAILSAFE
84
+ VULKAN_FLAGS = -DGGML_USE_VULKAN -DSD_USE_VULKAN
85
+ ifdef LLAMA_CUBLAS
86
+ CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS
87
+ else
88
+ CUBLAS_FLAGS =
89
+ endif
90
+ CUBLASLD_FLAGS =
91
+ CUBLAS_OBJS =
92
+
93
+ OBJS_FULL += ggml-alloc.o ggml-cpu-traits.o ggml-quants.o ggml-cpu-quants.o ggml-cpu-aarch64.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm.o common.o sampling.o kcpputils.o
94
+ OBJS_SIMPLE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx2.o ggml-cpu-quants_noavx2.o ggml-cpu-aarch64_noavx2.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx2.o common.o sampling.o kcpputils.o
95
+ OBJS_SIMPLER += ggml-alloc.o ggml-cpu-traits.o ggml-quants_noavx1.o ggml-cpu-quants_noavx1.o ggml-cpu-aarch64_noavx1.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_noavx1.o common.o sampling.o kcpputils.o
96
+ OBJS_FAILSAFE += ggml-alloc.o ggml-cpu-traits.o ggml-quants_failsafe.o ggml-cpu-quants_failsafe.o ggml-cpu-aarch64_failsafe.o unicode.o unicode-data.o ggml-threading.o ggml-cpu-cpp.o gguf.o sgemm_failsafe.o common.o sampling.o kcpputils.o
97
+
98
+ # OS specific
99
+ ifeq ($(UNAME_S),Linux)
100
+ CFLAGS += -pthread
101
+ CXXFLAGS += -pthread
102
+ LDFLAGS += -ldl
103
+ endif
104
+
105
+ ifeq ($(UNAME_S),Darwin)
106
+ CFLAGS += -pthread
107
+ CXXFLAGS += -pthread
108
+ CLANG_VER = $(shell clang -v 2>&1 | head -n 1 | awk 'BEGIN {FS="[. ]"};{print $$1 $$2 $$4}')
109
+ ifeq ($(CLANG_VER),Appleclang15)
110
+ LDFLAGS += -ld_classic
111
+ endif
112
+ endif
113
+ ifeq ($(UNAME_S),FreeBSD)
114
+ CFLAGS += -pthread
115
+ CXXFLAGS += -pthread
116
+ endif
117
+ ifeq ($(UNAME_S),NetBSD)
118
+ CFLAGS += -pthread
119
+ CXXFLAGS += -pthread
120
+ endif
121
+ ifeq ($(UNAME_S),OpenBSD)
122
+ CFLAGS += -pthread
123
+ CXXFLAGS += -pthread
124
+ endif
125
+ ifeq ($(UNAME_S),Haiku)
126
+ CFLAGS += -pthread
127
+ CXXFLAGS += -pthread
128
+ endif
129
+
130
+ ifdef LLAMA_GPROF
131
+ CFLAGS += -pg
132
+ CXXFLAGS += -pg
133
+ endif
134
+ ifdef LLAMA_PERF
135
+ CFLAGS += -DGGML_PERF
136
+ CXXFLAGS += -DGGML_PERF
137
+ endif
138
+
139
+ CCV := $(shell $(CC) --version | head -n 1)
140
+ CXXV := $(shell $(CXX) --version | head -n 1)
141
+
142
+ # Architecture specific
143
+ # For x86 based architectures
144
+ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
145
+ ifdef LLAMA_PORTABLE
146
+ SIMPLECFLAGS += -mavx -msse3 -mssse3
147
+ SIMPLERCFLAGS += -msse3 -mssse3
148
+ ifdef LLAMA_NOAVX2
149
+ FULLCFLAGS += -msse3 -mssse3 -mavx
150
+ else
151
+ FULLCFLAGS += -mavx2 -msse3 -mssse3 -mfma -mf16c -mavx
152
+ endif # LLAMA_NOAVX2
153
+ else
154
+ CFLAGS += -march=native -mtune=native
155
+ endif # LLAMA_PORTABLE
156
+ endif # if x86
157
+
158
+ ifndef LLAMA_NO_ACCELERATE
159
+ # Mac M1 - include Accelerate framework.
160
+ # `-framework Accelerate` works on Mac Intel as well, with negliable performance boost (as of the predict time).
161
+ ifeq ($(UNAME_S),Darwin)
162
+ CFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE
163
+ CXXFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE
164
+ LDFLAGS += -framework Accelerate
165
+ OBJS += ggml-blas.o
166
+ endif
167
+ endif
168
+
169
+ # it is recommended to use the CMAKE file to build for cublas if you can - will likely work better
170
+ OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu))
171
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu))
172
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu))
173
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu))
174
+ OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-vec*f16-f16.cu))
175
+
176
+ ifdef LLAMA_CUBLAS
177
+ CUBLAS_FLAGS = -DGGML_USE_CUDA -DSD_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
178
+ CUBLASLD_FLAGS = -lcuda -lcublas -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/local/cuda/targets/sbsa-linux/lib -L/usr/lib/wsl/lib
179
+ CUBLAS_OBJS = ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
180
+ CUBLAS_OBJS += $(patsubst %.cu,%.o,$(filter-out ggml/src/ggml-cuda/ggml-cuda.cu, $(wildcard ggml/src/ggml-cuda/*.cu)))
181
+ CUBLAS_OBJS += $(OBJS_CUDA_TEMP_INST)
182
+ NVCC = nvcc
183
+ NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
184
+
185
+ ifdef LLAMA_ADD_CONDA_PATHS
186
+ CUBLASLD_FLAGS += -Lconda/envs/linux/lib -Lconda/envs/linux/lib/stubs
187
+ endif
188
+
189
+ ifdef CUDA_DOCKER_ARCH
190
+ NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
191
+ else
192
+ ifdef LLAMA_PORTABLE
193
+ ifdef LLAMA_COLAB #colab does not need all targets, all-major doesnt work correctly with pascal
194
+ NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=all-major
195
+ else
196
+ NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=all
197
+ endif #LLAMA_COLAB
198
+ else
199
+ NVCCFLAGS += -arch=native
200
+ endif #LLAMA_PORTABLE
201
+ endif # CUDA_DOCKER_ARCH
202
+
203
+ ifdef LLAMA_CUDA_F16
204
+ NVCCFLAGS += -DGGML_CUDA_F16
205
+ endif # LLAMA_CUDA_F16
206
+ ifdef LLAMA_CUDA_DMMV_F16
207
+ NVCCFLAGS += -DGGML_CUDA_F16
208
+ endif # LLAMA_CUDA_DMMV_F16
209
+
210
+ ifdef LLAMA_CUDA_CCBIN
211
+ NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
212
+ endif
213
+
214
+ ggml/src/ggml-cuda/%.o: ggml/src/ggml-cuda/%.cu ggml/include/ggml.h ggml/src/ggml-common.h ggml/src/ggml-cuda/common.cuh
215
+ $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
216
+ ggml-cuda.o: ggml/src/ggml-cuda/ggml-cuda.cu ggml/include/ggml-cuda.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/src/ggml-backend-impl.h ggml/src/ggml-common.h $(wildcard ggml/src/ggml-cuda/*.cuh)
217
+ $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
218
+ ggml_v2-cuda.o: otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h
219
+ $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
220
+ ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h
221
+ $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
222
+ ggml_v3-cuda.o: otherarch/ggml_v3-cuda.cu otherarch/ggml_v3-cuda.h
223
+ $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(HIPFLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@
224
+ endif # LLAMA_CUBLAS
225
+
226
+ ifdef LLAMA_HIPBLAS
227
+ ifeq ($(wildcard /opt/rocm),)
228
+ ROCM_PATH ?= /usr
229
+ GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
230
+ HCC := $(ROCM_PATH)/bin/hipcc
231
+ HCXX := $(ROCM_PATH)/bin/hipcc
232
+ else
233
+ ROCM_PATH ?= /opt/rocm
234
+ GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
235
+ HCC := $(ROCM_PATH)/llvm/bin/clang
236
+ HCXX := $(ROCM_PATH)/llvm/bin/clang++
237
+ endif
238
+ HIPFLAGS += -DGGML_USE_HIP -DGGML_HIP_NO_VMM -DGGML_HIP_ROCWMMA_FATTN -DGGML_USE_CUDA -DSD_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C)
239
+ HIPLDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
240
+ HIPLDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64
241
+ HIPLDFLAGS += -lhipblas -lamdhip64 -lrocblas
242
+ HIP_OBJS += ggml-cuda.o ggml_v3-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o
243
+ HIP_OBJS += $(patsubst %.cu,%.o,$(filter-out ggml/src/ggml-cuda/ggml-cuda.cu, $(wildcard ggml/src/ggml-cuda/*.cu)))
244
+ HIP_OBJS += $(OBJS_CUDA_TEMP_INST)
245
+
246
+ HIPFLAGS2 += $(addprefix --offload-arch=,$(GPU_TARGETS))
247
+
248
+ ggml/src/ggml-cuda/%.o: ggml/src/ggml-cuda/%.cu ggml/include/ggml.h ggml/src/ggml-common.h ggml/src/ggml-cuda/common.cuh
249
+ $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
250
+ ggml-cuda.o: ggml/src/ggml-cuda/ggml-cuda.cu ggml/include/ggml-cuda.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/src/ggml-backend-impl.h ggml/src/ggml-common.h $(wildcard ggml/src/ggml-cuda/*.cuh)
251
+ $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
252
+ ggml_v2-cuda.o: otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h
253
+ $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
254
+ ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h
255
+ $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
256
+ ggml_v3-cuda.o: otherarch/ggml_v3-cuda.cu otherarch/ggml_v3-cuda.h
257
+ $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $(HIPFLAGS2) -x hip -c -o $@ $<
258
+ endif # LLAMA_HIPBLAS
259
+
260
+
261
+ ifdef LLAMA_METAL
262
+ CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG -DSD_USE_METAL
263
+ CXXFLAGS += -DGGML_USE_METAL -DSD_USE_METAL
264
+ LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
265
+ OBJS += ggml-metal.o
266
+
267
+ ggml-metal.o: ggml/src/ggml-metal/ggml-metal.m ggml/src/ggml-metal/ggml-metal-impl.h ggml/include/ggml-metal.h
268
+ @echo "== Preparing merged Metal file =="
269
+ @sed -e '/#include "ggml-common.h"/r ggml/src/ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
270
+ @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-merged.metal
271
+ @cp ggml/src/ggml-metal/ggml-metal-merged.metal ./ggml-metal-merged.metal
272
+ $(CC) $(CFLAGS) -c $< -o $@
273
+ endif # LLAMA_METAL
274
+
275
+ ifneq ($(filter aarch64%,$(UNAME_M)),)
276
+ # Apple M1, M2, etc.
277
+ # Raspberry Pi 3, 4, Zero 2 (64-bit)
278
+ ifdef LLAMA_PORTABLE
279
+ CFLAGS +=
280
+ CXXFLAGS +=
281
+ else
282
+ # sve is cooked on termux so we are disabling it
283
+ ifeq ($(UNAME_O), Android)
284
+ ifneq ($(findstring clang, $(CCV)), )
285
+ CFLAGS += -mcpu=native+nosve
286
+ CXXFLAGS += -mcpu=native+nosve
287
+ else
288
+ CFLAGS += -mcpu=native
289
+ CXXFLAGS += -mcpu=native
290
+ endif
291
+ else
292
+ CFLAGS += -mcpu=native
293
+ CXXFLAGS += -mcpu=native
294
+ endif
295
+ endif
296
+ endif
297
+
298
+ ifneq ($(filter armv6%,$(UNAME_M)),)
299
+ # Raspberry Pi 1, Zero
300
+ CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
301
+ CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access
302
+ endif
303
+ ifneq ($(filter armv7%,$(UNAME_M)),)
304
+ # Raspberry Pi 2
305
+ CFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
306
+ CXXFLAGS += -mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access -funsafe-math-optimizations
307
+ endif
308
+ ifneq ($(filter armv8%,$(UNAME_M)),)
309
+ # Raspberry Pi 3, 4, Zero 2 (32-bit)
310
+ CFLAGS += -mfp16-format=ieee -mno-unaligned-access
311
+ CXXFLAGS += -mfp16-format=ieee -mno-unaligned-access
312
+ endif
313
+ ifneq ($(filter ppc64%,$(UNAME_M)),)
314
+ POWER9_M := $(shell grep "POWER9" /proc/cpuinfo)
315
+ ifneq (,$(findstring POWER9,$(POWER9_M)))
316
+ CFLAGS += -mcpu=power9
317
+ CXXFLAGS += -mcpu=power9
318
+ endif
319
+ endif
320
+
321
+
322
+ DEFAULT_BUILD =
323
+ FAILSAFE_BUILD =
324
+ NOAVX2_BUILD =
325
+ CLBLAST_BUILD =
326
+ CUBLAS_BUILD =
327
+ HIPBLAS_BUILD =
328
+ VULKAN_BUILD =
329
+ NOTIFY_MSG =
330
+
331
+ ifeq ($(OS),Windows_NT)
332
+ DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
333
+ ifdef LLAMA_PORTABLE
334
+ FAILSAFE_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
335
+ NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
336
+ endif
337
+
338
+ ifdef LLAMA_CLBLAST
339
+ CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ lib/OpenCL.lib lib/clblast.lib -shared -o [email protected] $(LDFLAGS)
340
+ endif
341
+ ifdef LLAMA_VULKAN
342
+ VULKAN_BUILD = $(CXX) $(CXXFLAGS) $^ lib/vulkan-1.lib -shared -o [email protected] $(LDFLAGS)
343
+ endif
344
+
345
+ ifdef LLAMA_CUBLAS
346
+ CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o [email protected] $(CUBLASLD_FLAGS) $(LDFLAGS)
347
+ endif
348
+ ifdef LLAMA_HIPBLAS
349
+ HIPBLAS_BUILD = $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o [email protected] $(HIPLDFLAGS) $(LDFLAGS)
350
+ endif
351
+ else
352
+ DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
353
+ ifdef LLAMA_PORTABLE
354
+ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64))
355
+ FAILSAFE_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
356
+ NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o [email protected] $(LDFLAGS)
357
+ endif
358
+ endif
359
+
360
+ ifdef LLAMA_CLBLAST
361
+ ifeq ($(UNAME_S),Darwin)
362
+ CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -framework OpenCL $(ARCH_ADD) -shared -o [email protected] $(LDFLAGS)
363
+ else
364
+ CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -lOpenCL $(ARCH_ADD) -shared -o [email protected] $(LDFLAGS)
365
+ endif
366
+ endif
367
+ ifdef LLAMA_CUBLAS
368
+ CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o [email protected] $(CUBLASLD_FLAGS) $(LDFLAGS)
369
+ endif
370
+ ifdef LLAMA_HIPBLAS
371
+ HIPBLAS_BUILD = $(HCXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o [email protected] $(HIPLDFLAGS) $(LDFLAGS)
372
+ endif
373
+ ifdef LLAMA_VULKAN
374
+ VULKAN_BUILD = $(CXX) $(CXXFLAGS) $^ -lvulkan -shared -o [email protected] $(LDFLAGS)
375
+ endif
376
+ endif
377
+
378
+ ifndef LLAMA_CLBLAST
379
+ ifndef LLAMA_CUBLAS
380
+ ifndef LLAMA_HIPBLAS
381
+ ifndef LLAMA_VULKAN
382
+ ifndef LLAMA_METAL
383
+ NOTIFY_MSG = @echo -e '\n***\nYou did a basic CPU build. For faster speeds, consider installing and linking a GPU BLAS library. For example, set LLAMA_CLBLAST=1 LLAMA_VULKAN=1 to compile with Vulkan and CLBlast support. Add LLAMA_PORTABLE=1 to make a sharable build that other devices can use. Read the KoboldCpp Wiki for more information. This is just a reminder, not an error.\n***\n'
384
+ endif
385
+ endif
386
+ endif
387
+ endif
388
+ endif
389
+
390
+
391
+ #
392
+ # Print build information
393
+ #
394
+
395
+ $(info I koboldcpp build info: )
396
+ $(info I UNAME_S: $(UNAME_S))
397
+ $(info I UNAME_P: $(UNAME_P))
398
+ $(info I UNAME_M: $(UNAME_M))
399
+ $(info I UNAME_O: $(UNAME_O))
400
+ $(info I CFLAGS: $(CFLAGS))
401
+ $(info I CXXFLAGS: $(CXXFLAGS))
402
+ $(info I LDFLAGS: $(LDFLAGS))
403
+ $(info I CC: $(CCV))
404
+ $(info I CXX: $(CXXV))
405
+ $(info )
406
+
407
+ #
408
+ # Build library
409
+ #
410
+
411
+ ggml.o: ggml/src/ggml.c ggml/include/ggml.h
412
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
413
+ ggml_v4_failsafe.o: ggml/src/ggml.c ggml/include/ggml.h
414
+ $(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
415
+ ggml_v4_noavx2.o: ggml/src/ggml.c ggml/include/ggml.h
416
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
417
+ ggml_v4_clblast.o: ggml/src/ggml.c ggml/include/ggml.h
418
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
419
+ ggml_v4_cublas.o: ggml/src/ggml.c ggml/include/ggml.h
420
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
421
+ ggml_v4_clblast_noavx2.o: ggml/src/ggml.c ggml/include/ggml.h
422
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
423
+ ggml_v4_clblast_failsafe.o: ggml/src/ggml.c ggml/include/ggml.h
424
+ $(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
425
+ ggml_v4_vulkan.o: ggml/src/ggml.c ggml/include/ggml.h
426
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(VULKAN_FLAGS) -c $< -o $@
427
+ ggml_v4_vulkan_noavx2.o: ggml/src/ggml.c ggml/include/ggml.h
428
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(VULKAN_FLAGS) -c $< -o $@
429
+
430
+ # cpu and clblast separated
431
+ ggml-cpu.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
432
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
433
+ ggml-cpu_v4_failsafe.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
434
+ $(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
435
+ ggml-cpu_v4_noavx2.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
436
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
437
+ ggml-cpu_v4_clblast.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
438
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
439
+ ggml-cpu_v4_clblast_noavx2.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
440
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
441
+ ggml-cpu_v4_clblast_failsafe.o: ggml/src/ggml-cpu/ggml-cpu.c ggml/include/ggml-cpu.h
442
+ $(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
443
+
444
+ #quants
445
+ ggml-quants.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
446
+ $(CC) $(CFLAGS) $(FULLCFLAGS) -c $< -o $@
447
+ ggml-quants_noavx2.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
448
+ $(CC) $(CFLAGS) $(SIMPLECFLAGS) -c $< -o $@
449
+ ggml-quants_noavx1.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
450
+ $(CC) $(CFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
451
+ ggml-quants_failsafe.o: ggml/src/ggml-quants.c ggml/include/ggml.h ggml/src/ggml-quants.h ggml/src/ggml-common.h
452
+ $(CC) $(CFLAGS) $(NONECFLAGS) -c $< -o $@
453
+ ggml-cpu-quants.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
454
+ $(CC) $(CFLAGS) $(FULLCFLAGS) -c $< -o $@
455
+ ggml-cpu-quants_noavx2.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
456
+ $(CC) $(CFLAGS) $(SIMPLECFLAGS) -c $< -o $@
457
+ ggml-cpu-quants_noavx1.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
458
+ $(CC) $(CFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
459
+ ggml-cpu-quants_failsafe.o: ggml/src/ggml-cpu/ggml-cpu-quants.c ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-quants.h ggml/src/ggml-common.h
460
+ $(CC) $(CFLAGS) $(NONECFLAGS) -c $< -o $@
461
+
462
+ #aarch64
463
+ ggml-cpu-aarch64.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
464
+ $(CXX) $(CXXFLAGS) $(FULLCFLAGS) -c $< -o $@
465
+ ggml-cpu-aarch64_noavx2.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
466
+ $(CXX) $(CXXFLAGS) $(SIMPLECFLAGS) -c $< -o $@
467
+ ggml-cpu-aarch64_noavx1.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
468
+ $(CXX) $(CXXFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
469
+ ggml-cpu-aarch64_failsafe.o: ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ggml/include/ggml.h ggml/src/ggml-cpu/ggml-cpu-aarch64.h
470
+ $(CXX) $(CXXFLAGS) $(NONECFLAGS) -c $< -o $@
471
+
472
+ #sgemm
473
+ sgemm.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
474
+ $(CXX) $(CXXFLAGS) $(FULLCFLAGS) -c $< -o $@
475
+ sgemm_noavx2.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
476
+ $(CXX) $(CXXFLAGS) $(SIMPLECFLAGS) -c $< -o $@
477
+ sgemm_noavx1.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
478
+ $(CXX) $(CXXFLAGS) $(SIMPLERCFLAGS) -c $< -o $@
479
+ sgemm_failsafe.o: ggml/src/ggml-cpu/llamafile/sgemm.cpp ggml/src/ggml-cpu/llamafile/sgemm.h ggml/include/ggml.h
480
+ $(CXX) $(CXXFLAGS) $(NONECFLAGS) -c $< -o $@
481
+
482
+ #there's no intrinsics or special gpu ops used here, so we can have a universal object
483
+ ggml-alloc.o: ggml/src/ggml-alloc.c ggml/include/ggml.h ggml/include/ggml-alloc.h
484
+ $(CC) $(CFLAGS) -c $< -o $@
485
+ llava.o: examples/llava/llava.cpp examples/llava/llava.h
486
+ $(CXX) $(CXXFLAGS) -c $< -o $@
487
+ unicode.o: src/unicode.cpp src/unicode.h
488
+ $(CXX) $(CXXFLAGS) -c $< -o $@
489
+ unicode-data.o: src/unicode-data.cpp src/unicode-data.h
490
+ $(CXX) $(CXXFLAGS) -c $< -o $@
491
+ ggml-cpu-traits.o: ggml/src/ggml-cpu/ggml-cpu-traits.cpp ggml/src/ggml-cpu/ggml-cpu-traits.h ggml/include/ggml.h
492
+ $(CXX) $(CXXFLAGS) -c $< -o $@
493
+ ggml-threading.o: ggml/src/ggml-threading.cpp ggml/include/ggml.h
494
+ $(CXX) $(CXXFLAGS) -c $< -o $@
495
+ ggml-cpu-cpp.o: ggml/src/ggml-cpu/ggml-cpu.cpp ggml/include/ggml.h ggml/src/ggml-common.h
496
+ $(CXX) $(CXXFLAGS) -c $< -o $@
497
+ gguf.o: ggml/src/gguf.cpp ggml/include/gguf.h
498
+ $(CXX) $(CXXFLAGS) -c $< -o $@
499
+ kcpputils.o: otherarch/utils.cpp otherarch/utils.h
500
+ $(CXX) $(CXXFLAGS) -c $< -o $@
501
+
502
+ #these have special gpu defines
503
+ ggml-backend_default.o: ggml/src/ggml-backend.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h
504
+ $(CXX) $(CXXFLAGS) -c $< -o $@
505
+ ggml-backend_vulkan.o: ggml/src/ggml-backend.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h
506
+ $(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
507
+ ggml-backend_cublas.o: ggml/src/ggml-backend.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h
508
+ $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
509
+ ggml-backend-reg_default.o: ggml/src/ggml-backend-reg.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/include/ggml-cpu.h
510
+ $(CXX) $(CXXFLAGS) -c $< -o $@
511
+ ggml-backend-reg_vulkan.o: ggml/src/ggml-backend-reg.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/include/ggml-cpu.h
512
+ $(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
513
+ ggml-backend-reg_cublas.o: ggml/src/ggml-backend-reg.cpp ggml/src/ggml-backend-impl.h ggml/include/ggml.h ggml/include/ggml-backend.h ggml/include/ggml-cpu.h
514
+ $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
515
+ llavaclip_default.o: examples/llava/clip.cpp examples/llava/clip.h
516
+ $(CXX) $(CXXFLAGS) -c $< -o $@
517
+ llavaclip_cublas.o: examples/llava/clip.cpp examples/llava/clip.h
518
+ $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
519
+ llavaclip_vulkan.o: examples/llava/clip.cpp examples/llava/clip.h
520
+ $(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
521
+
522
+ #this is only used for accelerate
523
+ ggml-blas.o: ggml/src/ggml-blas/ggml-blas.cpp ggml/include/ggml-blas.h
524
+ $(CXX) $(CXXFLAGS) -c $< -o $@
525
+
526
+ #version 3 libs
527
+ ggml_v3.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
528
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
529
+ ggml_v3_failsafe.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
530
+ $(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
531
+ ggml_v3_noavx2.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
532
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
533
+ ggml_v3_clblast.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
534
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
535
+ ggml_v3_cublas.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
536
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
537
+ ggml_v3_clblast_noavx2.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
538
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
539
+ ggml_v3_clblast_failsafe.o: otherarch/ggml_v3.c otherarch/ggml_v3.h
540
+ $(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
541
+
542
+ #version 2 libs
543
+ ggml_v2.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
544
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
545
+ ggml_v2_failsafe.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
546
+ $(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
547
+ ggml_v2_noavx2.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
548
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) -c $< -o $@
549
+ ggml_v2_clblast.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
550
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
551
+ ggml_v2_cublas.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
552
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
553
+ ggml_v2_clblast_noavx2.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
554
+ $(CC) $(FASTCFLAGS) $(SIMPLECFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
555
+ ggml_v2_clblast_failsafe.o: otherarch/ggml_v2.c otherarch/ggml_v2.h
556
+ $(CC) $(FASTCFLAGS) $(SIMPLERCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
557
+
558
+ #extreme old version compat
559
+ ggml_v1.o: otherarch/ggml_v1.c otherarch/ggml_v1.h
560
+ $(CC) $(FASTCFLAGS) $(FULLCFLAGS) -c $< -o $@
561
+ ggml_v1_failsafe.o: otherarch/ggml_v1.c otherarch/ggml_v1.h
562
+ $(CC) $(FASTCFLAGS) $(NONECFLAGS) -c $< -o $@
563
+
564
+ #opencl
565
+ ggml-opencl.o: otherarch/ggml_v3b-opencl.cpp otherarch/ggml_v3b-opencl.h
566
+ $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
567
+ ggml_v2-opencl.o: otherarch/ggml_v2-opencl.cpp otherarch/ggml_v2-opencl.h
568
+ $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
569
+ ggml_v2-opencl-legacy.o: otherarch/ggml_v2-opencl-legacy.c otherarch/ggml_v2-opencl-legacy.h
570
+ $(CC) $(CFLAGS) -c $< -o $@
571
+ ggml_v3-opencl.o: otherarch/ggml_v3-opencl.cpp otherarch/ggml_v3-opencl.h
572
+ $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
573
+
574
+ #vulkan
575
+ ggml-vulkan.o: ggml/src/ggml-vulkan/ggml-vulkan.cpp ggml/include/ggml-vulkan.h ggml/src/ggml-vulkan-shaders.hpp ggml/src/ggml-vulkan-shaders.cpp
576
+ $(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
577
+
578
+ # intermediate objects
579
+ llama.o: src/llama.cpp ggml/include/ggml.h ggml/include/ggml-alloc.h ggml/include/ggml-backend.h ggml/include/ggml-cuda.h ggml/include/ggml-metal.h include/llama.h otherarch/llama-util.h
580
+ $(CXX) $(CXXFLAGS) -c $< -o $@
581
+ common.o: common/common.cpp common/common.h common/log.h
582
+ $(CXX) $(CXXFLAGS) -c $< -o $@
583
+ sampling.o: common/sampling.cpp common/common.h common/sampling.h common/log.h
584
+ $(CXX) $(CXXFLAGS) -c $< -o $@
585
+ console.o: common/console.cpp common/console.h
586
+ $(CXX) $(CXXFLAGS) -c $< -o $@
587
+ expose.o: expose.cpp expose.h
588
+ $(CXX) $(CXXFLAGS) -c $< -o $@
589
+
590
+ # sd.cpp objects
591
+ sdcpp_default.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
592
+ $(CXX) $(CXXFLAGS) -c $< -o $@
593
+ sdcpp_cublas.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
594
+ $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
595
+ sdcpp_vulkan.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c
596
+ $(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
597
+
598
+
599
+ #whisper objects
600
+ whispercpp_default.o: otherarch/whispercpp/whisper_adapter.cpp
601
+ $(CXX) $(CXXFLAGS) -c $< -o $@
602
+ whispercpp_cublas.o: otherarch/whispercpp/whisper_adapter.cpp
603
+ $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
604
+
605
+ #tts objects
606
+ tts_default.o: otherarch/tts_adapter.cpp
607
+ $(CXX) $(CXXFLAGS) -c $< -o $@
608
+
609
+ # idiotic "for easier compilation"
610
+ GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp src/llama.cpp src/llama-impl.cpp src/llama-chat.cpp src/llama-mmap.cpp src/llama-context.cpp src/llama-adapter.cpp src/llama-arch.cpp src/llama-batch.cpp src/llama-vocab.cpp src/llama-grammar.cpp src/llama-sampling.cpp src/llama-kv-cache.cpp src/llama-model-loader.cpp src/llama-model.cpp src/llama-quant.cpp src/llama-hparams.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml/include/ggml.h ggml/include/ggml-cpu.h ggml/include/ggml-cuda.h include/llama.h otherarch/llama-util.h
611
+ gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER)
612
+ $(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) -c $< -o $@
613
+ gpttype_adapter.o: $(GPTTYPE_ADAPTER)
614
+ $(CXX) $(CXXFLAGS) -c $< -o $@
615
+ gpttype_adapter_clblast.o: $(GPTTYPE_ADAPTER)
616
+ $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@
617
+ gpttype_adapter_cublas.o: $(GPTTYPE_ADAPTER)
618
+ $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@
619
+ gpttype_adapter_clblast_noavx2.o: $(GPTTYPE_ADAPTER)
620
+ $(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) $(CLBLAST_FLAGS) -c $< -o $@
621
+ gpttype_adapter_vulkan.o: $(GPTTYPE_ADAPTER)
622
+ $(CXX) $(CXXFLAGS) $(VULKAN_FLAGS) -c $< -o $@
623
+ gpttype_adapter_vulkan_noavx2.o: $(GPTTYPE_ADAPTER)
624
+ $(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) $(VULKAN_FLAGS) -c $< -o $@
625
+
626
+ clean:
627
+ rm -vf *.o main sdmain whispermain quantize_gguf quantize_clip quantize_gpt2 quantize_gptj quantize_neox quantize_mpt vulkan-shaders-gen gguf-split gguf-split.exe vulkan-shaders-gen.exe main.exe sdmain.exe whispermain.exe quantize_clip.exe quantize_gguf.exe quantize_gptj.exe quantize_gpt2.exe quantize_neox.exe quantize_mpt.exe koboldcpp_default.dll koboldcpp_failsafe.dll koboldcpp_noavx2.dll koboldcpp_clblast.dll koboldcpp_clblast_noavx2.dll koboldcpp_clblast_failsafe.dll koboldcpp_cublas.dll koboldcpp_hipblas.dll koboldcpp_vulkan.dll koboldcpp_vulkan_noavx2.dll koboldcpp_default.so koboldcpp_failsafe.so koboldcpp_noavx2.so koboldcpp_clblast.so koboldcpp_clblast_noavx2.so koboldcpp_clblast_failsafe.so koboldcpp_cublas.so koboldcpp_hipblas.so koboldcpp_vulkan.so koboldcpp_vulkan_noavx2.so
628
+ rm -vrf ggml/src/ggml-cuda/*.o
629
+ rm -vrf ggml/src/ggml-cuda/template-instances/*.o
630
+
631
+ # useful tools
632
+ main: examples/main/main.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
633
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
634
+ sdmain: otherarch/sdcpp/util.cpp otherarch/sdcpp/main.cpp otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c build-info.h ggml.o ggml-cpu.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
635
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
636
+ whispermain: otherarch/whispercpp/main.cpp otherarch/whispercpp/whisper.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
637
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
638
+ ttsmain: examples/tts/tts.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
639
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
640
+ gguf-split: examples/gguf-split/gguf-split.cpp ggml.o ggml-cpu.o llama.o build-info.h llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
641
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
642
+ gemma3-cli: examples/llava/gemma3-cli.cpp common/arg.cpp build-info.h ggml.o ggml-cpu.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
643
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
644
+
645
+ vulkan-shaders-gen: ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
646
+ @echo 'This command can be MANUALLY run to regenerate vulkan shaders. Normally concedo will do it, so you do not have to.'
647
+ $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
648
+ ifeq ($(OS),Windows_NT)
649
+ @echo 'Now rebuilding vulkan shaders for Windows...'
650
+ $(shell) vulkan-shaders-gen --glslc glslc --input-dir ggml/src/ggml-vulkan/vulkan-shaders --target-hpp ggml/src/ggml-vulkan-shaders.hpp --target-cpp ggml/src/ggml-vulkan-shaders.cpp
651
+ else
652
+ @echo 'Now rebuilding vulkan shaders for Linux...'
653
+ ${shell} chmod +x vulkan-shaders-gen
654
+ ${shell} chmod +x glslc-linux
655
+ $(shell) ./vulkan-shaders-gen --glslc ./glslc-linux --input-dir ggml/src/ggml-vulkan/vulkan-shaders --target-hpp ggml/src/ggml-vulkan-shaders.hpp --target-cpp ggml/src/ggml-vulkan-shaders.cpp
656
+ endif
657
+
658
+ #generated libraries
659
+ koboldcpp_default: ggml.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
660
+ $(DEFAULT_BUILD)
661
+
662
+ ifdef FAILSAFE_BUILD
663
+ koboldcpp_failsafe: ggml_v4_failsafe.o ggml-cpu_v4_failsafe.o ggml_v3_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FAILSAFE) $(OBJS)
664
+ $(FAILSAFE_BUILD)
665
+ else
666
+ koboldcpp_failsafe:
667
+ $(DONOTHING)
668
+ endif
669
+
670
+ ifdef NOAVX2_BUILD
671
+ koboldcpp_noavx2: ggml_v4_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLE) $(OBJS)
672
+ $(NOAVX2_BUILD)
673
+ else
674
+ koboldcpp_noavx2:
675
+ $(DONOTHING)
676
+ endif
677
+
678
+ ifdef CLBLAST_BUILD
679
+ koboldcpp_clblast: ggml_v4_clblast.o ggml-cpu_v4_clblast.o ggml_v3_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL) $(OBJS)
680
+ $(CLBLAST_BUILD)
681
+ ifdef NOAVX2_BUILD
682
+ koboldcpp_clblast_noavx2: ggml_v4_clblast_noavx2.o ggml-cpu_v4_clblast_noavx2.o ggml_v3_clblast_noavx2.o ggml_v2_clblast_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLE) $(OBJS)
683
+ $(CLBLAST_BUILD)
684
+ koboldcpp_clblast_failsafe: ggml_v4_clblast_failsafe.o ggml-cpu_v4_clblast_failsafe.o ggml_v3_clblast_failsafe.o ggml_v2_clblast_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o tts_default.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_SIMPLER) $(OBJS)
685
+ $(CLBLAST_BUILD)
686
+ else
687
+ koboldcpp_clblast_noavx2:
688
+ $(DONOTHING)
689
+ koboldcpp_clblast_failsafe:
690
+ $(DONOTHING)
691
+ endif
692
+ else
693
+ koboldcpp_clblast:
694
+ $(DONOTHING)
695
+ koboldcpp_clblast_noavx2:
696
+ $(DONOTHING)
697
+ koboldcpp_clblast_failsafe:
698
+ $(DONOTHING)
699
+ endif
700
+
701
+ ifdef CUBLAS_BUILD
702
+ koboldcpp_cublas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(CUBLAS_OBJS) $(OBJS_FULL) $(OBJS)
703
+ $(CUBLAS_BUILD)
704
+ else
705
+ koboldcpp_cublas:
706
+ $(DONOTHING)
707
+ endif
708
+
709
+ ifdef HIPBLAS_BUILD
710
+ koboldcpp_hipblas: ggml_v4_cublas.o ggml-cpu.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_cublas.o tts_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o ggml-backend-reg_cublas.o $(HIP_OBJS) $(OBJS_FULL) $(OBJS)
711
+ $(HIPBLAS_BUILD)
712
+ else
713
+ koboldcpp_hipblas:
714
+ $(DONOTHING)
715
+ endif
716
+
717
+ ifdef VULKAN_BUILD
718
+ koboldcpp_vulkan: ggml_v4_vulkan.o ggml-cpu.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter_vulkan.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o tts_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_FULL) $(OBJS)
719
+ $(VULKAN_BUILD)
720
+ ifdef NOAVX2_BUILD
721
+ koboldcpp_vulkan_noavx2: ggml_v4_vulkan_noavx2.o ggml-cpu_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_vulkan_noavx2.o ggml-vulkan.o sdcpp_vulkan.o whispercpp_default.o tts_default.o llavaclip_vulkan.o llava.o ggml-backend_vulkan.o ggml-backend-reg_vulkan.o $(OBJS_SIMPLE) $(OBJS)
722
+ $(VULKAN_BUILD)
723
+ else
724
+ koboldcpp_vulkan_noavx2:
725
+ $(DONOTHING)
726
+ endif
727
+ else
728
+ koboldcpp_vulkan:
729
+ $(DONOTHING)
730
+ koboldcpp_vulkan_noavx2:
731
+ $(DONOTHING)
732
+ endif
733
+
734
+ # tools
735
+ quantize_gguf: examples/quantize/quantize.cpp ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
736
+ $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
737
+ quantize_gptj: otherarch/tools/gptj_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
738
+ $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
739
+ quantize_gpt2: otherarch/tools/gpt2_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
740
+ $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
741
+ quantize_neox: otherarch/tools/neox_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
742
+ $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
743
+ quantize_mpt: otherarch/tools/mpt_quantize.cpp otherarch/tools/common-ggml.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o llavaclip_default.o llava.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
744
+ $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
745
+ quantize_clip: examples/llava/clip.cpp examples/llava/clip.h examples/llava/quantclip.cpp ggml_v3.o ggml.o ggml-cpu.o llama.o ggml-backend_default.o ggml-backend-reg_default.o $(OBJS_FULL)
746
+ $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
747
+
748
+ #window simple clinfo
749
+ simpleclinfo: simpleclinfo.cpp
750
+ $(CXX) $(CXXFLAGS) $^ lib/OpenCL.lib lib/clblast.lib -o $@ $(LDFLAGS)
751
+
752
+ build-info.h:
753
+ $(DONOTHING)
754
+
755
+ #phony for printing messages
756
+ finishedmsg:
757
+ $(NOTIFY_MSG)
758
+ $(DONOTHING)
OpenCL.dll ADDED
Binary file (55.8 kB). View file
 
README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # koboldcpp
2
+
3
+ KoboldCpp is an easy-to-use AI text-generation software for GGML and GGUF models, inspired by the original **KoboldAI**. It's a single self-contained distributable from Concedo, that builds off llama.cpp and adds many additional powerful features.
4
+
5
+ ![Preview](media/preview.png)
6
+ ![Preview](media/preview2.png)
7
+ ![Preview](media/preview3.png)
8
+ ![Preview](media/preview4.png)
9
+
10
+ ### Features
11
+ - Single file executable, with no installation required and no external dependencies
12
+ - Runs on CPU or GPU, supports full or partial offloaded
13
+ - LLM text generation (Supports all GGML and GGUF models, backwards compatibility with ALL past models)
14
+ - Image Generation (Stable Diffusion 1.5, SDXL, SD3, Flux)
15
+ - Speech-To-Text (Voice Recognition) via Whisper
16
+ - Text-To-Speech (Voice Generation) via OuteTTS
17
+ - Provides many compatible APIs endpoints for many popular webservices (KoboldCppApi OpenAiApi OllamaApi A1111ForgeApi ComfyUiApi WhisperTranscribeApi XttsApi OpenAiSpeechApi)
18
+ - Bundled KoboldAI Lite UI with editing tools, save formats, memory, world info, author's note, characters, scenarios.
19
+ - Includes multiple modes (chat, adventure, instruct, storywriter) and UI Themes (aesthetic roleplay, classic writer, corporate assistant, messsenger)
20
+ - Supports loading Tavern Character Cards, importing many different data formats from various sites, reading or exporting JSON savefiles and persistent stories.
21
+ - Many other features including new samplers, regex support, websearch, RAG via TextDB and more.
22
+ - Ready-to-use binaries for Windows, MacOS, Linux, Android (via Termux), Colab, Docker, also supports other platforms if self-compiled (like Raspberry PI).
23
+ - [Need help finding a model? Read this!](https://github.com/LostRuins/koboldcpp/wiki#getting-an-ai-model-file)
24
+
25
+ ## Windows Usage (Precompiled Binary, Recommended)
26
+ - Windows binaries are provided in the form of **koboldcpp.exe**, which is a pyinstaller wrapper containing all necessary files. **[Download the latest koboldcpp.exe release here](https://github.com/LostRuins/koboldcpp/releases/latest)**
27
+ - To run, simply execute **koboldcpp.exe**.
28
+ - Launching with no command line arguments displays a GUI containing a subset of configurable settings. Generally you dont have to change much besides the `Presets` and `GPU Layers`. Read the `--help` for more info about each settings.
29
+ - Obtain and load a GGUF model. See [here](#Obtaining-a-GGUF-model)
30
+ - By default, you can connect to http://localhost:5001
31
+ - You can also run it using the command line. For info, please check `koboldcpp.exe --help`
32
+
33
+ ## Linux Usage (Precompiled Binary, Recommended)
34
+ On modern Linux systems, you should download the `koboldcpp-linux-x64-cuda1150` prebuilt PyInstaller binary on the **[releases page](https://github.com/LostRuins/koboldcpp/releases/latest)**. Simply download and run the binary (You may have to `chmod +x` it first).
35
+
36
+ Alternatively, you can also install koboldcpp to the current directory by running the following terminal command:
37
+ ```
38
+ curl -fLo koboldcpp https://github.com/LostRuins/koboldcpp/releases/latest/download/koboldcpp-linux-x64-cuda1150 && chmod +x koboldcpp
39
+ ```
40
+ After running this command you can launch Koboldcpp from the current directory using `./koboldcpp` in the terminal (for CLI usage, run with `--help`).
41
+ Finally, obtain and load a GGUF model. See [here](#Obtaining-a-GGUF-model)
42
+
43
+ ## MacOS (Precompiled Binary)
44
+ - PyInstaller binaries for Modern ARM64 MacOS (M1, M2, M3) are now available! **[Simply download the MacOS binary](https://github.com/LostRuins/koboldcpp/releases/latest)**
45
+ - In a MacOS terminal window, set the file to executable `chmod +x koboldcpp-mac-arm64` and run it with `./koboldcpp-mac-arm64`.
46
+ - In newer MacOS you may also have to whitelist it in security settings if it's blocked. [Here's a video guide](https://youtube.com/watch?v=NOW5dyA_JgY).
47
+ - Alternatively, or for older x86 MacOS computers, you can clone the repo and compile from source code, see Compiling for MacOS below.
48
+ - Finally, obtain and load a GGUF model. See [here](#Obtaining-a-GGUF-model)
49
+
50
+ ## Run on Colab
51
+ - KoboldCpp now has an **official Colab GPU Notebook**! This is an easy way to get started without installing anything in a minute or two. [Try it here!](https://colab.research.google.com/github/LostRuins/koboldcpp/blob/concedo/colab.ipynb).
52
+ - Note that KoboldCpp is not responsible for your usage of this Colab Notebook, you should ensure that your own usage complies with Google Colab's terms of use.
53
+
54
+ ## Run on RunPod
55
+ - KoboldCpp can now be used on RunPod cloud GPUs! This is an easy way to get started without installing anything in a minute or two, and is very scalable, capable of running 70B+ models at afforable cost. [Try our RunPod image here!](https://koboldai.org/runpodcpp).
56
+
57
+ ## Run on Novita AI
58
+ KoboldCpp can now also be run on Novita AI, a newer alternative GPU cloud provider which has a quick launch KoboldCpp template for as well. [Check it out here!](https://koboldai.org/novitacpp)
59
+
60
+ ## Docker
61
+ - The official docker can be found at https://hub.docker.com/r/koboldai/koboldcpp
62
+ - If you're building your own docker, remember to set CUDA_DOCKER_ARCH or enable LLAMA_PORTABLE
63
+
64
+ ## Obtaining a GGUF model
65
+ - KoboldCpp uses GGUF models. They are not included with KoboldCpp, but you can download GGUF files from other places such as [TheBloke's Huggingface](https://huggingface.co/TheBloke). Search for "GGUF" on huggingface.co for plenty of compatible models in the `.gguf` format.
66
+ - For beginners, we recommend the models [Airoboros Mistral 7B](https://huggingface.co/TheBloke/airoboros-mistral2.2-7B-GGUF/resolve/main/airoboros-mistral2.2-7b.Q4_K_S.gguf) (smaller and weaker) or [Tiefighter 13B](https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf) (larger model) or [Beepo 22B](https://huggingface.co/concedo/Beepo-22B-GGUF/resolve/main/Beepo-22B-Q4_K_S.gguf) (largest and most powerful)
67
+ - [Alternatively, you can download the tools to convert models to the GGUF format yourself here](https://kcpptools.concedo.workers.dev). Run `convert-hf-to-gguf.py` to convert them, then `quantize_gguf.exe` to quantize the result.
68
+ - Other models for Whisper (speech recognition), Image Generation, Text to Speech or Image Recognition [can be found on the Wiki](https://github.com/LostRuins/koboldcpp/wiki#what-models-does-koboldcpp-support-what-architectures-are-supported)
69
+
70
+ ## Improving Performance
71
+ - **GPU Acceleration**: If you're on Windows with an Nvidia GPU you can get CUDA support out of the box using the `--usecublas` flag (Nvidia Only), or `--usevulkan` (Any GPU), make sure you select the correct .exe with CUDA support.
72
+ - **GPU Layer Offloading**: Add `--gpulayers` to offload model layers to the GPU. The more layers you offload to VRAM, the faster generation speed will become. Experiment to determine number of layers to offload, and reduce by a few if you run out of memory.
73
+ - **Increasing Context Size**: Use `--contextsize (number)` to increase context size, allowing the model to read more text. Note that you may also need to increase the max context in the KoboldAI Lite UI as well (click and edit the number text field).
74
+ - **Old CPU Compatibility**: If you are having crashes or issues, you can try running in a non-avx2 compatibility mode by adding the `--noavx2` flag. You can also try turning off mmap with `--nommap` or reducing your `--blasbatchssize` (set -1 to avoid batching)
75
+
76
+ For more information, be sure to run the program with the `--help` flag, or **[check the wiki](https://github.com/LostRuins/koboldcpp/wiki).**
77
+
78
+ ## Compiling KoboldCpp From Source Code
79
+
80
+ ### Compiling on Linux (Using koboldcpp.sh automated compiler script)
81
+ when you can't use the precompiled binary directly, we provide an automated build script which uses conda to obtain all dependencies, and generates (from source) a ready-to-use a pyinstaller binary for linux users.
82
+ - Clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
83
+ - Simply execute the build script with `./koboldcpp.sh dist` and run the generated binary. (Not recommended for systems that already have an existing installation of conda. Dependencies: curl, bzip2)
84
+ ```
85
+ ./koboldcpp.sh # This launches the GUI for easy configuration and launching (X11 required).
86
+ ./koboldcpp.sh --help # List all available terminal commands for using Koboldcpp, you can use koboldcpp.sh the same way as our python script and binaries.
87
+ ./koboldcpp.sh rebuild # Automatically generates a new conda runtime and compiles a fresh copy of the libraries. Do this after updating Koboldcpp to keep everything functional.
88
+ ./koboldcpp.sh dist # Generate your own precompiled binary (Due to the nature of Linux compiling these will only work on distributions equal or newer than your own.)
89
+ ```
90
+
91
+ ### Compiling on Linux (Manual Method)
92
+ - To compile your binaries from source, clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
93
+ - A makefile is provided, simply run `make`.
94
+ - Optional Vulkan: Link your own install of Vulkan SDK manually with `make LLAMA_VULKAN=1`
95
+ - Optional CLBlast: Link your own install of CLBlast manually with `make LLAMA_CLBLAST=1`
96
+ - Note: for these you will need to obtain and link OpenCL and CLBlast libraries.
97
+ - For Arch Linux: Install `cblas` and `clblast`.
98
+ - For Debian: Install `libclblast-dev`.
99
+ - You can attempt a CuBLAS build with `LLAMA_CUBLAS=1`, (or `LLAMA_HIPBLAS=1` for AMD). You will need CUDA Toolkit installed. Some have also reported success with the CMake file, though that is more for windows.
100
+ - For a full featured build (all backends), do `make LLAMA_CLBLAST=1 LLAMA_CUBLAS=1 LLAMA_VULKAN=1`. (Note that `LLAMA_CUBLAS=1` will not work on windows, you need visual studio)
101
+ - To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`
102
+ - After all binaries are built, you can run the python script with the command `koboldcpp.py [ggml_model.gguf] [port]`
103
+
104
+ ### Compiling on Windows
105
+ - You're encouraged to use the .exe released, but if you want to compile your binaries from source at Windows, the easiest way is:
106
+ - Get the latest release of w64devkit (https://github.com/skeeto/w64devkit). Be sure to use the "vanilla one", not i686 or other different stuff. If you try they will conflit with the precompiled libs!
107
+ - Clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
108
+ - Make sure you are using the w64devkit integrated terminal, then run `make` at the KoboldCpp source folder. This will create the .dll files for a pure CPU native build.
109
+ - For a full featured build (all backends), do `make LLAMA_CLBLAST=1 LLAMA_VULKAN=1`. (Note that `LLAMA_CUBLAS=1` will not work on windows, you need visual studio)
110
+ - To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`
111
+ - If you want to generate the .exe file, make sure you have the python module PyInstaller installed with pip (`pip install PyInstaller`). Then run the script `make_pyinstaller.bat`
112
+ - The koboldcpp.exe file will be at your dist folder.
113
+ - **Building with CUDA**: Visual Studio, CMake and CUDA Toolkit is required. Clone the repo, then open the CMake file and compile it in Visual Studio. Copy the `koboldcpp_cublas.dll` generated into the same directory as the `koboldcpp.py` file. If you are bundling executables, you may need to include CUDA dynamic libraries (such as `cublasLt64_11.dll` and `cublas64_11.dll`) in order for the executable to work correctly on a different PC.
114
+ - **Replacing Libraries (Not Recommended)**: If you wish to use your own version of the additional Windows libraries (OpenCL, CLBlast, Vulkan), you can do it with:
115
+ - OpenCL - tested with https://github.com/KhronosGroup/OpenCL-SDK . If you wish to compile it, follow the repository instructions. You will need vcpkg.
116
+ - CLBlast - tested with https://github.com/CNugteren/CLBlast . If you wish to compile it you will need to reference the OpenCL files. It will only generate the ".lib" file if you compile using MSVC.
117
+ - Move the respectives .lib files to the /lib folder of your project, overwriting the older files.
118
+ - Also, replace the existing versions of the corresponding .dll files located in the project directory root (e.g. clblast.dll).
119
+ - Make the KoboldCpp project using the instructions above.
120
+
121
+ ### Compiling on MacOS
122
+ - You can compile your binaries from source. You can clone the repo with `git clone https://github.com/LostRuins/koboldcpp.git`
123
+ - A makefile is provided, simply run `make`.
124
+ - If you want Metal GPU support, instead run `make LLAMA_METAL=1`, note that MacOS metal libraries need to be installed.
125
+ - To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`
126
+ - After all binaries are built, you can run the python script with the command `koboldcpp.py --model [ggml_model.gguf]` (and add `--gpulayers (number of layer)` if you wish to offload layers to GPU).
127
+
128
+ ### Compiling on Android (Termux Installation)
129
+ - [Install and run Termux from F-Droid](https://f-droid.org/en/packages/com.termux/)
130
+ - Enter the command `termux-change-repo` and choose `Mirror by BFSU`
131
+ - Install dependencies with `pkg install wget git python` (plus any other missing packages)
132
+ - Install dependencies `apt install openssl` (if needed)
133
+ - Clone the repo `git clone https://github.com/LostRuins/koboldcpp.git`
134
+ - Navigate to the koboldcpp folder `cd koboldcpp`
135
+ - Build the project `make`
136
+ - To make your build sharable and capable of working on other devices, you must use `LLAMA_PORTABLE=1`, this disables usage of ARM instrinsics.
137
+ - Grab a small GGUF model, such as `wget https://huggingface.co/concedo/KobbleTinyV2-1.1B-GGUF/resolve/main/KobbleTiny-Q4_K.gguf`
138
+ - Start the python server `python koboldcpp.py --model KobbleTiny-Q4_K.gguf`
139
+ - Connect to `http://localhost:5001` on your mobile browser
140
+ - If you encounter any errors, make sure your packages are up-to-date with `pkg up`
141
+ - GPU acceleration for Termux may be possible but I have not explored it. If you find a good cross-device solution, do share or PR it.
142
+
143
+ ## AMD Users
144
+ - For most users, you can get very decent speeds by selecting the **Vulkan** option instead, which supports both Nvidia and AMD GPUs.
145
+ - Alternatively, you can try the ROCM fork at https://github.com/YellowRoseCx/koboldcpp-rocm
146
+
147
+ ## Third Party Resources
148
+ - These unofficial resources have been contributed by the community, and may be outdated or unmaintained. No official support will be provided for them!
149
+ - Arch Linux Packages: [CUBLAS](https://aur.archlinux.org/packages/koboldcpp-cuda), and [HIPBLAS](https://aur.archlinux.org/packages/koboldcpp-hipblas).
150
+ - Unofficial Dockers: [korewaChino](https://github.com/korewaChino/koboldCppDocker) and [noneabove1182](https://github.com/noneabove1182/koboldcpp-docker)
151
+ - Nix & NixOS: KoboldCpp is available on Nixpkgs and can be installed by adding just `koboldcpp` to your `environment.systemPackages` *(or it can also be placed in `home.packages`)*.
152
+ - [Example Nix Setup and further information](examples/nix_example.md)
153
+ - If you face any issues with running KoboldCpp on Nix, please open an issue [here](https://github.com/NixOS/nixpkgs/issues/new?assignees=&labels=0.kind%3A+bug&projects=&template=bug_report.md&title=).
154
+ - [GPTLocalhost](https://gptlocalhost.com/demo#KoboldCpp) - KoboldCpp is supported by GPTLocalhost, a local Word Add-in for you to use KoboldCpp in Microsoft Word. A local alternative to "Copilot in Word."
155
+
156
+ ## Questions and Help Wiki
157
+ - **First, please check out [The KoboldCpp FAQ and Knowledgebase](https://github.com/LostRuins/koboldcpp/wiki) which may already have answers to your questions! Also please search through past issues and discussions.**
158
+ - If you cannot find an answer, open an issue on this github, or find us on the [KoboldAI Discord](https://koboldai.org/discord).
159
+
160
+ ## KoboldCpp and KoboldAI API Documentation
161
+ - [Documentation for KoboldAI and KoboldCpp endpoints can be found here](https://lite.koboldai.net/koboldcpp_api)
162
+
163
+ ## KoboldCpp Public Demo
164
+ - [A public KoboldCpp demo can be found at our Huggingface Space. Please do not abuse it.](https://koboldai-koboldcpp-tiefighter.hf.space/)
165
+
166
+ ## Considerations
167
+ - For Windows: No installation, single file executable, (It Just Works)
168
+ - Since v1.15, requires CLBlast if enabled, the prebuilt windows binaries are included in this repo. If not found, it will fall back to a mode without CLBlast.
169
+ - Since v1.33, you can set the context size to be above what the model supports officially. It does increases perplexity but should still work well below 4096 even on untuned models. (For GPT-NeoX, GPT-J, and Llama models) Customize this with `--ropeconfig`.
170
+ - Since v1.42, supports GGUF models for LLAMA and Falcon
171
+ - Since v1.55, lcuda paths on Linux are hardcoded and may require manual changes to the makefile if you do not use koboldcpp.sh for the compilation.
172
+ - Since v1.60, provides native image generation with StableDiffusion.cpp, you can load any SD1.5 or SDXL .safetensors model and it will provide an A1111 compatible API to use.
173
+ - **I try to keep backwards compatibility with ALL past llama.cpp models**. But you are also encouraged to reconvert/update your models if possible for best results.
174
+ - Since v1.75, openblas has been deprecated and removed in favor of the native CPU implementation.
175
+
176
+ ## License
177
+ - The original GGML library and llama.cpp by ggerganov are licensed under the MIT License
178
+ - However, KoboldAI Lite is licensed under the AGPL v3.0 License
179
+ - KoboldCpp code and other files are also under the AGPL v3.0 License unless otherwise stated
180
+
181
+ ## Notes
182
+ - If you wish, after building the koboldcpp libraries with `make`, you can rebuild the exe yourself with pyinstaller by using `make_pyinstaller.bat`
183
+ - API documentation available at `/api` (e.g. `http://localhost:5001/api`) and https://lite.koboldai.net/koboldcpp_api. An OpenAI compatible API is also provided at `/v1` route (e.g. `http://localhost:5001/v1`).
184
+ - **All up-to-date GGUF models are supported**, and KoboldCpp also includes backward compatibility for older versions/legacy GGML `.bin` models, though some newer features might be unavailable.
185
+ - An incomplete list of architectures is listed, but there are *many hundreds of other GGUF models*. In general, if it's GGUF, it should work.
186
+ - Llama / Llama2 / Llama3 / Alpaca / GPT4All / Vicuna / Koala / Pygmalion / Metharme / WizardLM / Mistral / Mixtral / Miqu / Qwen / Qwen2 / Yi / Gemma / Gemma2 / GPT-2 / Cerebras / Phi-2 / Phi-3 / GPT-NeoX / Pythia / StableLM / Dolly / RedPajama / GPT-J / RWKV4 / MPT / Falcon / Starcoder / Deepseek and many, **many** more.
187
+
188
+ # Where can I download AI model files?
189
+ - The best place to get GGUF text models is huggingface. For image models, CivitAI has a good selection. Here are some to get started.
190
+ - Text Generation: [Airoboros Mistral 7B](https://huggingface.co/TheBloke/airoboros-mistral2.2-7B-GGUF/resolve/main/airoboros-mistral2.2-7b.Q4_K_S.gguf) (smaller and weaker) or [Tiefighter 13B](https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf) (larger model) or [Beepo 22B](https://huggingface.co/concedo/Beepo-22B-GGUF/resolve/main/Beepo-22B-Q4_K_S.gguf) (largest and most powerful)
191
+ - Image Generation: [Anything v3](https://huggingface.co/admruul/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp16.safetensors) or [Deliberate V2](https://huggingface.co/Yntec/Deliberate2/resolve/main/Deliberate_v2.safetensors) or [Dreamshaper SDXL](https://huggingface.co/Lykon/dreamshaper-xl-v2-turbo/resolve/main/DreamShaperXL_Turbo_v2_1.safetensors)
192
+ - Image Recognition MMproj: [Pick the correct one for your model architecture here](https://huggingface.co/koboldcpp/mmproj/tree/main)
193
+ - Speech Recognition: [Whisper models for Speech-To-Text](https://huggingface.co/koboldcpp/whisper/tree/main)
194
+ - Text-To-Speech: [TTS models for Narration](https://huggingface.co/koboldcpp/tts/tree/main)
Remote-Link.cmd ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ : # This script will help setup a cloudflared tunnel for accessing KoboldCpp over the internet
2
+ : # It should work out of the box on both linux and windows
3
+ : # ======
4
+ : # WINDOWS PORTION
5
+ :<<BATCH
6
+ @echo off
7
+ echo Starting Cloudflare Tunnel for Windows
8
+ curl -L https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-windows-amd64.exe -o cloudflared.exe
9
+ cloudflared.exe tunnel --url localhost:5001
10
+ GOTO ENDING
11
+ BATCH
12
+ : # LINUX PORTION
13
+ echo 'Starting Cloudflare Tunnel for Linux'
14
+ curl -L https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -o 'cloudflared-linux-amd64' #
15
+ chmod +x 'cloudflared-linux-amd64' #
16
+ ./cloudflared-linux-amd64 tunnel --url http://localhost:5001 #
17
+ exit #
18
+ :ENDING
build-info.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef BUILD_INFO_H
2
+ #define BUILD_INFO_H
3
+
4
+ #define LLAMA_BUILD_NUMBER 999
5
+ #define LLAMA_COMMIT "KOBOLDCPP"
6
+ #define LLAMA_COMPILER "KCPP"
7
+ #define LLAMA_TARGET "KCPP"
8
+ #define LLAMA_BUILD_COMMIT "KOBOLDCPP"
9
+ #define LLAMA_BUILD_COMPILER "KCPP"
10
+ #define LLAMA_BUILD_TARGET "KCPP"
11
+
12
+ #endif // BUILD_INFO_H
build-xcframework.sh ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #
3
+ # Options
4
+ IOS_MIN_OS_VERSION=16.4
5
+ MACOS_MIN_OS_VERSION=13.3
6
+ VISIONOS_MIN_OS_VERSION=1.0
7
+ TVOS_MIN_OS_VERSION=16.4
8
+
9
+ BUILD_SHARED_LIBS=OFF
10
+ LLAMA_BUILD_EXAMPLES=OFF
11
+ LLAMA_BUILD_TESTS=OFF
12
+ LLAMA_BUILD_SERVER=OFF
13
+ GGML_METAL=ON
14
+ GGML_METAL_EMBED_LIBRARY=ON
15
+ GGML_BLAS_DEFAULT=ON
16
+ GGML_METAL_USE_BF16=ON
17
+ GGML_OPENMP=OFF
18
+
19
+ COMMON_C_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
20
+ COMMON_CXX_FLAGS="-Wno-macro-redefined -Wno-shorten-64-to-32 -Wno-unused-command-line-argument -g"
21
+
22
+ # Common options for all builds
23
+ COMMON_CMAKE_ARGS=(
24
+ -DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_REQUIRED=NO
25
+ -DCMAKE_XCODE_ATTRIBUTE_CODE_SIGN_IDENTITY=""
26
+ -DCMAKE_XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED=NO
27
+ -DCMAKE_XCODE_ATTRIBUTE_DEBUG_INFORMATION_FORMAT="dwarf-with-dsym"
28
+ -DCMAKE_XCODE_ATTRIBUTE_GCC_GENERATE_DEBUGGING_SYMBOLS=YES
29
+ -DCMAKE_XCODE_ATTRIBUTE_COPY_PHASE_STRIP=NO
30
+ -DCMAKE_XCODE_ATTRIBUTE_STRIP_INSTALLED_PRODUCT=NO
31
+ -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
32
+ -DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS}
33
+ -DLLAMA_BUILD_EXAMPLES=${LLAMA_BUILD_EXAMPLES}
34
+ -DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS}
35
+ -DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER}
36
+ -DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY}
37
+ -DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT}
38
+ -DGGML_METAL=${GGML_METAL}
39
+ -DGGML_METAL_USE_BF16=${GGML_METAL_USE_BF16}
40
+ -DGGML_NATIVE=OFF
41
+ -DGGML_OPENMP=${GGML_OPENMP}
42
+ )
43
+
44
+ check_required_tool() {
45
+ local tool=$1
46
+ local install_message=$2
47
+
48
+ if ! command -v $tool &> /dev/null; then
49
+ echo "Error: $tool is required but not found."
50
+ echo "$install_message"
51
+ exit 1
52
+ fi
53
+ }
54
+ echo "Checking for required tools..."
55
+ check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)"
56
+ check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
57
+ check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)"
58
+ check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)"
59
+
60
+ set -e
61
+
62
+ ## Clean up previous builds
63
+ rm -rf build-apple
64
+ rm -rf build-ios-sim
65
+ rm -rf build-ios-device
66
+ rm -rf build-macos
67
+ rm -rf build-visionos
68
+ rm -rf build-visionos-sim
69
+ rm -rf build-tvos-sim
70
+ rm -rf build-tvos-device
71
+
72
+ # Setup the xcframework build directory structure
73
+ setup_framework_structure() {
74
+ local build_dir=$1
75
+ local min_os_version=$2
76
+ local platform=$3 # "ios", "macos", "visionos", or "tvos"
77
+ local framework_name="llama"
78
+
79
+ echo "Creating ${platform}-style framework structure for ${build_dir}"
80
+
81
+ if [[ "$platform" == "macos" ]]; then
82
+ # macOS versioned structure uses versioned directories
83
+ mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Headers
84
+ mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Modules
85
+ mkdir -p ${build_dir}/framework/${framework_name}.framework/Versions/A/Resources
86
+
87
+ # Create symbolic links
88
+ ln -sf A ${build_dir}/framework/${framework_name}.framework/Versions/Current
89
+ ln -sf Versions/Current/Headers ${build_dir}/framework/${framework_name}.framework/Headers
90
+ ln -sf Versions/Current/Modules ${build_dir}/framework/${framework_name}.framework/Modules
91
+ ln -sf Versions/Current/Resources ${build_dir}/framework/${framework_name}.framework/Resources
92
+ ln -sf Versions/Current/${framework_name} ${build_dir}/framework/${framework_name}.framework/${framework_name}
93
+
94
+ # Set header and module paths
95
+ local header_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Headers/
96
+ local module_path=${build_dir}/framework/${framework_name}.framework/Versions/A/Modules/
97
+ else
98
+ # iOS/VisionOS/tvOS use a flat structure
99
+ mkdir -p ${build_dir}/framework/${framework_name}.framework/Headers
100
+ mkdir -p ${build_dir}/framework/${framework_name}.framework/Modules
101
+
102
+ # Remove any existing structure to ensure clean build
103
+ rm -rf ${build_dir}/framework/${framework_name}.framework/Versions
104
+
105
+ # Set header and module paths
106
+ local header_path=${build_dir}/framework/${framework_name}.framework/Headers/
107
+ local module_path=${build_dir}/framework/${framework_name}.framework/Modules/
108
+ fi
109
+
110
+ # Copy all required headers (common for all platforms)
111
+ cp include/llama.h ${header_path}
112
+ cp ggml/include/ggml.h ${header_path}
113
+ cp ggml/include/ggml-alloc.h ${header_path}
114
+ cp ggml/include/ggml-backend.h ${header_path}
115
+ cp ggml/include/ggml-metal.h ${header_path}
116
+ cp ggml/include/ggml-cpu.h ${header_path}
117
+ cp ggml/include/ggml-blas.h ${header_path}
118
+ cp ggml/include/gguf.h ${header_path}
119
+
120
+ # Create module map (common for all platforms)
121
+ cat > ${module_path}module.modulemap << EOF
122
+ framework module llama {
123
+ header "llama.h"
124
+ header "ggml.h"
125
+ header "ggml-alloc.h"
126
+ header "ggml-backend.h"
127
+ header "ggml-metal.h"
128
+ header "ggml-cpu.h"
129
+ header "ggml-blas.h"
130
+ header "gguf.h"
131
+
132
+ link "c++"
133
+ link framework "Accelerate"
134
+ link framework "Metal"
135
+ link framework "Foundation"
136
+
137
+ export *
138
+ }
139
+ EOF
140
+
141
+ # Platform-specific settings for Info.plist
142
+ local platform_name=""
143
+ local sdk_name=""
144
+ local supported_platform=""
145
+
146
+ case "$platform" in
147
+ "ios")
148
+ platform_name="iphoneos"
149
+ sdk_name="iphoneos${min_os_version}"
150
+ supported_platform="iPhoneOS"
151
+ local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
152
+ local device_family=' <key>UIDeviceFamily</key>
153
+ <array>
154
+ <integer>1</integer>
155
+ <integer>2</integer>
156
+ </array>'
157
+ ;;
158
+ "macos")
159
+ platform_name="macosx"
160
+ sdk_name="macosx${min_os_version}"
161
+ supported_platform="MacOSX"
162
+ local plist_path="${build_dir}/framework/${framework_name}.framework/Versions/A/Resources/Info.plist"
163
+ local device_family=""
164
+ ;;
165
+ "visionos")
166
+ platform_name="xros"
167
+ sdk_name="xros${min_os_version}"
168
+ supported_platform="XRPlatform"
169
+ local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
170
+ local device_family=""
171
+ ;;
172
+ "tvos")
173
+ platform_name="appletvos"
174
+ sdk_name="appletvos${min_os_version}"
175
+ supported_platform="AppleTVOS"
176
+ local plist_path="${build_dir}/framework/${framework_name}.framework/Info.plist"
177
+ local device_family=' <key>UIDeviceFamily</key>
178
+ <array>
179
+ <integer>3</integer>
180
+ </array>'
181
+ ;;
182
+ esac
183
+
184
+ # Create Info.plist
185
+ cat > ${plist_path} << EOF
186
+ <?xml version="1.0" encoding="UTF-8"?>
187
+ <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
188
+ <plist version="1.0">
189
+ <dict>
190
+ <key>CFBundleDevelopmentRegion</key>
191
+ <string>en</string>
192
+ <key>CFBundleExecutable</key>
193
+ <string>llama</string>
194
+ <key>CFBundleIdentifier</key>
195
+ <string>org.ggml.llama</string>
196
+ <key>CFBundleInfoDictionaryVersion</key>
197
+ <string>6.0</string>
198
+ <key>CFBundleName</key>
199
+ <string>llama</string>
200
+ <key>CFBundlePackageType</key>
201
+ <string>FMWK</string>
202
+ <key>CFBundleShortVersionString</key>
203
+ <string>1.0</string>
204
+ <key>CFBundleVersion</key>
205
+ <string>1</string>
206
+ <key>MinimumOSVersion</key>
207
+ <string>${min_os_version}</string>
208
+ <key>CFBundleSupportedPlatforms</key>
209
+ <array>
210
+ <string>${supported_platform}</string>
211
+ </array>${device_family}
212
+ <key>DTPlatformName</key>
213
+ <string>${platform_name}</string>
214
+ <key>DTSDKName</key>
215
+ <string>${sdk_name}</string>
216
+ </dict>
217
+ </plist>
218
+ EOF
219
+ }
220
+
221
+ # Create dynamic libraries from static libraries.
222
+ combine_static_libraries() {
223
+ local build_dir="$1"
224
+ local release_dir="$2"
225
+ local platform="$3" # "ios", "macos", "visionos", or "tvos"
226
+ local is_simulator="$4"
227
+ local base_dir="$(pwd)"
228
+ local framework_name="llama"
229
+
230
+ # Determine output path based on platform
231
+ local output_lib=""
232
+ if [[ "$platform" == "macos" ]]; then
233
+ # macOS uses versioned structure
234
+ output_lib="${build_dir}/framework/${framework_name}.framework/Versions/A/${framework_name}"
235
+ else
236
+ # iOS, visionOS, and tvOS use a directory flat structure
237
+ output_lib="${build_dir}/framework/${framework_name}.framework/${framework_name}"
238
+ fi
239
+
240
+ local libs=(
241
+ "${base_dir}/${build_dir}/src/${release_dir}/libllama.a"
242
+ "${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml.a"
243
+ "${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-base.a"
244
+ "${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a"
245
+ "${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
246
+ "${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
247
+ )
248
+
249
+ # Create temporary directory for processing
250
+ local temp_dir="${base_dir}/${build_dir}/temp"
251
+ mkdir -p "${temp_dir}"
252
+
253
+ # Since we have multiple architectures libtool will find object files that do not
254
+ # match the target architecture. We suppress these warnings.
255
+ libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null
256
+
257
+ # Determine SDK, architectures, and install_name based on platform and simulator flag.
258
+ local sdk=""
259
+ local archs=""
260
+ local min_version_flag=""
261
+ local install_name=""
262
+
263
+ case "$platform" in
264
+ "ios")
265
+ if [[ "$is_simulator" == "true" ]]; then
266
+ sdk="iphonesimulator"
267
+ archs="arm64 x86_64"
268
+ min_version_flag="-mios-simulator-version-min=${IOS_MIN_OS_VERSION}"
269
+ else
270
+ sdk="iphoneos"
271
+ archs="arm64"
272
+ min_version_flag="-mios-version-min=${IOS_MIN_OS_VERSION}"
273
+ fi
274
+ install_name="@rpath/llama.framework/llama"
275
+ ;;
276
+ "macos")
277
+ sdk="macosx"
278
+ archs="arm64 x86_64"
279
+ min_version_flag="-mmacosx-version-min=${MACOS_MIN_OS_VERSION}"
280
+ install_name="@rpath/llama.framework/Versions/Current/llama"
281
+ ;;
282
+ "visionos")
283
+ if [[ "$is_simulator" == "true" ]]; then
284
+ sdk="xrsimulator"
285
+ archs="arm64 x86_64"
286
+ min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}-simulator"
287
+ else
288
+ sdk="xros"
289
+ archs="arm64"
290
+ min_version_flag="-mtargetos=xros${VISIONOS_MIN_OS_VERSION}"
291
+ fi
292
+ # Use flat structure for visionOS, same as iOS
293
+ install_name="@rpath/llama.framework/llama"
294
+ ;;
295
+ "tvos")
296
+ if [[ "$is_simulator" == "true" ]]; then
297
+ sdk="appletvsimulator"
298
+ archs="arm64 x86_64"
299
+ min_version_flag="-mtvos-simulator-version-min=${TVOS_MIN_OS_VERSION}"
300
+ else
301
+ sdk="appletvos"
302
+ archs="arm64"
303
+ min_version_flag="-mtvos-version-min=${TVOS_MIN_OS_VERSION}"
304
+ fi
305
+ install_name="@rpath/llama.framework/llama"
306
+ ;;
307
+ esac
308
+
309
+ # Build architecture flags
310
+ local arch_flags=""
311
+ for arch in $archs; do
312
+ arch_flags+=" -arch $arch"
313
+ done
314
+
315
+ # Create dynamic library
316
+ echo "Creating dynamic library for ${platform}."
317
+ xcrun -sdk $sdk clang++ -dynamiclib \
318
+ -isysroot $(xcrun --sdk $sdk --show-sdk-path) \
319
+ $arch_flags \
320
+ $min_version_flag \
321
+ -Wl,-force_load,"${temp_dir}/combined.a" \
322
+ -framework Foundation -framework Metal -framework Accelerate \
323
+ -install_name "$install_name" \
324
+ -o "${base_dir}/${output_lib}"
325
+
326
+ # Platform-specific post-processing for device builds
327
+ if [[ "$is_simulator" == "false" ]]; then
328
+ if command -v vtool &>/dev/null; then
329
+ case "$platform" in
330
+ "ios")
331
+ echo "Marking binary as a framework binary for iOS..."
332
+ vtool -set-build-version ios ${IOS_MIN_OS_VERSION} ${IOS_MIN_OS_VERSION} -replace \
333
+ -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
334
+ ;;
335
+ "visionos")
336
+ echo "Marking binary as a framework binary for visionOS..."
337
+ vtool -set-build-version xros ${VISIONOS_MIN_OS_VERSION} ${VISIONOS_MIN_OS_VERSION} -replace \
338
+ -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
339
+ ;;
340
+ "tvos")
341
+ echo "Marking binary as a framework binary for tvOS..."
342
+ vtool -set-build-version tvos ${TVOS_MIN_OS_VERSION} ${TVOS_MIN_OS_VERSION} -replace \
343
+ -output "${base_dir}/${output_lib}" "${base_dir}/${output_lib}"
344
+ ;;
345
+ esac
346
+ else
347
+ echo "Warning: vtool not found. Binary may not pass App Store validation."
348
+ fi
349
+ fi
350
+
351
+ echo "Creating properly formatted dSYM..."
352
+ # Create a separate directory for dSYMs for all platforms
353
+ mkdir -p "${base_dir}/${build_dir}/dSYMs"
354
+
355
+ # iOS and visionOS style dSYM (flat structure)
356
+ if [[ "$platform" == "ios" || "$platform" == "visionos" || "$platform" == "tvos" ]]; then
357
+ # Generate dSYM in the dSYMs directory
358
+ xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
359
+
360
+ # Create a copy of the binary that will be stripped
361
+ cp "${base_dir}/${output_lib}" "${temp_dir}/binary_to_strip"
362
+
363
+ # Strip debug symbols from the copy
364
+ xcrun strip -S "${temp_dir}/binary_to_strip" -o "${temp_dir}/stripped_lib"
365
+
366
+ # Replace the original with the stripped version
367
+ mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
368
+ else
369
+ # macOS style dSYM
370
+ # First strip debug info to a separate file
371
+ xcrun strip -S "${base_dir}/${output_lib}" -o "${temp_dir}/stripped_lib"
372
+
373
+ # Generate dSYM in the dSYMs directory
374
+ xcrun dsymutil "${base_dir}/${output_lib}" -o "${base_dir}/${build_dir}/dSYMs/llama.dSYM"
375
+
376
+ # Replace original binary with stripped version
377
+ mv "${temp_dir}/stripped_lib" "${base_dir}/${output_lib}"
378
+ fi
379
+
380
+ # Remove any automatically generated dSYM files in the framework structure as they will
381
+ # otherwise case Invalid Bundle Structure validation errors.
382
+ if [ -d "${base_dir}/${output_lib}.dSYM" ]; then
383
+ echo "Removing generated dSYM file in framework structure: ${base_dir}/${output_lib}.dSYM"
384
+ rm -rf "${base_dir}/${output_lib}.dSYM"
385
+ fi
386
+
387
+ # Clean up
388
+ rm -rf "${temp_dir}"
389
+ }
390
+
391
+ echo "Building for iOS simulator..."
392
+ cmake -B build-ios-sim -G Xcode \
393
+ "${COMMON_CMAKE_ARGS[@]}" \
394
+ -DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
395
+ -DIOS=ON \
396
+ -DCMAKE_SYSTEM_NAME=iOS \
397
+ -DCMAKE_OSX_SYSROOT=iphonesimulator \
398
+ -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
399
+ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphonesimulator \
400
+ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
401
+ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
402
+ -S .
403
+ cmake --build build-ios-sim --config Release -- -quiet
404
+
405
+ echo "Building for iOS devices..."
406
+ cmake -B build-ios-device -G Xcode \
407
+ "${COMMON_CMAKE_ARGS[@]}" \
408
+ -DCMAKE_OSX_DEPLOYMENT_TARGET=${IOS_MIN_OS_VERSION} \
409
+ -DCMAKE_OSX_SYSROOT=iphoneos \
410
+ -DCMAKE_OSX_ARCHITECTURES="arm64" \
411
+ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=iphoneos \
412
+ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
413
+ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
414
+ -S .
415
+ cmake --build build-ios-device --config Release -- -quiet
416
+
417
+ echo "Building for macOS..."
418
+ cmake -B build-macos -G Xcode \
419
+ "${COMMON_CMAKE_ARGS[@]}" \
420
+ -DCMAKE_OSX_DEPLOYMENT_TARGET=${MACOS_MIN_OS_VERSION} \
421
+ -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
422
+ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
423
+ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
424
+ -S .
425
+ cmake --build build-macos --config Release -- -quiet
426
+
427
+ echo "Building for visionOS..."
428
+ cmake -B build-visionos -G Xcode \
429
+ "${COMMON_CMAKE_ARGS[@]}" \
430
+ -DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
431
+ -DCMAKE_OSX_ARCHITECTURES="arm64" \
432
+ -DCMAKE_SYSTEM_NAME=visionOS \
433
+ -DCMAKE_OSX_SYSROOT=xros \
434
+ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
435
+ -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
436
+ -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
437
+ -S .
438
+ cmake --build build-visionos --config Release -- -quiet
439
+
440
+ echo "Building for visionOS simulator..."
441
+ cmake -B build-visionos-sim -G Xcode \
442
+ "${COMMON_CMAKE_ARGS[@]}" \
443
+ -DCMAKE_OSX_DEPLOYMENT_TARGET=${VISIONOS_MIN_OS_VERSION} \
444
+ -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
445
+ -DCMAKE_SYSTEM_NAME=visionOS \
446
+ -DCMAKE_OSX_SYSROOT=xrsimulator \
447
+ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
448
+ -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
449
+ -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
450
+ -S .
451
+ cmake --build build-visionos-sim --config Release -- -quiet
452
+
453
+ # Add tvOS builds (might need the same u_int definitions as watchOS and visionOS)
454
+ echo "Building for tvOS simulator..."
455
+ cmake -B build-tvos-sim -G Xcode \
456
+ "${COMMON_CMAKE_ARGS[@]}" \
457
+ -DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
458
+ -DCMAKE_SYSTEM_NAME=tvOS \
459
+ -DCMAKE_OSX_SYSROOT=appletvsimulator \
460
+ -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" \
461
+ -DGGML_METAL=ON \
462
+ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvsimulator \
463
+ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
464
+ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
465
+ -S .
466
+ cmake --build build-tvos-sim --config Release -- -quiet
467
+
468
+ echo "Building for tvOS devices..."
469
+ cmake -B build-tvos-device -G Xcode \
470
+ "${COMMON_CMAKE_ARGS[@]}" \
471
+ -DCMAKE_OSX_DEPLOYMENT_TARGET=${TVOS_MIN_OS_VERSION} \
472
+ -DCMAKE_SYSTEM_NAME=tvOS \
473
+ -DCMAKE_OSX_SYSROOT=appletvos \
474
+ -DCMAKE_OSX_ARCHITECTURES="arm64" \
475
+ -DGGML_METAL=ON \
476
+ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=appletvos \
477
+ -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
478
+ -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
479
+ -S .
480
+ cmake --build build-tvos-device --config Release -- -quiet
481
+
482
+ # Setup frameworks and copy binaries and headers
483
+ echo "Setting up framework structures..."
484
+ setup_framework_structure "build-ios-sim" ${IOS_MIN_OS_VERSION} "ios"
485
+ setup_framework_structure "build-ios-device" ${IOS_MIN_OS_VERSION} "ios"
486
+ setup_framework_structure "build-macos" ${MACOS_MIN_OS_VERSION} "macos"
487
+ setup_framework_structure "build-visionos" ${VISIONOS_MIN_OS_VERSION} "visionos"
488
+ setup_framework_structure "build-visionos-sim" ${VISIONOS_MIN_OS_VERSION} "visionos"
489
+ setup_framework_structure "build-tvos-sim" ${TVOS_MIN_OS_VERSION} "tvos"
490
+ setup_framework_structure "build-tvos-device" ${TVOS_MIN_OS_VERSION} "tvos"
491
+
492
+ # Create dynamic libraries from static libraries
493
+ echo "Creating dynamic libraries from static libraries..."
494
+ combine_static_libraries "build-ios-sim" "Release-iphonesimulator" "ios" "true"
495
+ combine_static_libraries "build-ios-device" "Release-iphoneos" "ios" "false"
496
+ combine_static_libraries "build-macos" "Release" "macos" "false"
497
+ combine_static_libraries "build-visionos" "Release-xros" "visionos" "false"
498
+ combine_static_libraries "build-visionos-sim" "Release-xrsimulator" "visionos" "true"
499
+ combine_static_libraries "build-tvos-sim" "Release-appletvsimulator" "tvos" "true"
500
+ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false"
501
+
502
+ # Create XCFramework with correct debug symbols paths
503
+ echo "Creating XCFramework..."
504
+ xcodebuild -create-xcframework \
505
+ -framework $(pwd)/build-ios-sim/framework/llama.framework \
506
+ -debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \
507
+ -framework $(pwd)/build-ios-device/framework/llama.framework \
508
+ -debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
509
+ -framework $(pwd)/build-macos/framework/llama.framework \
510
+ -debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
511
+ -framework $(pwd)/build-visionos/framework/llama.framework \
512
+ -debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
513
+ -framework $(pwd)/build-visionos-sim/framework/llama.framework \
514
+ -debug-symbols $(pwd)/build-visionos-sim/dSYMs/llama.dSYM \
515
+ -framework $(pwd)/build-tvos-device/framework/llama.framework \
516
+ -debug-symbols $(pwd)/build-tvos-device/dSYMs/llama.dSYM \
517
+ -framework $(pwd)/build-tvos-sim/framework/llama.framework \
518
+ -debug-symbols $(pwd)/build-tvos-sim/dSYMs/llama.dSYM \
519
+ -output $(pwd)/build-apple/llama.xcframework
clblast.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0611442b931691d9b3c9bc5ebe7625f17a5c5902e1a2b9e98cbad440d1459625
3
+ size 5450752
colab.ipynb ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "colab_type": "text",
7
+ "id": "view-in-github"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/LostRuins/koboldcpp/blob/concedo/colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "2FCn5tmpn3UV"
17
+ },
18
+ "source": [
19
+ "## Welcome to the Official KoboldCpp Colab Notebook\n",
20
+ "It's really easy to get started. Just press the two **Play** buttons below, and then connect to the **Cloudflare URL** shown at the end.\n",
21
+ "You can select a model from the dropdown, or enter a **custom URL** to a GGUF model (Example: `https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_M.gguf`)\n",
22
+ "\n",
23
+ "**Keep this page open and occationally check for captcha's so that your AI is not shut down**"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {
30
+ "id": "QNaj3u0jn3UW"
31
+ },
32
+ "outputs": [],
33
+ "source": [
34
+ "#@title <-- Tap this if you play on Mobile { display-mode: \"form\" }\n",
35
+ "%%html\n",
36
+ "<b>Press play on the music player to keep the tab alive, then start KoboldCpp below</b><br/>\n",
37
+ "<audio autoplay=\"\" src=\"https://raw.githubusercontent.com/KoboldAI/KoboldAI-Client/main/colab/silence.m4a\" loop controls>"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {
44
+ "cellView": "form",
45
+ "id": "uJS9i_Dltv8Y"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "#@title <b>v-- Enter your model below and then click this to start Koboldcpp</b>\n",
50
+ "\n",
51
+ "Model = \"https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf\" #@param [\"https://huggingface.co/KoboldAI/LLaMA2-13B-Tiefighter-GGUF/resolve/main/LLaMA2-13B-Tiefighter.Q4_K_S.gguf\",\"https://huggingface.co/KoboldAI/LLaMA2-13B-Estopia-GGUF/resolve/main/LLaMA2-13B-Estopia.Q4_K_S.gguf\",\"https://huggingface.co/mradermacher/Fimbulvetr-11B-v2-GGUF/resolve/main/Fimbulvetr-11B-v2.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/MythoMax-L2-13B-GGUF/resolve/main/mythomax-l2-13b.Q4_K_M.gguf\",\"https://huggingface.co/TheBloke/ReMM-SLERP-L2-13B-GGUF/resolve/main/remm-slerp-l2-13b.Q4_K_M.gguf\",\"https://huggingface.co/TheBloke/Xwin-LM-13B-v0.2-GGUF/resolve/main/xwin-lm-13b-v0.2.Q4_K_M.gguf\",\"https://huggingface.co/mradermacher/mini-magnum-12b-v1.1-GGUF/resolve/main/mini-magnum-12b-v1.1.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/Stheno-L2-13B-GGUF/resolve/main/stheno-l2-13b.Q4_K_M.gguf\",\"https://huggingface.co/TheBloke/MythoMax-L2-Kimiko-v2-13B-GGUF/resolve/main/mythomax-l2-kimiko-v2-13b.Q4_K_M.gguf\",\"https://huggingface.co/bartowski/Rocinante-12B-v1.1-GGUF/resolve/main/Rocinante-12B-v1.1-Q4_K_S.gguf\",\"https://huggingface.co/KoboldAI/Llama-3.1-8B-BookAdventures-GGUF/resolve/main/Llama-3.1-8B-BookAdventures.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/MistRP-Airoboros-7B-GGUF/resolve/main/mistrp-airoboros-7b.Q4_K_S.gguf\",\"https://huggingface.co/TheBloke/airoboros-mistral2.2-7B-GGUF/resolve/main/airoboros-mistral2.2-7b.Q4_K_S.gguf\",\"https://huggingface.co/concedo/KobbleTinyV2-1.1B-GGUF/resolve/main/KobbleTiny-Q4_K.gguf\",\"https://huggingface.co/grimjim/kukulemon-7B-GGUF/resolve/main/kukulemon-7B.Q8_0.gguf\",\"https://huggingface.co/mradermacher/LemonKunoichiWizardV3-GGUF/resolve/main/LemonKunoichiWizardV3.Q4_K_M.gguf\",\"https://huggingface.co/Lewdiculous/Kunoichi-DPO-v2-7B-GGUF-Imatrix/resolve/main/Kunoichi-DPO-v2-7B-Q4_K_M-imatrix.gguf\",\"https://huggingface.co/mradermacher/L3-8B-Stheno-v3.2-i1-GGUF/resolve/main/L3-8B-Stheno-v3.2.i1-Q4_K_M.gguf\",\"https://huggingface.co/Lewdiculous/Llama-3-Lumimaid-8B-v0.1-OAS-GGUF-IQ-Imatrix/resolve/main/v2-Llama-3-Lumimaid-8B-v0.1-OAS-Q4_K_M-imat.gguf\",\"https://huggingface.co/bartowski/NeuralDaredevil-8B-abliterated-GGUF/resolve/main/NeuralDaredevil-8B-abliterated-Q4_K_M.gguf\",\"https://huggingface.co/bartowski/L3-8B-Lunaris-v1-GGUF/resolve/main/L3-8B-Lunaris-v1-Q4_K_M.gguf\",\"https://huggingface.co/mradermacher/L3-Umbral-Mind-RP-v2.0-8B-GGUF/resolve/main/L3-Umbral-Mind-RP-v2.0-8B.Q4_K_M.gguf\",\"https://huggingface.co/bartowski/TheDrummer_Cydonia-24B-v2-GGUF/resolve/main/TheDrummer_Cydonia-24B-v2-Q4_K_S.gguf\",\"https://huggingface.co/bartowski/PocketDoc_Dans-PersonalityEngine-V1.2.0-24b-GGUF/resolve/main/PocketDoc_Dans-PersonalityEngine-V1.2.0-24b-IQ4_XS.gguf\"]{allow-input: true}\n",
52
+ "Layers = 99 #@param [99]{allow-input: true}\n",
53
+ "ContextSize = 4096 #@param [4096,8192] {allow-input: true}\n",
54
+ "FlashAttention = True #@param {type:\"boolean\"}\n",
55
+ "Multiplayer = False #@param {type:\"boolean\"}\n",
56
+ "FACommand = \"\"\n",
57
+ "MPCommand = \"\"\n",
58
+ "#@markdown <hr>\n",
59
+ "LoadVisionMMProjector = False #@param {type:\"boolean\"}\n",
60
+ "Mmproj = \"https://huggingface.co/koboldcpp/mmproj/resolve/main/llama-13b-mmproj-v1.5.Q4_1.gguf\" #@param [\"https://huggingface.co/koboldcpp/mmproj/resolve/main/llama-13b-mmproj-v1.5.Q4_1.gguf\",\"https://huggingface.co/koboldcpp/mmproj/resolve/main/mistral-7b-mmproj-v1.5-Q4_1.gguf\",\"https://huggingface.co/koboldcpp/mmproj/resolve/main/llama-7b-mmproj-v1.5-Q4_0.gguf\",\"https://huggingface.co/koboldcpp/mmproj/resolve/main/LLaMA3-8B_mmproj-Q4_1.gguf\"]{allow-input: true}\n",
61
+ "VCommand = \"\"\n",
62
+ "#@markdown <hr>\n",
63
+ "LoadImgModel = False #@param {type:\"boolean\"}\n",
64
+ "ImgModel = \"https://huggingface.co/koboldcpp/imgmodel/resolve/main/imgmodel_ftuned_q4_0.gguf\" #@param [\"https://huggingface.co/koboldcpp/imgmodel/resolve/main/imgmodel_ftuned_q4_0.gguf\"]{allow-input: true}\n",
65
+ "SCommand = \"\"\n",
66
+ "#@markdown <hr>\n",
67
+ "LoadSpeechModel = False #@param {type:\"boolean\"}\n",
68
+ "SpeechModel = \"https://huggingface.co/koboldcpp/whisper/resolve/main/whisper-base.en-q5_1.bin\" #@param [\"https://huggingface.co/koboldcpp/whisper/resolve/main/whisper-base.en-q5_1.bin\"]{allow-input: true}\n",
69
+ "WCommand = \"\"\n",
70
+ "#@markdown <hr>\n",
71
+ "LoadTTSModel = False #@param {type:\"boolean\"}\n",
72
+ "TTSModel = \"https://huggingface.co/koboldcpp/tts/resolve/main/OuteTTS-0.2-500M-Q4_0.gguf\" #@param [\"https://huggingface.co/koboldcpp/tts/resolve/main/OuteTTS-0.2-500M-Q4_0.gguf\"]{allow-input: true}\n",
73
+ "WavTokModel = \"https://huggingface.co/koboldcpp/tts/resolve/main/WavTokenizer-Large-75-Q4_0.gguf\" #@param [\"https://huggingface.co/koboldcpp/tts/resolve/main/WavTokenizer-Large-75-Q4_0.gguf\"]{allow-input: true}\n",
74
+ "TTSCommand = \"\"\n",
75
+ "#@markdown <hr>\n",
76
+ "AllowSaveToGoogleDrive = False #@param {type:\"boolean\"}\n",
77
+ "SavGdriveCommand = \"\"\n",
78
+ "\n",
79
+ "import os\n",
80
+ "if not os.path.isfile(\"/opt/bin/nvidia-smi\"):\n",
81
+ " raise RuntimeError(\"⚠️Colab did not give you a GPU due to usage limits, this can take a few hours before they let you back in. Check out https://lite.koboldai.net for a free alternative (that does not provide an API link but can load KoboldAI saves and chat cards) or subscribe to Colab Pro for immediate access.⚠️\")\n",
82
+ "\n",
83
+ "if AllowSaveToGoogleDrive:\n",
84
+ " print(\"Attempting to request access to save to your google drive...\")\n",
85
+ " try:\n",
86
+ " from google.colab import drive\n",
87
+ " import os, json\n",
88
+ " drive.mount('/content/drive', force_remount=True)\n",
89
+ " if not os.path.exists(\"/content/drive/MyDrive\"):\n",
90
+ " raise RuntimeError(\"Google Drive mount failed. Please grant permissions and try again.\")\n",
91
+ " kcppdir = '/content/drive/MyDrive/koboldcpp_data'\n",
92
+ " os.makedirs(kcppdir, exist_ok=True)\n",
93
+ " savedatapath = os.path.join(kcppdir, \"koboldcpp_save_db.jsondb\")\n",
94
+ " if not os.path.exists(savedatapath):\n",
95
+ " settings_data = {}\n",
96
+ " with open(savedatapath, \"w\") as json_file:\n",
97
+ " json.dump(settings_data, json_file, indent=4)\n",
98
+ " print(f\"Created new koboldcpp_save_db.jsondb at {savedatapath}\")\n",
99
+ " else:\n",
100
+ " print(f\"Loading saved data at {savedatapath}\")\n",
101
+ " SavGdriveCommand = f\" --savedatafile {savedatapath}\"\n",
102
+ " except Exception as e:\n",
103
+ " print(f\"⚠️ Error: {e}\")\n",
104
+ " print(\"Please ensure you grant Google Drive permissions and try again.\")\n",
105
+ "\n",
106
+ "%cd /content\n",
107
+ "if Mmproj and LoadVisionMMProjector:\n",
108
+ " VCommand = \"--mmproj vmodel.gguf\"\n",
109
+ "else:\n",
110
+ " SCommand = \"\"\n",
111
+ "if ImgModel and LoadImgModel:\n",
112
+ " SCommand = \"--sdmodel imodel.gguf --sdthreads 4 --sdquant --sdclamped\"\n",
113
+ "else:\n",
114
+ " SCommand = \"\"\n",
115
+ "if SpeechModel and LoadSpeechModel:\n",
116
+ " WCommand = \"--whispermodel wmodel.bin\"\n",
117
+ "else:\n",
118
+ " WCommand = \"\"\n",
119
+ "if TTSModel and WavTokModel and LoadTTSModel:\n",
120
+ " TTSCommand = \"--ttsmodel ttsmodel.bin --ttswavtokenizer ttswavtok.bin --ttsgpu\"\n",
121
+ "else:\n",
122
+ " TTSCommand = \"\"\n",
123
+ "if FlashAttention:\n",
124
+ " FACommand = \"--flashattention\"\n",
125
+ "else:\n",
126
+ " FACommand = \"\"\n",
127
+ "if Multiplayer:\n",
128
+ " MPCommand = \"--multiplayer\"\n",
129
+ "else:\n",
130
+ " MPCommand = \"\"\n",
131
+ "\n",
132
+ "!echo Downloading KoboldCpp, please wait...\n",
133
+ "!wget -O dlfile.tmp https://kcpplinux.concedo.workers.dev && mv dlfile.tmp koboldcpp_linux\n",
134
+ "!test -f koboldcpp_linux && echo Download Successful || echo Download Failed\n",
135
+ "!chmod +x ./koboldcpp_linux\n",
136
+ "!apt update\n",
137
+ "!apt install aria2 -y\n",
138
+ "# simple fix for a common URL mistake\n",
139
+ "if \"https://huggingface.co/\" in Model and \"/blob/main/\" in Model:\n",
140
+ " Model = Model.replace(\"/blob/main/\", \"/resolve/main/\")\n",
141
+ "!aria2c -x 10 -o model.gguf --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $Model\n",
142
+ "if VCommand:\n",
143
+ " !aria2c -x 10 -o vmodel.gguf --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $Mmproj\n",
144
+ "if SCommand:\n",
145
+ " !aria2c -x 10 -o imodel.gguf --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $ImgModel\n",
146
+ "if WCommand:\n",
147
+ " !aria2c -x 10 -o wmodel.bin --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $SpeechModel\n",
148
+ "if TTSCommand:\n",
149
+ " !aria2c -x 10 -o ttsmodel.bin --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $TTSModel\n",
150
+ " !aria2c -x 10 -o ttswavtok.bin --summary-interval=5 --download-result=default --allow-overwrite=true --file-allocation=none $WavTokModel\n",
151
+ "!./koboldcpp_linux model.gguf --usecublas 0 mmq --chatcompletionsadapter AutoGuess --multiuser --gpulayers $Layers --contextsize $ContextSize --websearch --quiet --remotetunnel $FACommand $MPCommand $VCommand $SCommand $WCommand $TTSCommand $SavGdriveCommand\n"
152
+ ]
153
+ }
154
+ ],
155
+ "metadata": {
156
+ "accelerator": "GPU",
157
+ "colab": {
158
+ "cell_execution_strategy": "setup",
159
+ "gpuType": "T4",
160
+ "include_colab_link": true,
161
+ "private_outputs": true,
162
+ "provenance": []
163
+ },
164
+ "kernelspec": {
165
+ "display_name": "Python 3",
166
+ "name": "python3"
167
+ },
168
+ "language_info": {
169
+ "name": "python"
170
+ }
171
+ },
172
+ "nbformat": 4,
173
+ "nbformat_minor": 0
174
+ }
common/arg.cpp ADDED
The diff for this file is too large to render. See raw diff
 
common/arg.h ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "common.h"
4
+
5
+ #include <set>
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ //
10
+ // CLI argument parsing
11
+ //
12
+
13
+ struct common_arg {
14
+ std::set<enum llama_example> examples = {LLAMA_EXAMPLE_COMMON};
15
+ std::set<enum llama_example> excludes = {};
16
+ std::vector<const char *> args;
17
+ const char * value_hint = nullptr; // help text or example for arg value
18
+ const char * value_hint_2 = nullptr; // for second arg value
19
+ const char * env = nullptr;
20
+ std::string help;
21
+ bool is_sparam = false; // is current arg a sampling param?
22
+ void (*handler_void) (common_params & params) = nullptr;
23
+ void (*handler_string) (common_params & params, const std::string &) = nullptr;
24
+ void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
25
+ void (*handler_int) (common_params & params, int) = nullptr;
26
+
27
+ common_arg(
28
+ const std::initializer_list<const char *> & args,
29
+ const char * value_hint,
30
+ const std::string & help,
31
+ void (*handler)(common_params & params, const std::string &)
32
+ ) : args(args), value_hint(value_hint), help(help), handler_string(handler) {}
33
+
34
+ common_arg(
35
+ const std::initializer_list<const char *> & args,
36
+ const char * value_hint,
37
+ const std::string & help,
38
+ void (*handler)(common_params & params, int)
39
+ ) : args(args), value_hint(value_hint), help(help), handler_int(handler) {}
40
+
41
+ common_arg(
42
+ const std::initializer_list<const char *> & args,
43
+ const std::string & help,
44
+ void (*handler)(common_params & params)
45
+ ) : args(args), help(help), handler_void(handler) {}
46
+
47
+ // support 2 values for arg
48
+ common_arg(
49
+ const std::initializer_list<const char *> & args,
50
+ const char * value_hint,
51
+ const char * value_hint_2,
52
+ const std::string & help,
53
+ void (*handler)(common_params & params, const std::string &, const std::string &)
54
+ ) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {}
55
+
56
+ common_arg & set_examples(std::initializer_list<enum llama_example> examples);
57
+ common_arg & set_excludes(std::initializer_list<enum llama_example> excludes);
58
+ common_arg & set_env(const char * env);
59
+ common_arg & set_sparam();
60
+ bool in_example(enum llama_example ex);
61
+ bool is_exclude(enum llama_example ex);
62
+ bool get_value_from_env(std::string & output);
63
+ bool has_value_from_env();
64
+ std::string to_string();
65
+ };
66
+
67
+ struct common_params_context {
68
+ enum llama_example ex = LLAMA_EXAMPLE_COMMON;
69
+ common_params & params;
70
+ std::vector<common_arg> options;
71
+ void(*print_usage)(int, char **) = nullptr;
72
+ common_params_context(common_params & params) : params(params) {}
73
+ };
74
+
75
+ // parse input arguments from CLI
76
+ // if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message)
77
+ bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
78
+
79
+ // function to be used by test-arg-parser
80
+ common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
common/base64.hpp ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ This is free and unencumbered software released into the public domain.
3
+
4
+ Anyone is free to copy, modify, publish, use, compile, sell, or
5
+ distribute this software, either in source code form or as a compiled
6
+ binary, for any purpose, commercial or non-commercial, and by any
7
+ means.
8
+
9
+ In jurisdictions that recognize copyright laws, the author or authors
10
+ of this software dedicate any and all copyright interest in the
11
+ software to the public domain. We make this dedication for the benefit
12
+ of the public at large and to the detriment of our heirs and
13
+ successors. We intend this dedication to be an overt act of
14
+ relinquishment in perpetuity of all present and future rights to this
15
+ software under copyright law.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
18
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
19
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
20
+ IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
21
+ OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
22
+ ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
23
+ OTHER DEALINGS IN THE SOFTWARE.
24
+
25
+ For more information, please refer to <http://unlicense.org>
26
+ */
27
+
28
+ #ifndef PUBLIC_DOMAIN_BASE64_HPP_
29
+ #define PUBLIC_DOMAIN_BASE64_HPP_
30
+
31
+ #include <cstdint>
32
+ #include <iterator>
33
+ #include <stdexcept>
34
+ #include <string>
35
+
36
+ class base64_error : public std::runtime_error
37
+ {
38
+ public:
39
+ using std::runtime_error::runtime_error;
40
+ };
41
+
42
+ class base64
43
+ {
44
+ public:
45
+ enum class alphabet
46
+ {
47
+ /** the alphabet is detected automatically */
48
+ auto_,
49
+ /** the standard base64 alphabet is used */
50
+ standard,
51
+ /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/
52
+ url_filename_safe
53
+ };
54
+
55
+ enum class decoding_behavior
56
+ {
57
+ /** if the input is not padded, the remaining bits are ignored */
58
+ moderate,
59
+ /** if a padding character is encounter decoding is finished */
60
+ loose
61
+ };
62
+
63
+ /**
64
+ Encodes all the elements from `in_begin` to `in_end` to `out`.
65
+
66
+ @warning The source and destination cannot overlap. The destination must be able to hold at least
67
+ `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator.
68
+
69
+ @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than
70
+ 8 bits
71
+ @tparam Output_iterator the destination; the elements written to it are from the type `char`
72
+ @param in_begin the beginning of the source
73
+ @param in_end the ending of the source
74
+ @param out the destination iterator
75
+ @param alphabet which alphabet should be used
76
+ @returns the iterator to the next element past the last element copied
77
+ @throws see `Input_iterator` and `Output_iterator`
78
+ */
79
+ template<typename Input_iterator, typename Output_iterator>
80
+ static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
81
+ alphabet alphabet = alphabet::standard)
82
+ {
83
+ constexpr auto pad = '=';
84
+ const char* alpha = alphabet == alphabet::url_filename_safe
85
+ ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
86
+ : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
87
+
88
+ while (in_begin != in_end) {
89
+ std::uint8_t i0 = 0, i1 = 0, i2 = 0;
90
+
91
+ // first character
92
+ i0 = static_cast<std::uint8_t>(*in_begin);
93
+ ++in_begin;
94
+
95
+ *out = alpha[i0 >> 2 & 0x3f];
96
+ ++out;
97
+
98
+ // part of first character and second
99
+ if (in_begin != in_end) {
100
+ i1 = static_cast<std::uint8_t>(*in_begin);
101
+ ++in_begin;
102
+
103
+ *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)];
104
+ ++out;
105
+ } else {
106
+ *out = alpha[(i0 & 0x3) << 4];
107
+ ++out;
108
+
109
+ // last padding
110
+ *out = pad;
111
+ ++out;
112
+
113
+ // last padding
114
+ *out = pad;
115
+ ++out;
116
+
117
+ break;
118
+ }
119
+
120
+ // part of second character and third
121
+ if (in_begin != in_end) {
122
+ i2 = static_cast<std::uint8_t>(*in_begin);
123
+ ++in_begin;
124
+
125
+ *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)];
126
+ ++out;
127
+ } else {
128
+ *out = alpha[(i1 & 0xf) << 2];
129
+ ++out;
130
+
131
+ // last padding
132
+ *out = pad;
133
+ ++out;
134
+
135
+ break;
136
+ }
137
+
138
+ // rest of third
139
+ *out = alpha[i2 & 0x3f];
140
+ ++out;
141
+ }
142
+
143
+ return out;
144
+ }
145
+ /**
146
+ Encodes a string.
147
+
148
+ @param str the string that should be encoded
149
+ @param alphabet which alphabet should be used
150
+ @returns the encoded base64 string
151
+ @throws see base64::encode()
152
+ */
153
+ static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard)
154
+ {
155
+ std::string result;
156
+
157
+ result.reserve(required_encode_size(str.length()) + 1);
158
+
159
+ encode(str.begin(), str.end(), std::back_inserter(result), alphabet);
160
+
161
+ return result;
162
+ }
163
+ /**
164
+ Encodes a char array.
165
+
166
+ @param buffer the char array
167
+ @param size the size of the array
168
+ @param alphabet which alphabet should be used
169
+ @returns the encoded string
170
+ */
171
+ static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard)
172
+ {
173
+ std::string result;
174
+
175
+ result.reserve(required_encode_size(size) + 1);
176
+
177
+ encode(buffer, buffer + size, std::back_inserter(result), alphabet);
178
+
179
+ return result;
180
+ }
181
+ /**
182
+ Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`,
183
+ in other words: inplace decoding is possible.
184
+
185
+ @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`,
186
+ otherwise the behavior depends on the output iterator.
187
+
188
+ @tparam Input_iterator the source; the returned elements are cast to `char`
189
+ @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t`
190
+ @param in_begin the beginning of the source
191
+ @param in_end the ending of the source
192
+ @param out the destination iterator
193
+ @param alphabet which alphabet should be used
194
+ @param behavior the behavior when an error was detected
195
+ @returns the iterator to the next element past the last element copied
196
+ @throws base64_error depending on the set behavior
197
+ @throws see `Input_iterator` and `Output_iterator`
198
+ */
199
+ template<typename Input_iterator, typename Output_iterator>
200
+ static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
201
+ alphabet alphabet = alphabet::auto_,
202
+ decoding_behavior behavior = decoding_behavior::moderate)
203
+ {
204
+ //constexpr auto pad = '=';
205
+ std::uint8_t last = 0;
206
+ auto bits = 0;
207
+
208
+ while (in_begin != in_end) {
209
+ auto c = *in_begin;
210
+ ++in_begin;
211
+
212
+ if (c == '=') {
213
+ break;
214
+ }
215
+
216
+ auto part = _base64_value(alphabet, c);
217
+
218
+ // enough bits for one byte
219
+ if (bits + 6 >= 8) {
220
+ *out = (last << (8 - bits)) | (part >> (bits - 2));
221
+ ++out;
222
+
223
+ bits -= 2;
224
+ } else {
225
+ bits += 6;
226
+ }
227
+
228
+ last = part;
229
+ }
230
+
231
+ // check padding
232
+ if (behavior != decoding_behavior::loose) {
233
+ while (in_begin != in_end) {
234
+ auto c = *in_begin;
235
+ ++in_begin;
236
+
237
+ if (c != '=') {
238
+ throw base64_error("invalid base64 character.");
239
+ }
240
+ }
241
+ }
242
+
243
+ return out;
244
+ }
245
+ /**
246
+ Decodes a string.
247
+
248
+ @param str the base64 encoded string
249
+ @param alphabet which alphabet should be used
250
+ @param behavior the behavior when an error was detected
251
+ @returns the decoded string
252
+ @throws see base64::decode()
253
+ */
254
+ static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_,
255
+ decoding_behavior behavior = decoding_behavior::moderate)
256
+ {
257
+ std::string result;
258
+
259
+ result.reserve(max_decode_size(str.length()));
260
+
261
+ decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior);
262
+
263
+ return result;
264
+ }
265
+ /**
266
+ Decodes a string.
267
+
268
+ @param buffer the base64 encoded buffer
269
+ @param size the size of the buffer
270
+ @param alphabet which alphabet should be used
271
+ @param behavior the behavior when an error was detected
272
+ @returns the decoded string
273
+ @throws see base64::decode()
274
+ */
275
+ static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_,
276
+ decoding_behavior behavior = decoding_behavior::moderate)
277
+ {
278
+ std::string result;
279
+
280
+ result.reserve(max_decode_size(size));
281
+
282
+ decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior);
283
+
284
+ return result;
285
+ }
286
+ /**
287
+ Decodes a string inplace.
288
+
289
+ @param[in,out] str the base64 encoded string
290
+ @param alphabet which alphabet should be used
291
+ @param behavior the behavior when an error was detected
292
+ @throws base64::decode_inplace()
293
+ */
294
+ static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_,
295
+ decoding_behavior behavior = decoding_behavior::moderate)
296
+ {
297
+ str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin());
298
+ }
299
+ /**
300
+ Decodes a char array inplace.
301
+
302
+ @param[in,out] str the string array
303
+ @param size the length of the array
304
+ @param alphabet which alphabet should be used
305
+ @param behavior the behavior when an error was detected
306
+ @returns the pointer to the next element past the last element decoded
307
+ @throws base64::decode_inplace()
308
+ */
309
+ static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_,
310
+ decoding_behavior behavior = decoding_behavior::moderate)
311
+ {
312
+ return decode(str, str + size, str, alphabet, behavior);
313
+ }
314
+ /**
315
+ Returns the required decoding size for a given size. The value is calculated with the following formula:
316
+
317
+ $$
318
+ \lceil \frac{size}{4} \rceil \cdot 3
319
+ $$
320
+
321
+ @param size the size of the encoded input
322
+ @returns the size of the resulting decoded buffer; this the absolute maximum
323
+ */
324
+ static std::size_t max_decode_size(std::size_t size) noexcept
325
+ {
326
+ return (size / 4 + (size % 4 ? 1 : 0)) * 3;
327
+ }
328
+ /**
329
+ Returns the required encoding size for a given size. The value is calculated with the following formula:
330
+
331
+ $$
332
+ \lceil \frac{size}{3} \rceil \cdot 4
333
+ $$
334
+
335
+ @param size the size of the decoded input
336
+ @returns the size of the resulting encoded buffer
337
+ */
338
+ static std::size_t required_encode_size(std::size_t size) noexcept
339
+ {
340
+ return (size / 3 + (size % 3 ? 1 : 0)) * 4;
341
+ }
342
+
343
+ private:
344
+ static std::uint8_t _base64_value(alphabet& alphabet, char c)
345
+ {
346
+ if (c >= 'A' && c <= 'Z') {
347
+ return c - 'A';
348
+ } else if (c >= 'a' && c <= 'z') {
349
+ return c - 'a' + 26;
350
+ } else if (c >= '0' && c <= '9') {
351
+ return c - '0' + 52;
352
+ }
353
+
354
+ // comes down to alphabet
355
+ if (alphabet == alphabet::standard) {
356
+ if (c == '+') {
357
+ return 62;
358
+ } else if (c == '/') {
359
+ return 63;
360
+ }
361
+ } else if (alphabet == alphabet::url_filename_safe) {
362
+ if (c == '-') {
363
+ return 62;
364
+ } else if (c == '_') {
365
+ return 63;
366
+ }
367
+ } // auto detect
368
+ else {
369
+ if (c == '+') {
370
+ alphabet = alphabet::standard;
371
+
372
+ return 62;
373
+ } else if (c == '/') {
374
+ alphabet = alphabet::standard;
375
+
376
+ return 63;
377
+ } else if (c == '-') {
378
+ alphabet = alphabet::url_filename_safe;
379
+
380
+ return 62;
381
+ } else if (c == '_') {
382
+ alphabet = alphabet::url_filename_safe;
383
+
384
+ return 63;
385
+ }
386
+ }
387
+
388
+ throw base64_error("invalid base64 character.");
389
+ }
390
+ };
391
+
392
+ #endif // !PUBLIC_DOMAIN_BASE64_HPP_
common/build-info.cpp.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ int LLAMA_BUILD_NUMBER = @BUILD_NUMBER@;
2
+ char const *LLAMA_COMMIT = "@BUILD_COMMIT@";
3
+ char const *LLAMA_COMPILER = "@BUILD_COMPILER@";
4
+ char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@";
common/chat.cpp ADDED
@@ -0,0 +1,1779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "chat.h"
2
+ #include "json-schema-to-grammar.h"
3
+ #include "log.h"
4
+ #include "minja/chat-template.hpp"
5
+ #include "minja/minja.hpp"
6
+
7
+ #include <optional>
8
+
9
+ typedef minja::chat_template common_chat_template;
10
+
11
+ struct common_chat_templates {
12
+ bool has_explicit_template; // Model had builtin template or template overridde was specified.
13
+ std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
14
+ std::unique_ptr<common_chat_template> template_tool_use;
15
+ };
16
+
17
+ struct templates_params {
18
+ json messages;
19
+ json tools;
20
+ common_chat_tool_choice tool_choice;
21
+ json json_schema;
22
+ bool parallel_tool_calls;
23
+ bool stream;
24
+ std::string grammar;
25
+ bool add_generation_prompt = true;
26
+ bool extract_reasoning = true;
27
+ };
28
+
29
+ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
30
+ if (tool_choice == "auto") {
31
+ return COMMON_CHAT_TOOL_CHOICE_AUTO;
32
+ }
33
+ if (tool_choice == "none") {
34
+ return COMMON_CHAT_TOOL_CHOICE_NONE;
35
+ }
36
+ if (tool_choice == "required") {
37
+ return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
38
+ }
39
+ throw std::runtime_error("Invalid tool_choice: " + tool_choice);
40
+ }
41
+
42
+ template <>
43
+ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
44
+ std::vector<common_chat_msg> msgs;
45
+
46
+ try {
47
+
48
+ if (!messages.is_array()) {
49
+ throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
50
+ }
51
+
52
+ for (const auto & message : messages) {
53
+ if (!message.is_object()) {
54
+ throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
55
+ }
56
+
57
+ common_chat_msg msg;
58
+ if (!message.contains("role")) {
59
+ throw std::runtime_error("Missing 'role' in message: " + message.dump());
60
+ }
61
+ msg.role = message.at("role");
62
+
63
+ auto has_content = message.contains("content");
64
+ auto has_tool_calls = message.contains("tool_calls");
65
+ if (has_content) {
66
+ const auto & content = message.at("content");
67
+ if (content.is_string()) {
68
+ msg.content = content;
69
+ } else if (content.is_array()) {
70
+ for (const auto & part : content) {
71
+ if (!part.contains("type")) {
72
+ throw std::runtime_error("Missing content part type: " + part.dump());
73
+ }
74
+ const auto & type = part.at("type");
75
+ if (type != "text") {
76
+ throw std::runtime_error("Unsupported content part type: " + type.dump());
77
+ }
78
+ common_chat_msg_content_part msg_part;
79
+ msg_part.type = type;
80
+ msg_part.text = part.at("text");
81
+ msg.content_parts.push_back(msg_part);
82
+ }
83
+ } else if (!content.is_null()) {
84
+ throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
85
+ }
86
+ }
87
+ if (has_tool_calls) {
88
+ for (const auto & tool_call : message.at("tool_calls")) {
89
+ common_chat_tool_call tc;
90
+ if (!tool_call.contains("type")) {
91
+ throw std::runtime_error("Missing tool call type: " + tool_call.dump());
92
+ }
93
+ const auto & type = tool_call.at("type");
94
+ if (type != "function") {
95
+ throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
96
+ }
97
+ if (!tool_call.contains("function")) {
98
+ throw std::runtime_error("Missing tool call function: " + tool_call.dump());
99
+ }
100
+ const auto & fc = tool_call.at("function");
101
+ if (!fc.contains("name")) {
102
+ throw std::runtime_error("Missing tool call name: " + tool_call.dump());
103
+ }
104
+ tc.name = fc.at("name");
105
+ tc.arguments = fc.at("arguments");
106
+ if (tool_call.contains("id")) {
107
+ tc.id = tool_call.at("id");
108
+ }
109
+ msg.tool_calls.push_back(tc);
110
+ }
111
+ }
112
+ if (!has_content && !has_tool_calls) {
113
+ throw std::runtime_error("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)");
114
+ }
115
+ if (message.contains("reasoning_content")) {
116
+ msg.reasoning_content = message.at("reasoning_content");
117
+ }
118
+ if (message.contains("name")) {
119
+ msg.tool_name = message.at("name");
120
+ }
121
+ if (message.contains("tool_call_id")) {
122
+ msg.tool_call_id = message.at("tool_call_id");
123
+ }
124
+
125
+ msgs.push_back(msg);
126
+ }
127
+ } catch (const std::exception & e) {
128
+ throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
129
+ }
130
+
131
+ return msgs;
132
+ }
133
+
134
+ template <>
135
+ json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
136
+ json messages = json::array();
137
+ for (const auto & msg : msgs) {
138
+ if (!msg.content.empty() && !msg.content_parts.empty()) {
139
+ throw std::runtime_error("Cannot specify both content and content_parts");
140
+ }
141
+ json jmsg {
142
+ {"role", msg.role},
143
+ };
144
+ if (!msg.content.empty()) {
145
+ jmsg["content"] = msg.content;
146
+ } else if (!msg.content_parts.empty()) {
147
+ if (concat_typed_text) {
148
+ std::string text;
149
+ for (const auto & part : msg.content_parts) {
150
+ if (part.type != "text") {
151
+ LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
152
+ continue;
153
+ }
154
+ if (!text.empty()) {
155
+ text += '\n';
156
+ }
157
+ text += part.text;
158
+ }
159
+ jmsg["content"] = text;
160
+ } else {
161
+ auto & parts = jmsg["content"] = json::array();
162
+ for (const auto & part : msg.content_parts) {
163
+ parts.push_back({
164
+ {"type", part.type},
165
+ {"text", part.text},
166
+ });
167
+ }
168
+ }
169
+ } else {
170
+ jmsg["content"] = json(); // null
171
+ }
172
+ if (!msg.reasoning_content.empty()) {
173
+ jmsg["reasoning_content"] = msg.reasoning_content;
174
+ }
175
+ if (!msg.tool_name.empty()) {
176
+ jmsg["name"] = msg.tool_name;
177
+ }
178
+ if (!msg.tool_call_id.empty()) {
179
+ jmsg["tool_call_id"] = msg.tool_call_id;
180
+ }
181
+ if (!msg.tool_calls.empty()) {
182
+ auto & tool_calls = jmsg["tool_calls"] = json::array();
183
+ for (const auto & tool_call : msg.tool_calls) {
184
+ json tc {
185
+ {"type", "function"},
186
+ {"function", {
187
+ {"name", tool_call.name},
188
+ {"arguments", tool_call.arguments},
189
+ }},
190
+ };
191
+ if (!tool_call.id.empty()) {
192
+ tc["id"] = tool_call.id;
193
+ }
194
+ tool_calls.push_back(tc);
195
+ }
196
+ }
197
+ messages.push_back(jmsg);
198
+ }
199
+ return messages;
200
+ }
201
+
202
+ template <>
203
+ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
204
+ return common_chat_msgs_parse_oaicompat(json::parse(messages));
205
+ }
206
+
207
+ template <>
208
+ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
209
+ std::vector<common_chat_tool> result;
210
+
211
+ try {
212
+ if (!tools.is_null()) {
213
+ if (!tools.is_array()) {
214
+ throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
215
+ }
216
+ for (const auto & tool : tools) {
217
+ if (!tool.contains("type")) {
218
+ throw std::runtime_error("Missing tool type: " + tool.dump());
219
+ }
220
+ const auto & type = tool.at("type");
221
+ if (!type.is_string() || type != "function") {
222
+ throw std::runtime_error("Unsupported tool type: " + tool.dump());
223
+ }
224
+ if (!tool.contains("function")) {
225
+ throw std::runtime_error("Missing tool function: " + tool.dump());
226
+ }
227
+
228
+ const auto & function = tool.at("function");
229
+ result.push_back({
230
+ /* .name = */ function.at("name"),
231
+ /* .description = */ function.at("description"),
232
+ /* .parameters = */ function.at("parameters").dump(),
233
+ });
234
+ }
235
+ }
236
+ } catch (const std::exception & e) {
237
+ throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
238
+ }
239
+
240
+ return result;
241
+ }
242
+
243
+ template <>
244
+ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
245
+ return common_chat_tools_parse_oaicompat(json::parse(tools));
246
+ }
247
+
248
+ template <>
249
+ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
250
+ if (tools.empty()) {
251
+ return json();
252
+ }
253
+
254
+ auto result = json::array();
255
+ for (const auto & tool : tools) {
256
+ result.push_back({
257
+ {"type", "function"},
258
+ {"function", {
259
+ {"name", tool.name},
260
+ {"description", tool.description},
261
+ {"parameters", json::parse(tool.parameters)},
262
+ }},
263
+ });
264
+ }
265
+ return result;
266
+ }
267
+
268
+ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
269
+ if (use_jinja) {
270
+ try {
271
+ common_chat_msg msg;
272
+ msg.role = "user";
273
+ msg.content = "test";
274
+
275
+ auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
276
+
277
+ common_chat_templates_inputs inputs;
278
+ inputs.messages = {msg};
279
+
280
+ common_chat_templates_apply(tmpls.get(), inputs);
281
+ return true;
282
+ } catch (const std::exception & e) {
283
+ LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
284
+ return false;
285
+ }
286
+ }
287
+ llama_chat_message chat[] = {{"user", "test"}};
288
+ const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
289
+ return res >= 0;
290
+ }
291
+
292
+ std::string common_chat_format_single(
293
+ const struct common_chat_templates * tmpls,
294
+ const std::vector<common_chat_msg> & past_msg,
295
+ const common_chat_msg & new_msg,
296
+ bool add_ass,
297
+ bool use_jinja) {
298
+
299
+ common_chat_templates_inputs inputs;
300
+ inputs.use_jinja = use_jinja;
301
+
302
+ std::string fmt_past_msg;
303
+ if (!past_msg.empty()) {
304
+ inputs.messages = past_msg;
305
+ inputs.add_generation_prompt = false;
306
+ fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
307
+ }
308
+ std::ostringstream ss;
309
+ // if the past_msg ends with a newline, we must preserve it in the formatted version
310
+ if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
311
+ ss << "\n";
312
+ };
313
+ // format chat with new_msg
314
+ inputs.messages.push_back(new_msg);
315
+ inputs.add_generation_prompt = add_ass;
316
+ auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
317
+ // get the diff part
318
+ ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
319
+ return ss.str();
320
+ }
321
+
322
+ std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
323
+ common_chat_templates_inputs inputs;
324
+ inputs.use_jinja = use_jinja;
325
+ auto add_simple_msg = [&](auto role, auto content) {
326
+ common_chat_msg msg;
327
+ msg.role = role;
328
+ msg.content = content;
329
+ inputs.messages.push_back(msg);
330
+ };
331
+ add_simple_msg("system", "You are a helpful assistant");
332
+ add_simple_msg("user", "Hello");
333
+ add_simple_msg("assistant", "Hi there");
334
+ add_simple_msg("user", "How are you?");
335
+ return common_chat_templates_apply(tmpls, inputs).prompt;
336
+ }
337
+
338
+ #define CHATML_TEMPLATE_SRC \
339
+ "{%- for message in messages -%}\n" \
340
+ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
341
+ "{%- endfor -%}\n" \
342
+ "{%- if add_generation_prompt -%}\n" \
343
+ " {{- '<|im_start|>assistant\n' -}}\n" \
344
+ "{%- endif -%}"
345
+
346
+ void common_chat_templates_free(struct common_chat_templates * tmpls) {
347
+ delete tmpls;
348
+ }
349
+
350
+ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
351
+ return tmpls->has_explicit_template;
352
+ }
353
+
354
+ const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
355
+ if (variant != nullptr) {
356
+ if (strcmp(variant, "tool_use") == 0) {
357
+ if (tmpls->template_tool_use) {
358
+ return tmpls->template_tool_use->source().c_str();
359
+ }
360
+ return nullptr;
361
+ } else {
362
+ LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
363
+ }
364
+ }
365
+ return tmpls->template_default->source().c_str();
366
+ }
367
+
368
+ common_chat_templates_ptr common_chat_templates_init(
369
+ const struct llama_model * model,
370
+ const std::string & chat_template_override,
371
+ const std::string & bos_token_override,
372
+ const std::string & eos_token_override)
373
+ {
374
+ std::string default_template_src;
375
+ std::string template_tool_use_src;
376
+
377
+ bool has_explicit_template = !chat_template_override.empty();
378
+ if (chat_template_override.empty()) {
379
+ GGML_ASSERT(model != nullptr);
380
+ const auto * str = llama_model_chat_template(model, /* name */ nullptr);
381
+ if (str) {
382
+ default_template_src = str;
383
+ has_explicit_template = true;
384
+ }
385
+ str = llama_model_chat_template(model, /* name */ "tool_use");
386
+ if (str) {
387
+ template_tool_use_src = str;
388
+ has_explicit_template = true;
389
+ }
390
+ } else {
391
+ default_template_src = chat_template_override;
392
+ }
393
+ if (default_template_src.empty() || default_template_src == "chatml") {
394
+ if (!template_tool_use_src.empty()) {
395
+ default_template_src = template_tool_use_src;
396
+ } else {
397
+ default_template_src = CHATML_TEMPLATE_SRC;
398
+ }
399
+ }
400
+ std::string token_bos = bos_token_override;
401
+ std::string token_eos = eos_token_override;
402
+ if (model) {
403
+ const auto * vocab = llama_model_get_vocab(model);
404
+ const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
405
+ if (token == LLAMA_TOKEN_NULL) {
406
+ if (default_template_src.find(jinja_variable_name) != std::string::npos
407
+ || template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
408
+ LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
409
+ }
410
+ return std::string();
411
+ }
412
+ return common_token_to_piece(vocab, token, true);
413
+ };
414
+ token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
415
+ token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
416
+ }
417
+ common_chat_templates_ptr tmpls(new common_chat_templates());
418
+ tmpls->has_explicit_template = has_explicit_template;
419
+ try {
420
+ tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
421
+ } catch (const std::exception & e) {
422
+ LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
423
+ tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
424
+ }
425
+ if (!template_tool_use_src.empty()) {
426
+ try {
427
+ tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
428
+ } catch (const std::exception & e) {
429
+ LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
430
+ }
431
+ }
432
+ return tmpls;
433
+ }
434
+
435
+ std::string common_chat_format_name(common_chat_format format) {
436
+ switch (format) {
437
+ case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
438
+ case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
439
+ case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
440
+ case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
441
+ case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
442
+ case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
443
+ case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)";
444
+ case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
445
+ case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
446
+ case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
447
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
448
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)";
449
+ case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
450
+ case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)";
451
+ default:
452
+ throw std::runtime_error("Unknown chat format");
453
+ }
454
+ }
455
+
456
+ static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) {
457
+ // // https://json.nlohmann.me/features/parsing/sax_interface/
458
+ struct json_error_locator : public nlohmann::json_sax<json> {
459
+ std::size_t position;
460
+ bool found_error;
461
+
462
+ json_error_locator() : position(0), found_error(false) {}
463
+
464
+ bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
465
+ this->position = position - 1;
466
+ this->found_error = true;
467
+ return false;
468
+ }
469
+ bool null() override { return true; } // NOLINT
470
+ bool boolean(bool) override { return true; } // NOLINT
471
+ bool number_integer(number_integer_t) override { return true; } // NOLINT
472
+ bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
473
+ bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
474
+ bool string(string_t &) override { return true; } // NOLINT
475
+ bool binary(binary_t &) override { return true; } // NOLINT
476
+ bool start_object(std::size_t) override { return true; } // NOLINT
477
+ bool key(string_t &) override { return true; } // NOLINT
478
+ bool end_object() override { return true; }
479
+ bool start_array(std::size_t) override { return true; } // NOLINT
480
+ bool end_array() override { return true; }
481
+ };
482
+ json_error_locator err_loc;
483
+ json::sax_parse(it, end, &err_loc);
484
+
485
+ std::string::const_iterator temptative_end;
486
+ if (err_loc.found_error) {
487
+ temptative_end = it + err_loc.position;
488
+ } else {
489
+ temptative_end = end;
490
+ }
491
+ std::string json_sub {it, temptative_end};
492
+ try {
493
+ out = json::parse(json_sub);
494
+ it = temptative_end;
495
+ return true;
496
+ } catch (const std::exception &) {
497
+ return false;
498
+ }
499
+ }
500
+
501
+ static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) {
502
+ auto expected_it = expected.begin();
503
+ auto tmp_it = it;
504
+ while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) {
505
+ ++tmp_it;
506
+ ++expected_it;
507
+ }
508
+ if (expected_it == expected.end()) {
509
+ it = tmp_it;
510
+ return true;
511
+ }
512
+ return false;
513
+ }
514
+
515
+ static std::optional<std::smatch> parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) {
516
+ std::smatch match;
517
+ if (std::regex_match(it, end, match, expected)) {
518
+ it = match.suffix().first;
519
+ return match;
520
+ }
521
+ return std::nullopt;
522
+ }
523
+
524
+ static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) {
525
+ while (it != end && std::isspace(*it)) {
526
+ ++it;
527
+ }
528
+ }
529
+
530
+ /**
531
+ * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
532
+ * Aggregates the prefix, suffix and in-between text into the content.
533
+ */
534
+ static common_chat_msg parse_json_tool_calls(
535
+ const std::string& input,
536
+ const std::optional<std::regex> & trigger_opt,
537
+ const std::regex & function_regex,
538
+ const std::regex & close_regex,
539
+ bool allow_raw_python = false) {
540
+ std::smatch match;
541
+
542
+ common_chat_msg result;
543
+ result.role = "assistant";
544
+
545
+
546
+ auto end = input.end();
547
+ auto it = input.begin();
548
+
549
+ if (trigger_opt) {
550
+ if (!std::regex_search(it, end, match, *trigger_opt)) {
551
+ result.content = input;
552
+ return result;
553
+ }
554
+ result.content = match.prefix().str();
555
+ it = match.suffix().first;
556
+ }
557
+
558
+ while (it != end) {
559
+ std::sregex_iterator rend;
560
+ std::sregex_iterator rit(it, end, function_regex);
561
+ if (rit == rend) {
562
+ result.content += std::string(it, end);
563
+ break;
564
+ }
565
+ auto name = rit->str(1);
566
+ result.content += std::string(it, rit->prefix().second);
567
+ it = rit->suffix().first;
568
+
569
+ json arguments;
570
+ if (parse_json(it, end, arguments)) {
571
+ if (!std::regex_search(it, end, match, close_regex)) {
572
+ throw std::runtime_error("Malformed input, missing closing pattern: " + input);
573
+ }
574
+ it = match.suffix().first;
575
+ result.tool_calls.push_back({name, arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), /* id= */ ""});
576
+ } else {
577
+ if (allow_raw_python && name == "python") {
578
+ result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""});
579
+ break;
580
+ }
581
+ throw std::runtime_error("Failed to parse json tool call arguments: " + input);
582
+ }
583
+ }
584
+
585
+ if (!result.tool_calls.empty()) {
586
+ if (!string_strip(result.content).empty()) {
587
+ LOG_WRN("Content found with tool calls: %s\n", result.content.c_str());
588
+ }
589
+ result.content = "";
590
+ }
591
+ return result;
592
+ }
593
+
594
+ static common_chat_tool_call process_tool_call(const json & tool_call) {
595
+ const auto & arguments = tool_call.at("arguments");
596
+ return {
597
+ /* .name = */ tool_call.at("name"),
598
+ /* .arguments = */ arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
599
+ /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "",
600
+ };
601
+ }
602
+ static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) {
603
+ auto content_end = input.find(prefix);
604
+ size_t tc_start = std::string::npos;
605
+
606
+ common_chat_msg result;
607
+ result.role = "assistant";
608
+ if (content_end == std::string::npos) {
609
+ result.content = input;
610
+ } else {
611
+ tc_start = content_end + prefix.size() - rstrip_prefix;
612
+ result.content = input.substr(0, content_end);
613
+ auto tool_calls = json::parse(input.substr(tc_start));
614
+ for (const auto & tool_call : tool_calls) {
615
+ result.tool_calls.emplace_back(process_tool_call(tool_call));
616
+ }
617
+ }
618
+ return result;
619
+ }
620
+
621
+ static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
622
+ for (const auto & tool : tools) {
623
+ if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) {
624
+ LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str());
625
+ continue;
626
+ }
627
+ fn(tool);
628
+ }
629
+ }
630
+
631
+ static std::string apply(
632
+ const common_chat_template & tmpl,
633
+ const nlohmann::ordered_json & messages,
634
+ const nlohmann::ordered_json & tools,
635
+ bool add_generation_prompt,
636
+ const nlohmann::ordered_json & extra_context = nlohmann::ordered_json())
637
+ {
638
+ minja::chat_template_inputs tmpl_inputs;
639
+ tmpl_inputs.messages = messages;
640
+ tmpl_inputs.tools = tools;
641
+ tmpl_inputs.add_generation_prompt = add_generation_prompt;
642
+ tmpl_inputs.extra_context = extra_context;
643
+ // TODO: add flag to control date/time, if only for testing purposes.
644
+ // tmpl_inputs.now = std::chrono::system_clock::now();
645
+
646
+ minja::chat_template_options tmpl_opts;
647
+ // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
648
+ // instead of using `chat_template_options.use_bos_token = false`, since these tokens
649
+ // may be needed inside the template / between messages too.
650
+ auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
651
+ if (string_starts_with(result, tmpl.bos_token())) {
652
+ result = result.substr(tmpl.bos_token().size());
653
+ }
654
+ if (string_ends_with(result, tmpl.eos_token())) {
655
+ result = result.substr(0, result.size() - tmpl.eos_token().size());
656
+ }
657
+ return result;
658
+ }
659
+
660
+ static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
661
+ common_chat_params data;
662
+
663
+ auto tool_call_schemas = json::array();
664
+ foreach_function(inputs.tools, [&](const json & tool) {
665
+ const auto & function = tool.at("function");
666
+ auto tool_schema = json {
667
+ {"type", "object"},
668
+ {"properties", {
669
+ {"name", {
670
+ {"type", "string"},
671
+ {"const", function.at("name")},
672
+ }},
673
+ {"arguments", function.at("parameters")},
674
+ }},
675
+ {"required", json::array({"name", "arguments"})},
676
+ };
677
+ if (function.contains("description")) {
678
+ tool_schema["description"] = function.at("description");
679
+ }
680
+ if (inputs.parallel_tool_calls) {
681
+ tool_schema.at("properties")["id"] = {
682
+ {"type", "string"},
683
+ {"minLength", 4},
684
+ };
685
+ tool_schema.at("required").push_back("id");
686
+ }
687
+ tool_call_schemas.emplace_back(tool_schema);
688
+ });
689
+ const auto tool_call =
690
+ inputs.parallel_tool_calls
691
+ ? json {
692
+ {"type", "object"},
693
+ {"properties", {
694
+ {"tool_calls", {
695
+ {"type", "array"},
696
+ {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
697
+ {"anyOf", tool_call_schemas},
698
+ }},
699
+ {"minItems", 1},
700
+ }},
701
+ }},
702
+ {"required", json::array({"tool_calls"})},
703
+ }
704
+ : json {
705
+ {"type", "object"},
706
+ {"properties", {
707
+ {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json {
708
+ {"anyOf", tool_call_schemas},
709
+ }},
710
+ }},
711
+ {"required", json::array({"tool_call"})},
712
+ };
713
+ const auto schema =
714
+ inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
715
+ ? json {
716
+ {"anyOf", json::array({
717
+ tool_call,
718
+ {
719
+ {"type", "object"},
720
+ {"properties", {
721
+ {"response", inputs.json_schema.is_null()
722
+ ? json {{"type", "string"}}
723
+ : inputs.json_schema
724
+ },
725
+ }},
726
+ {"required", json::array({"response"})},
727
+ },
728
+ })}
729
+ }
730
+ : tool_call;
731
+
732
+ data.grammar_lazy = false;
733
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
734
+ builder.add_schema("root", schema);
735
+ });
736
+
737
+ auto tweaked_messages = common_chat_template::add_system(
738
+ inputs.messages,
739
+ "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request");
740
+
741
+ data.prompt = apply(tmpl, tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
742
+ data.format = COMMON_CHAT_FORMAT_GENERIC;
743
+ return data;
744
+ }
745
+ static common_chat_msg common_chat_parse_generic(const std::string & input) {
746
+ json data = json::parse(input);
747
+ common_chat_msg result;
748
+ result.role = "assistant";
749
+ if (data.contains("tool_calls")) {
750
+ for (const auto & tool_call : data.at("tool_calls")) {
751
+ result.tool_calls.push_back({
752
+ tool_call.at("name"),
753
+ tool_call.at("arguments").dump(),
754
+ tool_call.contains("id") ? tool_call.at("id") : "",
755
+ });
756
+ }
757
+ } else if (data.contains("tool_call")) {
758
+ result.tool_calls.push_back({
759
+ data.at("tool_call").at("name"),
760
+ data.at("tool_call").at("arguments").dump(),
761
+ /* id= */ "",
762
+ });
763
+ } else if (data.contains("response")) {
764
+ const auto & response = data.at("response");
765
+ result.content = response.is_string() ? response.get<std::string>() : response.dump(2);
766
+ }
767
+ return result;
768
+ }
769
+
770
+ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
771
+ common_chat_params data;
772
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
773
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
774
+ auto schemas = json::array();
775
+ foreach_function(inputs.tools, [&](const json & tool) {
776
+ const auto & function = tool.at("function");
777
+ schemas.push_back({
778
+ {"type", "object"},
779
+ {"properties", {
780
+ // Important note: the model is probably trained to take a JSON stringified arguments value.
781
+ // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object.
782
+ {"name", {
783
+ {"type", "string"},
784
+ {"const", function.at("name")},
785
+ }},
786
+ {"arguments", function.at("parameters")},
787
+ {"id", {
788
+ {"type", "string"},
789
+ // Nemo's template expects a 9-character alphanumeric ID.
790
+ {"pattern", "^[a-zA-Z0-9]{9}$"},
791
+ }},
792
+ }},
793
+ {"required", json::array({"name", "arguments", "id"})},
794
+ });
795
+ });
796
+ auto schema = json {
797
+ {"type", "array"},
798
+ {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
799
+ {"minItems", 1},
800
+ };
801
+ if (!inputs.parallel_tool_calls) {
802
+ schema["maxItems"] = 1;
803
+ }
804
+ builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
805
+ });
806
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
807
+ data.preserved_tokens = {
808
+ "[TOOL_CALLS]",
809
+ };
810
+ data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
811
+ data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
812
+ return data;
813
+ }
814
+ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) {
815
+ return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
816
+ }
817
+
818
+ static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
819
+ common_chat_params data;
820
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
821
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
822
+ auto schemas = json::array();
823
+ foreach_function(inputs.tools, [&](const json & tool) {
824
+ const auto & function = tool.at("function");
825
+ schemas.push_back({
826
+ {"type", "object"},
827
+ {"properties", {
828
+ {"tool_call_id", {
829
+ {"type", "string"},
830
+ // Command-R's template expects an integer string.
831
+ {"pattern", "^[0-9]{1,10}$"},
832
+ }},
833
+ {"tool_name", {
834
+ {"type", "string"},
835
+ {"const", function.at("name")},
836
+ }},
837
+ {"parameters", function.at("parameters")},
838
+ }},
839
+ {"required", json::array({"tool_call_id", "tool_name", "parameters"})},
840
+ });
841
+ });
842
+ auto schema = json {
843
+ {"type", "array"},
844
+ {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
845
+ {"minItems", 1},
846
+ };
847
+ if (!inputs.parallel_tool_calls) {
848
+ schema["maxItems"] = 1;
849
+ }
850
+ builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\"");
851
+ });
852
+ data.grammar_triggers.push_back({
853
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
854
+ "<|START_ACTION|>",
855
+ });
856
+ data.preserved_tokens = {
857
+ "<|START_ACTION|>",
858
+ "<|END_ACTION|>",
859
+ "<|START_RESPONSE|>",
860
+ "<|END_RESPONSE|>",
861
+ "<|START_THINKING|>",
862
+ "<|END_THINKING|>",
863
+ };
864
+ auto adjusted_messages = json::array();
865
+ for (const auto & msg : inputs.messages) {
866
+ auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string();
867
+ auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array();
868
+ if (has_reasoning_content && has_tool_calls) {
869
+ auto adjusted_message = msg;
870
+ adjusted_message["tool_plan"] = msg.at("reasoning_content");
871
+ adjusted_message.erase("reasoning_content");
872
+ adjusted_messages.push_back(adjusted_message);
873
+ } else {
874
+ adjusted_messages.push_back(msg);
875
+ }
876
+ }
877
+ data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {});
878
+ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B;
879
+ return data;
880
+ }
881
+ static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) {
882
+ static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)");
883
+ static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>");
884
+ static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>");
885
+
886
+ std::smatch match;
887
+
888
+ common_chat_msg result;
889
+ result.role = "assistant";
890
+
891
+ std::string rest = input;
892
+
893
+ if (std::regex_match(rest, match, thought_regex)) {
894
+ if (extract_reasoning) {
895
+ result.reasoning_content = match[2].str();
896
+ } else if (!match[2].str().empty()) {
897
+ // Let the unparsed thinking tags through in content only if their insides aren't empty.
898
+ result.content = match[1].str();
899
+ }
900
+ rest = match[3].str();
901
+ }
902
+ if (std::regex_match(rest, match, action_regex)) {
903
+ auto actions_str = match[1].str();
904
+ auto actions = json::parse(actions_str);
905
+ for (const auto & action : actions) {
906
+ result.tool_calls.push_back({
907
+ /* .name = */ action.at("tool_name"),
908
+ /* .arguments = */ action.at("parameters").dump(),
909
+ /* .id = */ action.at("tool_call_id"),
910
+ });
911
+ }
912
+ } else if (std::regex_match(rest, match, response_regex)) {
913
+ auto response = match[1].str();
914
+ result.content += response;
915
+ } else {
916
+ result.content += rest;
917
+ }
918
+ return result;
919
+ }
920
+
921
+ static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector<std::string> & expected_properties) {
922
+ if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) {
923
+ throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties");
924
+ }
925
+ const auto & parameters_properties = parameters.at("properties");
926
+ const auto & parameters_required = parameters.at("required");
927
+ for (const auto & prop : expected_properties) {
928
+ if (!parameters_properties.contains(prop)) {
929
+ throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
930
+ }
931
+ if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
932
+ throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
933
+ }
934
+ }
935
+ if (parameters_properties.size() != expected_properties.size()) {
936
+ throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", "));
937
+ }
938
+ }
939
+
940
+ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
941
+ auto builtin_tools = json::array();
942
+ common_chat_params data;
943
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
944
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
945
+ std::vector<std::string> tool_rules;
946
+
947
+ auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
948
+ if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
949
+ // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
950
+ // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
951
+ expect_tool_parameters(name, parameters, {"query"});
952
+ } else if (name == "python" || name == "code_interpreter") {
953
+ // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
954
+ expect_tool_parameters(name, parameters, {"code"});
955
+ } else {
956
+ return false;
957
+ }
958
+
959
+ std::vector<std::string> kvs;
960
+ for (const auto & [key, value] : parameters.at("properties").items()) {
961
+ kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
962
+ }
963
+
964
+ tool_rules.push_back(
965
+ builder.add_rule(
966
+ name + "-call",
967
+ "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
968
+ builtin_tools.push_back(name);
969
+
970
+ return true;
971
+ };
972
+
973
+ foreach_function(inputs.tools, [&](const json & tool) {
974
+ const auto & function = tool.at("function");
975
+ std::string name = function.at("name");
976
+ auto parameters = function.at("parameters");
977
+ builder.resolve_refs(parameters);
978
+
979
+ // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
980
+ if (allow_python_tag_builtin_tools) {
981
+ handle_builtin_tool(name, parameters);
982
+ }
983
+ tool_rules.push_back(
984
+ builder.add_rule(
985
+ name + "-call",
986
+ "\"{\" space "
987
+ "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
988
+ " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
989
+ " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
990
+ "\"}\" space"));
991
+ });
992
+ // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
993
+ data.grammar_triggers.push_back({
994
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
995
+ "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
996
+ });
997
+ if (!builtin_tools.empty()) {
998
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
999
+ data.preserved_tokens.push_back("<|python_tag|>");
1000
+ }
1001
+ // Allow a few empty lines on top of the usual constrained json schema space rule.
1002
+ builder.add_rule("root", string_join(tool_rules, " | "));
1003
+ });
1004
+ data.additional_stops.push_back("<|eom_id|>");
1005
+ data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
1006
+ {"tools_in_user_message", false},
1007
+ {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
1008
+ });
1009
+ data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
1010
+ ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
1011
+ : COMMON_CHAT_FORMAT_LLAMA_3_X;
1012
+ return data;
1013
+ }
1014
+ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
1015
+ // TODO: tighten & simplify the parser, don't accept leading text context.
1016
+ static const std::regex function_regex(
1017
+ "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
1018
+ static const std::regex close_regex("\\}\\s*");
1019
+ static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)");
1020
+
1021
+ if (with_builtin_tools) {
1022
+ std::smatch match;
1023
+ if (std::regex_match(input, match, builtin_call_regex)) {
1024
+ try {
1025
+ auto name = match[1].str();
1026
+ auto arg_name = match[2].str();
1027
+ auto arg_value_str = match[3].str();
1028
+ auto arg_value = json::parse(arg_value_str);
1029
+
1030
+ common_chat_msg msg;
1031
+ msg.role = "assistant";
1032
+ msg.tool_calls.push_back({
1033
+ /* .name = */ name,
1034
+ /* .arguments = */ (json {
1035
+ {arg_name, arg_value},
1036
+ }).dump(),
1037
+ /* .id = */ "",
1038
+ });
1039
+ return msg;
1040
+ } catch (const std::exception & e) {
1041
+ LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str());
1042
+ }
1043
+ }
1044
+ }
1045
+ return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
1046
+ }
1047
+
1048
+ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
1049
+ common_chat_params data;
1050
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
1051
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
1052
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1053
+ std::vector<std::string> tool_rules;
1054
+ foreach_function(inputs.tools, [&](const json & tool) {
1055
+ const auto & function = tool.at("function");
1056
+ std::string name = function.at("name");
1057
+ auto parameters = function.at("parameters");
1058
+ builder.resolve_refs(parameters);
1059
+ tool_rules.push_back(builder.add_rule(name + "-call",
1060
+ "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
1061
+ "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " "
1062
+ "\"```<|tool▁call▁end|>\""));
1063
+ });
1064
+ // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
1065
+ // so we accept common variants (then it's all constrained)
1066
+ builder.add_rule("root",
1067
+ "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
1068
+ "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
1069
+ "\"<|tool▁calls▁end|>\""
1070
+ " space");
1071
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"});
1072
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"});
1073
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"});
1074
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"});
1075
+ data.preserved_tokens = {
1076
+ "<think>",
1077
+ "</think>",
1078
+ "<|tool▁calls▁begin|>",
1079
+ "<|tool▁call▁begin|>",
1080
+ "<|tool▁sep|>",
1081
+ "<|tool▁call▁end|>",
1082
+ "<|tool▁calls▁end|",
1083
+ };
1084
+ });
1085
+ }
1086
+ auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1087
+
1088
+ // Hacks to fix the official (broken) prompt.
1089
+ // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead,
1090
+ // until the official template is fixed.
1091
+ if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) {
1092
+ // Don't leave the chat dangling after tool results
1093
+ if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) {
1094
+ prompt += "<|end▁of▁sentence|>";
1095
+ if (inputs.add_generation_prompt) {
1096
+ prompt += "<|Assistant|>";
1097
+ }
1098
+ }
1099
+ // Fix up tool call delta example added by Minja
1100
+ prompt = std::regex_replace(
1101
+ prompt,
1102
+ std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"),
1103
+ "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2");
1104
+ }
1105
+ data.prompt = prompt;
1106
+ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1;
1107
+ return data;
1108
+ }
1109
+ static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function<common_chat_msg(const std::string &)> & rest_parser) {
1110
+ std::smatch match;
1111
+ static const std::regex reasoning_content_regex("((?:<think>)?([\\s\\S\\r\\n]*?)</think>)?([\\s\\S\\r\\n]*)");
1112
+ if (std::regex_match(input, match, reasoning_content_regex)) {
1113
+ auto rest = match[3].str();
1114
+ auto msg = rest_parser(rest);
1115
+ auto reasoning_content = string_strip(match[2].str());
1116
+ if (extract_reasoning) {
1117
+ msg.reasoning_content = reasoning_content;
1118
+ } else if (!reasoning_content.empty()) {
1119
+ std::ostringstream content;
1120
+ content << "<think>" << reasoning_content << "</think>" << msg.content;
1121
+ msg.content = content.str();
1122
+ }
1123
+ return msg;
1124
+ }
1125
+ return rest_parser(input);
1126
+ }
1127
+ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) {
1128
+ return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
1129
+ static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n");
1130
+ static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>");
1131
+ static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>");
1132
+
1133
+ common_chat_msg msg;
1134
+ msg.role = "assistant";
1135
+ std::smatch match;
1136
+ if (std::regex_search(input, match, tool_calls_regex)) {
1137
+ auto tool_calls = match[1].str();
1138
+ auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex);
1139
+ msg.tool_calls = std::move(msg2.tool_calls);
1140
+ } else {
1141
+ msg.content = input;
1142
+ }
1143
+ return msg;
1144
+ });
1145
+ }
1146
+
1147
+ static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1148
+ LOG_DBG("%s\n", __func__);
1149
+ common_chat_params data;
1150
+ data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
1151
+ {"datetime", "Jan 29 2025 13:00:00 GMT"},
1152
+ {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
1153
+ });
1154
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
1155
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1156
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1157
+ auto schemas = json::array();
1158
+ foreach_function(inputs.tools, [&](const json & tool) {
1159
+ const auto & function = tool.at("function");
1160
+ schemas.push_back({
1161
+ {"type", "object"},
1162
+ {"properties", {
1163
+ {"name", {
1164
+ {"type", "string"},
1165
+ {"const", function.at("name")},
1166
+ }},
1167
+ {"arguments", function.at("parameters")},
1168
+ }},
1169
+ {"required", json::array({"name", "arguments", "id"})},
1170
+ });
1171
+ });
1172
+ auto schema = json {
1173
+ {"type", "array"},
1174
+ {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
1175
+ {"minItems", 1},
1176
+ };
1177
+ if (!inputs.parallel_tool_calls) {
1178
+ schema["maxItems"] = 1;
1179
+ }
1180
+ builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema));
1181
+ });
1182
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["});
1183
+ data.preserved_tokens = {
1184
+ " functools[",
1185
+ };
1186
+ data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2;
1187
+ } else {
1188
+ data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1189
+ }
1190
+ return data;
1191
+ }
1192
+ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) {
1193
+ return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
1194
+ }
1195
+
1196
+ static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1197
+ // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
1198
+ // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
1199
+ common_chat_params data;
1200
+ data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1201
+ data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
1202
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
1203
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1204
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1205
+ std::vector<std::string> first_tool_rules;
1206
+ std::vector<std::string> subsequent_tool_rules;
1207
+ foreach_function(inputs.tools, [&](const json & tool) {
1208
+ const auto & function = tool.at("function");
1209
+ std::string name = function.at("name");
1210
+ auto parameters = function.at("parameters");
1211
+ builder.resolve_refs(parameters);
1212
+ auto args_rule = builder.add_schema(name + "-args", parameters);
1213
+ first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule));
1214
+ subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
1215
+ data.grammar_triggers.push_back({
1216
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1217
+ regex_escape(name + "\n"),
1218
+ });
1219
+ data.grammar_triggers.push_back({
1220
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1221
+ regex_escape("assistant<|end_header_id|>\n" + name + "\n"),
1222
+ });
1223
+ data.grammar_triggers.push_back({
1224
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1225
+ regex_escape(">>>" + name + "\n"),
1226
+ });
1227
+ data.grammar_triggers.push_back({
1228
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1229
+ ">>>assistant<|end_header_id|>\n" + name,
1230
+ });
1231
+ });
1232
+ data.preserved_tokens = {
1233
+ "<|end_header_id|>",
1234
+ };
1235
+ auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space";
1236
+ if (inputs.parallel_tool_calls) {
1237
+ auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space";
1238
+ builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*");
1239
+ } else {
1240
+ builder.add_rule("root", first_rule);
1241
+ }
1242
+
1243
+ });
1244
+ }
1245
+ return data;
1246
+ }
1247
+
1248
+ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) {
1249
+ static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)");
1250
+ static const std::regex close_regex(R"($|(?=>>>))");
1251
+
1252
+ std::string content;
1253
+ auto it = input.begin();
1254
+ const auto end = input.end();
1255
+
1256
+ if (parse_literal(it, end, "all\n")) {
1257
+ std::smatch match;
1258
+ if (std::regex_search(it, end, match, function_regex)) {
1259
+ auto fun_it = match.prefix().second;
1260
+ content = std::string(it, fun_it);
1261
+ it = fun_it;
1262
+ } else {
1263
+ common_chat_msg res;
1264
+ res.role = "assistant";
1265
+ res.content = std::string(it, end);
1266
+ return res;
1267
+ }
1268
+ }
1269
+ // TODO: tighten & simplify.
1270
+ try {
1271
+ auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true);
1272
+ res.content = content + res.content;
1273
+ return res;
1274
+ } catch (const std::exception & e) {
1275
+ LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what());
1276
+ common_chat_msg res;
1277
+ res.role = "assistant";
1278
+ res.content = input;
1279
+ return res;
1280
+ }
1281
+ }
1282
+
1283
+ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
1284
+ // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
1285
+ common_chat_params data;
1286
+ json tools = inputs.tools.is_null() ? inputs.tools : json::array();
1287
+ std::string python_code_argument_name;
1288
+ auto has_raw_python = false;
1289
+
1290
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1291
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1292
+ std::vector<std::string> tool_rules;
1293
+ foreach_function(inputs.tools, [&](const json & tool) {
1294
+ const auto & function = tool.at("function");
1295
+ const auto & parameters = function.at("parameters");
1296
+ std::string name = function.at("name");
1297
+ if (name == "python" || name == "ipython") {
1298
+ if (!parameters.contains("type")) {
1299
+ throw std::runtime_error("Missing type in python tool");
1300
+ }
1301
+ has_raw_python = true;
1302
+ const auto & type = parameters.at("type");
1303
+ if (type == "object") {
1304
+ auto properties = parameters.at("properties");
1305
+ for (auto it = properties.begin(); it != properties.end(); ++it) {
1306
+ if (it.value().at("type") == "string") {
1307
+ if (!python_code_argument_name.empty()) {
1308
+ throw std::runtime_error("Multiple string arguments found in python tool");
1309
+ }
1310
+ python_code_argument_name = it.key();
1311
+ }
1312
+ }
1313
+ if (python_code_argument_name.empty()) {
1314
+ throw std::runtime_error("No string argument found in python tool");
1315
+ }
1316
+ } else if (type != "string") {
1317
+ throw std::runtime_error("Invalid type in python tool: " + type.dump());
1318
+ }
1319
+ }
1320
+ tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
1321
+ });
1322
+ if (has_raw_python) {
1323
+ tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
1324
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
1325
+ data.preserved_tokens.push_back("<|python_tag|>");
1326
+ }
1327
+ auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
1328
+ builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1329
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
1330
+ });
1331
+
1332
+ data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1333
+ // TODO: if (has_raw_python)
1334
+ data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
1335
+ return data;
1336
+ }
1337
+ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
1338
+ // This version of Functionary still supports the llama 3.1 tool call format for the python tool.
1339
+ static const std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)");
1340
+ std::smatch match;
1341
+ if (std::regex_search(input, match, python_tag_regex)) {
1342
+ auto code = match[1].str();
1343
+ common_chat_msg msg;
1344
+ msg.role = "assistant";
1345
+ msg.content = match.prefix().str();
1346
+ msg.tool_calls.push_back({
1347
+ /* .name = */ "python",
1348
+ /* .arguments = */ (json {{"code", code}}).dump(),
1349
+ /* .id = */ "",
1350
+ });
1351
+ return msg;
1352
+ }
1353
+ static const std::regex function_regex(R"(<function=(\w+)>)");
1354
+ static const std::regex close_regex(R"(</function>)");
1355
+ // TODO: tighten & simplify.
1356
+ return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
1357
+ }
1358
+
1359
+ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
1360
+ common_chat_params data;
1361
+ // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
1362
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1363
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1364
+ std::vector<std::string> tool_rules;
1365
+ std::vector<std::string> tool_call_alts;
1366
+ foreach_function(inputs.tools, [&](const json & tool) {
1367
+ const auto & function = tool.at("function");
1368
+ std::string name = function.at("name");
1369
+ auto parameters = function.at("parameters");
1370
+ builder.resolve_refs(parameters);
1371
+ tool_rules.push_back(builder.add_schema(name + "-call", {
1372
+ {"type", "object"},
1373
+ {"properties", json {
1374
+ {"name", json {{"const", name}}},
1375
+ {"arguments", parameters},
1376
+ }},
1377
+ {"required", json::array({"name", "arguments"})},
1378
+ }));
1379
+ tool_call_alts.push_back(builder.add_rule(
1380
+ name + "-function-tag",
1381
+ "\"<function\" ( \"=" + name + "\" | \" name=\\\"" + name + "\\\"\" ) \">\" space " +
1382
+ builder.add_schema(name + "-args", parameters) + " "
1383
+ "\"</function>\" space"));
1384
+
1385
+ data.grammar_triggers.push_back({
1386
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
1387
+ "<function=" + name + ">",
1388
+ });
1389
+ auto escaped_name = regex_escape(name);
1390
+ data.grammar_triggers.push_back({
1391
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
1392
+ "<function\\s+name\\s*=\\s*\"" + escaped_name + "\"",
1393
+ });
1394
+ });
1395
+ auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space");
1396
+ std::vector<std::string> alt_tags {
1397
+ any_tool_call,
1398
+ "\"<tool_call>\" space " + any_tool_call + " \"</tool_call>\"",
1399
+ // The rest is just to accommodate common "good bad" outputs.
1400
+ "\"<function_call>\" space " + any_tool_call + " \"</function_call>\"",
1401
+ "\"<response>\" space " + any_tool_call + " \"</response>\"",
1402
+ "\"<tools>\" space " + any_tool_call + " \"</tools>\"",
1403
+ "\"<json>\" space " + any_tool_call + " \"</json>\"",
1404
+ "\"<xml>\" space " + any_tool_call + " \"</xml>\"",
1405
+ "\"<JSON>\" space " + any_tool_call + " \"</JSON>\"",
1406
+ };
1407
+ auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space");
1408
+ tool_call_alts.push_back(wrappable_tool_call);
1409
+ tool_call_alts.push_back(
1410
+ "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space ");
1411
+ auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | "));
1412
+ builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
1413
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"});
1414
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function"});
1415
+ // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
1416
+ data.grammar_triggers.push_back({
1417
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
1418
+ "(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"",
1419
+ });
1420
+ data.preserved_tokens = {
1421
+ "<think>",
1422
+ "</think>",
1423
+ "<tool_call>",
1424
+ "</tool_call>",
1425
+ "<function",
1426
+ "<tools>",
1427
+ "</tools>",
1428
+ "<response>",
1429
+ "</response>",
1430
+ "<function_call>",
1431
+ "</function_call>",
1432
+ "<json>",
1433
+ "</json>",
1434
+ "<JSON>",
1435
+ "</JSON>",
1436
+ "```",
1437
+ "```json",
1438
+ "```xml",
1439
+ };
1440
+ });
1441
+
1442
+ data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1443
+ data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO;
1444
+ return data;
1445
+ }
1446
+ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) {
1447
+ return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) {
1448
+ static const std::regex open_regex(
1449
+ "(?:"
1450
+ "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
1451
+ "(<tool_call>" // match 2 (open_tag)
1452
+ "|<function_call>"
1453
+ "|<tool>"
1454
+ "|<tools>"
1455
+ "|<response>"
1456
+ "|<json>"
1457
+ "|<xml>"
1458
+ "|<JSON>"
1459
+ ")?"
1460
+ "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest)
1461
+ ")"
1462
+ "|"
1463
+ "(?:<function=([^>]+)>" // match 4 (function name)
1464
+ "|<function name=\"([^\"]+)\">)" // match 5 (function name again)
1465
+ "([\\s\\S]*)" // match 6 (function arguments + rest)})"
1466
+ );
1467
+
1468
+ try {
1469
+ common_chat_msg msg;
1470
+ msg.role = "assistant";
1471
+
1472
+ std::string::const_iterator it = input.begin();
1473
+ const std::string::const_iterator end = input.end();
1474
+ std::smatch match;
1475
+
1476
+ while (it != end) {
1477
+ if (std::regex_search(it, end, match, open_regex)) {
1478
+ // Add content before the match
1479
+ msg.content += std::string(it, match[0].first);
1480
+
1481
+ auto block_start = match[1].str();
1482
+ std::string block_end = block_start.empty() ? "" : "```";
1483
+
1484
+ auto open_tag = match[2].str();
1485
+ std::string close_tag;
1486
+
1487
+ if (match[3].matched) {
1488
+ close_tag = open_tag.empty() ? "" : "</" + open_tag.substr(1);
1489
+ auto json_it = match[3].first;
1490
+ json tool_call;
1491
+ if (parse_json(json_it, end, tool_call) && tool_call.contains("name") && tool_call.contains("arguments")) {
1492
+
1493
+ msg.tool_calls.emplace_back(process_tool_call(tool_call));
1494
+ it = json_it; // Move iterator past parsed JSON
1495
+
1496
+ // Handle close tags
1497
+ consume_spaces(it, end);
1498
+ if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
1499
+ throw std::runtime_error("Failed to parse closing tag");
1500
+ }
1501
+ consume_spaces(it, end);
1502
+ if (!block_end.empty() && !parse_literal(it, end, block_end)) {
1503
+ throw std::runtime_error("Failed to parse block end");
1504
+ }
1505
+ consume_spaces(it, end);
1506
+ } else {
1507
+ // Not a valid tool call, treat as content
1508
+ msg.content += std::string(match[0].first, match[0].second);
1509
+ it = match[0].second;
1510
+ }
1511
+ } else {
1512
+ auto function_name = match[4].str();
1513
+ if (function_name.empty()) {
1514
+ function_name = match[5].str();
1515
+ }
1516
+ GGML_ASSERT(!function_name.empty());
1517
+
1518
+ close_tag = "</function>";
1519
+ // Start parsing from after the opening tags
1520
+ auto json_it = match[6].first;
1521
+ json arguments;
1522
+ if (parse_json(json_it, end, arguments)) {
1523
+ msg.tool_calls.emplace_back(process_tool_call({
1524
+ {"name", function_name},
1525
+ {"arguments", arguments},
1526
+ }));
1527
+ it = json_it; // Move iterator past parsed JSON
1528
+
1529
+ // Handle close tags
1530
+ consume_spaces(it, end);
1531
+ if (!close_tag.empty() && !parse_literal(it, end, close_tag)) {
1532
+ throw std::runtime_error("Failed to parse closing tag");
1533
+ }
1534
+ consume_spaces(it, end);
1535
+ if (!block_end.empty() && !parse_literal(it, end, block_end)) {
1536
+ throw std::runtime_error("Failed to parse block end");
1537
+ }
1538
+ consume_spaces(it, end);
1539
+ } else {
1540
+ // Not a valid tool call, treat as content
1541
+ msg.content += std::string(match[0].first, match[0].second);
1542
+ it = match[0].second;
1543
+ }
1544
+ }
1545
+ } else {
1546
+ // Add remaining content
1547
+ msg.content += std::string(it, end);
1548
+ break;
1549
+ }
1550
+ }
1551
+ return msg;
1552
+ } catch (const std::exception & e) {
1553
+ LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
1554
+ common_chat_msg msg;
1555
+ msg.role = "assistant";
1556
+ msg.content = input;
1557
+ return msg;
1558
+ }
1559
+ });
1560
+ }
1561
+
1562
+ static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
1563
+ common_chat_params data;
1564
+ data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
1565
+ data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
1566
+ data.grammar_lazy = false;
1567
+ if (!inputs.json_schema.is_null()) {
1568
+ if (!inputs.grammar.empty()) {
1569
+ throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
1570
+ }
1571
+ data.grammar = json_schema_to_grammar(inputs.json_schema);
1572
+ } else {
1573
+ data.grammar = inputs.grammar;
1574
+ }
1575
+ return data;
1576
+ }
1577
+
1578
+ static common_chat_params common_chat_templates_apply_jinja(
1579
+ const struct common_chat_templates * tmpls,
1580
+ const struct common_chat_templates_inputs & inputs)
1581
+ {
1582
+ templates_params params;
1583
+ params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
1584
+ const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
1585
+ ? *tmpls->template_tool_use
1586
+ : *tmpls->template_default;
1587
+ const auto & src = tmpl.source();
1588
+ const auto & caps = tmpl.original_caps();
1589
+ params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
1590
+ params.add_generation_prompt = inputs.add_generation_prompt;
1591
+ params.extract_reasoning = inputs.extract_reasoning;
1592
+ params.tool_choice = inputs.tool_choice;
1593
+ params.grammar = inputs.grammar;
1594
+ if (!inputs.json_schema.empty()) {
1595
+ params.json_schema = json::parse(inputs.json_schema);
1596
+ }
1597
+
1598
+ if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
1599
+ LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
1600
+ params.parallel_tool_calls = false;
1601
+ } else {
1602
+ params.parallel_tool_calls = inputs.parallel_tool_calls;
1603
+ }
1604
+
1605
+ if (params.tools.is_array()) {
1606
+ if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
1607
+ throw std::runtime_error("Cannot specify grammar with tools");
1608
+ }
1609
+ if (caps.supports_tool_calls && !caps.supports_tools) {
1610
+ LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n");
1611
+ }
1612
+ }
1613
+
1614
+ // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
1615
+ if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
1616
+ return common_chat_params_init_deepseek_r1(tmpl, params);
1617
+ }
1618
+
1619
+ // Command R7B: : use handler in all cases except json schema (thinking / tools).
1620
+ if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
1621
+ return common_chat_params_init_command_r7b(tmpl, params);
1622
+ }
1623
+
1624
+ // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
1625
+ if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
1626
+ return common_chat_params_init_hermes_2_pro(tmpl, params);
1627
+ }
1628
+
1629
+ // Use generic handler when mixing tools + JSON schema.
1630
+ // TODO: support that mix in handlers below.
1631
+ if ((params.tools.is_array() && params.json_schema.is_object())) {
1632
+ return common_chat_params_init_generic(tmpl, params);
1633
+ }
1634
+
1635
+ // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
1636
+ if (src.find(">>>all") != std::string::npos) {
1637
+ return common_chat_params_init_functionary_v3_2(tmpl, params);
1638
+ }
1639
+
1640
+ // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
1641
+ if (src.find(" functools[") != std::string::npos) {
1642
+ return common_chat_params_init_firefunction_v2(tmpl, params);
1643
+ }
1644
+
1645
+ // Plain handler (no tools)
1646
+ if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
1647
+ return common_chat_params_init_without_tools(tmpl, params);
1648
+ }
1649
+
1650
+ // Functionary v3.1 (w/ tools)
1651
+ if (src.find("<|start_header_id|>") != std::string::npos
1652
+ && src.find("<function=") != std::string::npos) {
1653
+ return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
1654
+ }
1655
+
1656
+ // Llama 3.1, 3.2, 3.3 (w/ tools)
1657
+ if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
1658
+ auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
1659
+ return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
1660
+ }
1661
+
1662
+ // Mistral Nemo (w/ tools)
1663
+ if (src.find("[TOOL_CALLS]") != std::string::npos) {
1664
+ return common_chat_params_init_mistral_nemo(tmpl, params);
1665
+ }
1666
+
1667
+ // Generic fallback
1668
+ return common_chat_params_init_generic(tmpl, params);
1669
+ }
1670
+
1671
+ // Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
1672
+ static common_chat_params common_chat_templates_apply_legacy(
1673
+ const struct common_chat_templates * tmpls,
1674
+ const struct common_chat_templates_inputs & inputs)
1675
+ {
1676
+ int alloc_size = 0;
1677
+ std::vector<llama_chat_message> chat;
1678
+ std::vector<std::string> contents;
1679
+ for (const auto & msg : inputs.messages) {
1680
+ auto content = msg.content;
1681
+ for (const auto & part : msg.content_parts) {
1682
+ if (part.type != "text") {
1683
+ LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
1684
+ continue;
1685
+ }
1686
+ if (!content.empty()) {
1687
+ content += "\n";;
1688
+ }
1689
+ content += part.text;
1690
+ }
1691
+ contents.emplace_back(std::move(content));
1692
+ }
1693
+ for (size_t i = 0; i < contents.size(); ++i) {
1694
+ const auto & msg = inputs.messages[i];
1695
+ const auto & content = contents[i];
1696
+ chat.push_back({msg.role.c_str(), content.c_str()});
1697
+ alloc_size += (msg.role.size() + content.size()) * 1.25;
1698
+ }
1699
+
1700
+ std::vector<char> buf(alloc_size);
1701
+
1702
+ // run the first time to get the total output length
1703
+ const auto & src = tmpls->template_default->source();
1704
+ int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
1705
+
1706
+ // error: chat template is not supported
1707
+ if (res < 0) {
1708
+ // if the custom "tmpl" is not supported, we throw an error
1709
+ // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
1710
+ throw std::runtime_error("this custom template is not supported");
1711
+ }
1712
+
1713
+ // if it turns out that our buffer is too small, we resize it
1714
+ if ((size_t) res > buf.size()) {
1715
+ buf.resize(res);
1716
+ res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
1717
+ }
1718
+
1719
+ common_chat_params params;
1720
+ params.prompt = std::string(buf.data(), res);
1721
+ if (!inputs.json_schema.empty()) {
1722
+ params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
1723
+ } else {
1724
+ params.grammar = inputs.grammar;
1725
+ }
1726
+ return params;
1727
+ }
1728
+
1729
+ common_chat_params common_chat_templates_apply(
1730
+ const struct common_chat_templates * tmpls,
1731
+ const struct common_chat_templates_inputs & inputs)
1732
+ {
1733
+ GGML_ASSERT(tmpls != nullptr);
1734
+ return inputs.use_jinja
1735
+ ? common_chat_templates_apply_jinja(tmpls, inputs)
1736
+ : common_chat_templates_apply_legacy(tmpls, inputs);
1737
+ }
1738
+
1739
+ static common_chat_msg common_chat_parse_content_only(const std::string & input) {
1740
+ common_chat_msg msg;
1741
+ msg.role = "assistant";
1742
+ msg.content = input;
1743
+ return msg;
1744
+ }
1745
+
1746
+ common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
1747
+ switch (format) {
1748
+ case COMMON_CHAT_FORMAT_CONTENT_ONLY:
1749
+ return common_chat_parse_content_only(input);
1750
+ case COMMON_CHAT_FORMAT_GENERIC:
1751
+ return common_chat_parse_generic(input);
1752
+ case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
1753
+ return common_chat_parse_mistral_nemo(input);
1754
+ case COMMON_CHAT_FORMAT_LLAMA_3_X:
1755
+ return common_chat_parse_llama_3_1(input);
1756
+ case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
1757
+ return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true);
1758
+ case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
1759
+ return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false);
1760
+ case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING:
1761
+ return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true);
1762
+ case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
1763
+ return common_chat_parse_functionary_v3_2(input);
1764
+ case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
1765
+ return common_chat_parse_functionary_v3_1_llama_3_1(input);
1766
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO:
1767
+ return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false);
1768
+ case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING:
1769
+ return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true);
1770
+ case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
1771
+ return common_chat_parse_firefunction_v2(input);
1772
+ case COMMON_CHAT_FORMAT_COMMAND_R7B:
1773
+ return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false);
1774
+ case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING:
1775
+ return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true);
1776
+ default:
1777
+ throw std::runtime_error("Unsupported format: " + common_chat_format_name(format));
1778
+ }
1779
+ }
common/chat.h ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
2
+
3
+ #pragma once
4
+
5
+ #include "common.h"
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ struct common_chat_templates;
10
+
11
+ struct common_chat_tool_call {
12
+ std::string name;
13
+ std::string arguments;
14
+ std::string id;
15
+ };
16
+
17
+ struct common_chat_msg_content_part {
18
+ std::string type;
19
+ std::string text;
20
+ };
21
+
22
+ struct common_chat_msg {
23
+ std::string role;
24
+ std::string content;
25
+ std::vector<common_chat_msg_content_part> content_parts = {};
26
+ std::vector<common_chat_tool_call> tool_calls = {};
27
+ std::string reasoning_content;
28
+ std::string tool_name;
29
+ std::string tool_call_id;
30
+ };
31
+
32
+ struct common_chat_tool {
33
+ std::string name;
34
+ std::string description;
35
+ std::string parameters;
36
+ };
37
+
38
+ enum common_chat_tool_choice {
39
+ COMMON_CHAT_TOOL_CHOICE_AUTO,
40
+ COMMON_CHAT_TOOL_CHOICE_REQUIRED,
41
+ COMMON_CHAT_TOOL_CHOICE_NONE,
42
+ };
43
+
44
+ enum common_chat_format {
45
+ COMMON_CHAT_FORMAT_CONTENT_ONLY,
46
+ COMMON_CHAT_FORMAT_GENERIC,
47
+ COMMON_CHAT_FORMAT_MISTRAL_NEMO,
48
+ COMMON_CHAT_FORMAT_LLAMA_3_X,
49
+ COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
50
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
51
+ COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
52
+ COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
53
+ COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
54
+ COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
55
+ COMMON_CHAT_FORMAT_HERMES_2_PRO,
56
+ COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
57
+ COMMON_CHAT_FORMAT_COMMAND_R7B,
58
+ COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
59
+
60
+ COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
61
+ };
62
+
63
+ struct common_chat_templates_inputs {
64
+ std::vector<common_chat_msg> messages;
65
+ std::string grammar;
66
+ std::string json_schema;
67
+ bool add_generation_prompt = true;
68
+ bool use_jinja = true;
69
+ // Parameters below only supported when use_jinja is true
70
+ std::vector<common_chat_tool> tools;
71
+ common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
72
+ bool parallel_tool_calls = false;
73
+ bool extract_reasoning = true;
74
+ };
75
+
76
+ struct common_chat_params {
77
+ common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
78
+ std::string prompt;
79
+ std::string grammar;
80
+ bool grammar_lazy = false;
81
+ std::vector<common_grammar_trigger> grammar_triggers;
82
+ std::vector<std::string> preserved_tokens;
83
+ std::vector<std::string> additional_stops;
84
+ };
85
+
86
+ // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
87
+ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
88
+
89
+ void common_chat_templates_free(struct common_chat_templates * tmpls);
90
+
91
+ struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
92
+
93
+ typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
94
+
95
+ common_chat_templates_ptr common_chat_templates_init(
96
+ const struct llama_model * model,
97
+ const std::string & chat_template_override,
98
+ const std::string & bos_token_override = "",
99
+ const std::string & eos_token_override = "");
100
+
101
+ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
102
+ const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
103
+
104
+
105
+ struct common_chat_params common_chat_templates_apply(
106
+ const struct common_chat_templates * tmpls,
107
+ const struct common_chat_templates_inputs & inputs);
108
+
109
+ // Format single message, while taking into account the position of that message in chat history
110
+ std::string common_chat_format_single(
111
+ const struct common_chat_templates * tmpls,
112
+ const std::vector<common_chat_msg> & past_msg,
113
+ const common_chat_msg & new_msg,
114
+ bool add_ass,
115
+ bool use_jinja);
116
+
117
+ // Returns an example of formatted chat
118
+ std::string common_chat_format_example(
119
+ const struct common_chat_templates * tmpls,
120
+ bool use_jinja);
121
+
122
+ std::string common_chat_format_name(common_chat_format format);
123
+ common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
124
+
125
+ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
126
+
127
+ // Parses a JSON array of messages in OpenAI's chat completion API format.
128
+ // T can be std::string containing JSON or nlohmann::ordered_json
129
+ template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
130
+ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
131
+
132
+ // Parses a JSON array of tools in OpenAI's chat completion tool call API format.
133
+ // T can be std::string containing JSON or nlohmann::ordered_json
134
+ template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
135
+ template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
common/common.cpp ADDED
@@ -0,0 +1,2058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if defined(_MSC_VER)
2
+ #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
3
+ #endif
4
+
5
+ #include "ggml.h"
6
+ #include "gguf.h"
7
+
8
+ #include "common.h"
9
+ #include "log.h"
10
+ #include "build-info.h"
11
+ #include "log.cpp"
12
+ // Change JSON_ASSERT from assert() to GGML_ASSERT:
13
+ #define JSON_ASSERT GGML_ASSERT
14
+ #include "json.hpp"
15
+ #include "json-schema-to-grammar.cpp"
16
+ #include "llama.h"
17
+ #include "chat.cpp"
18
+
19
+ #include <algorithm>
20
+ #include <cinttypes>
21
+ #include <climits>
22
+ #include <cmath>
23
+ #include <codecvt>
24
+ #include <cstdarg>
25
+ #include <cstring>
26
+ #include <ctime>
27
+ #include <filesystem>
28
+ #include <fstream>
29
+ #include <iostream>
30
+ #include <iterator>
31
+ #include <regex>
32
+ #include <sstream>
33
+ #include <string>
34
+ #include <thread>
35
+ #include <unordered_map>
36
+ #include <unordered_set>
37
+ #include <vector>
38
+
39
+ #if defined(__APPLE__) && defined(__MACH__)
40
+ #include <sys/types.h>
41
+ #include <sys/sysctl.h>
42
+ #endif
43
+
44
+ #if defined(_WIN32)
45
+ #define WIN32_LEAN_AND_MEAN
46
+ #ifndef NOMINMAX
47
+ # define NOMINMAX
48
+ #endif
49
+ #include <locale>
50
+ #include <windows.h>
51
+ #include <fcntl.h>
52
+ #include <io.h>
53
+ #else
54
+ #include <sys/ioctl.h>
55
+ #include <sys/stat.h>
56
+ #include <unistd.h>
57
+ #endif
58
+ #if defined(LLAMA_USE_CURL)
59
+ #include <curl/curl.h>
60
+ #include <curl/easy.h>
61
+ #include <future>
62
+ #endif
63
+
64
+ #if defined(_MSC_VER)
65
+ #pragma warning(disable: 4244 4267) // possible loss of data
66
+ #endif
67
+
68
+ #if defined(LLAMA_USE_CURL)
69
+ #ifdef __linux__
70
+ #include <linux/limits.h>
71
+ #elif defined(_WIN32)
72
+ # if !defined(PATH_MAX)
73
+ # define PATH_MAX MAX_PATH
74
+ # endif
75
+ #else
76
+ #include <sys/syslimits.h>
77
+ #endif
78
+ #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
79
+
80
+ //
81
+ // CURL utils
82
+ //
83
+
84
+ using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
85
+
86
+ // cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
87
+ struct curl_slist_ptr {
88
+ struct curl_slist * ptr = nullptr;
89
+ ~curl_slist_ptr() {
90
+ if (ptr) {
91
+ curl_slist_free_all(ptr);
92
+ }
93
+ }
94
+ };
95
+ #endif // LLAMA_USE_CURL
96
+
97
+ using json = nlohmann::ordered_json;
98
+
99
+ //
100
+ // CPU utils
101
+ //
102
+
103
+ int32_t cpu_get_num_physical_cores() {
104
+ #ifdef __linux__
105
+ // enumerate the set of thread siblings, num entries is num cores
106
+ std::unordered_set<std::string> siblings;
107
+ for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
108
+ std::ifstream thread_siblings("/sys/devices/system/cpu/cpu"
109
+ + std::to_string(cpu) + "/topology/thread_siblings");
110
+ if (!thread_siblings.is_open()) {
111
+ break; // no more cpus
112
+ }
113
+ std::string line;
114
+ if (std::getline(thread_siblings, line)) {
115
+ siblings.insert(line);
116
+ }
117
+ }
118
+ if (!siblings.empty()) {
119
+ return static_cast<int32_t>(siblings.size());
120
+ }
121
+ #elif defined(__APPLE__) && defined(__MACH__)
122
+ int32_t num_physical_cores;
123
+ size_t len = sizeof(num_physical_cores);
124
+ int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
125
+ if (result == 0) {
126
+ return num_physical_cores;
127
+ }
128
+ result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
129
+ if (result == 0) {
130
+ return num_physical_cores;
131
+ }
132
+ #elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
133
+ // TODO: windows + arm64 + mingw64
134
+ unsigned int n_threads_win = std::thread::hardware_concurrency();
135
+ unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;
136
+
137
+ DWORD buffer_size = 0;
138
+ if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
139
+ if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
140
+ return default_threads;
141
+ }
142
+ }
143
+
144
+ std::vector<char> buffer(buffer_size);
145
+ if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) {
146
+ return default_threads;
147
+ }
148
+
149
+ int32_t num_physical_cores = 0;
150
+ PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data());
151
+ while (buffer_size > 0) {
152
+ if (info->Relationship == RelationProcessorCore) {
153
+ num_physical_cores += info->Processor.GroupCount;
154
+ }
155
+ buffer_size -= info->Size;
156
+ info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size);
157
+ }
158
+
159
+ return num_physical_cores > 0 ? num_physical_cores : default_threads;
160
+ #endif
161
+ unsigned int n_threads = std::thread::hardware_concurrency();
162
+ return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
163
+ }
164
+
165
+ #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
166
+ #include <pthread.h>
167
+
168
+ static void cpuid(unsigned leaf, unsigned subleaf,
169
+ unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
170
+ __asm__("movq\t%%rbx,%%rsi\n\t"
171
+ "cpuid\n\t"
172
+ "xchgq\t%%rbx,%%rsi"
173
+ : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
174
+ : "0"(leaf), "2"(subleaf));
175
+ }
176
+
177
+ static int pin_cpu(int cpu) {
178
+ cpu_set_t mask;
179
+ CPU_ZERO(&mask);
180
+ CPU_SET(cpu, &mask);
181
+ return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
182
+ }
183
+
184
+ static bool is_hybrid_cpu(void) {
185
+ unsigned eax, ebx, ecx, edx;
186
+ cpuid(7, 0, &eax, &ebx, &ecx, &edx);
187
+ return !!(edx & (1u << 15));
188
+ }
189
+
190
+ static bool is_running_on_efficiency_core(void) {
191
+ unsigned eax, ebx, ecx, edx;
192
+ cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
193
+ int intel_atom = 0x20;
194
+ int core_type = (eax & 0xff000000u) >> 24;
195
+ return core_type == intel_atom;
196
+ }
197
+
198
+ static int cpu_count_math_cpus(int n_cpu) {
199
+ int result = 0;
200
+ for (int cpu = 0; cpu < n_cpu; ++cpu) {
201
+ if (pin_cpu(cpu)) {
202
+ return -1;
203
+ }
204
+ if (is_running_on_efficiency_core()) {
205
+ continue; // efficiency cores harm lockstep threading
206
+ }
207
+ ++cpu; // hyperthreading isn't useful for linear algebra
208
+ ++result;
209
+ }
210
+ return result;
211
+ }
212
+
213
+ #endif // __x86_64__ && __linux__
214
+
215
+ /**
216
+ * Returns number of CPUs on system that are useful for math.
217
+ */
218
+ int32_t cpu_get_num_math() {
219
+ #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
220
+ int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
221
+ if (n_cpu < 1) {
222
+ return cpu_get_num_physical_cores();
223
+ }
224
+ if (is_hybrid_cpu()) {
225
+ cpu_set_t affinity;
226
+ if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
227
+ int result = cpu_count_math_cpus(n_cpu);
228
+ pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
229
+ if (result > 0) {
230
+ return result;
231
+ }
232
+ }
233
+ }
234
+ #endif
235
+ return cpu_get_num_physical_cores();
236
+ }
237
+
238
+ // Helper for setting process priority
239
+
240
+ #if defined(_WIN32)
241
+
242
+ bool set_process_priority(enum ggml_sched_priority prio) {
243
+ if (prio == GGML_SCHED_PRIO_NORMAL) {
244
+ return true;
245
+ }
246
+
247
+ DWORD p = NORMAL_PRIORITY_CLASS;
248
+ switch (prio) {
249
+ case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
250
+ case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
251
+ case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
252
+ case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS; break;
253
+ }
254
+
255
+ if (!SetPriorityClass(GetCurrentProcess(), p)) {
256
+ LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
257
+ return false;
258
+ }
259
+
260
+ return true;
261
+ }
262
+
263
+ #else // MacOS and POSIX
264
+ #include <sys/types.h>
265
+ #include <sys/resource.h>
266
+
267
+ bool set_process_priority(enum ggml_sched_priority prio) {
268
+ if (prio == GGML_SCHED_PRIO_NORMAL) {
269
+ return true;
270
+ }
271
+
272
+ int p = 0;
273
+ switch (prio) {
274
+ case GGML_SCHED_PRIO_NORMAL: p = 0; break;
275
+ case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
276
+ case GGML_SCHED_PRIO_HIGH: p = -10; break;
277
+ case GGML_SCHED_PRIO_REALTIME: p = -20; break;
278
+ }
279
+
280
+ if (!setpriority(PRIO_PROCESS, 0, p)) {
281
+ LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
282
+ return false;
283
+ }
284
+ return true;
285
+ }
286
+
287
+ #endif
288
+
289
+ //
290
+ // CLI argument parsing
291
+ //
292
+
293
+
294
+ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
295
+ int32_t n_set = 0;
296
+
297
+ if (cpuparams.n_threads < 0) {
298
+ // Assuming everything about cpuparams is invalid
299
+ if (role_model != nullptr) {
300
+ cpuparams = *role_model;
301
+ } else {
302
+ cpuparams.n_threads = cpu_get_num_math();
303
+ }
304
+ }
305
+
306
+ for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
307
+ if (cpuparams.cpumask[i]) {
308
+ n_set++;
309
+ }
310
+ }
311
+
312
+ if (n_set && n_set < cpuparams.n_threads) {
313
+ // Not enough set bits, may experience performance issues.
314
+ LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
315
+ }
316
+ }
317
+
318
+ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
319
+ size_t dash_loc = range.find('-');
320
+ if (dash_loc == std::string::npos) {
321
+ LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
322
+ return false;
323
+ }
324
+
325
+ size_t start_i;
326
+ size_t end_i;
327
+
328
+ if (dash_loc == 0) {
329
+ start_i = 0;
330
+ } else {
331
+ start_i = std::stoull(range.substr(0, dash_loc));
332
+ if (start_i >= GGML_MAX_N_THREADS) {
333
+ LOG_ERR("Start index out of bounds!\n");
334
+ return false;
335
+ }
336
+ }
337
+
338
+ if (dash_loc == range.length() - 1) {
339
+ end_i = GGML_MAX_N_THREADS - 1;
340
+ } else {
341
+ end_i = std::stoull(range.substr(dash_loc + 1));
342
+ if (end_i >= GGML_MAX_N_THREADS) {
343
+ LOG_ERR("End index out of bounds!\n");
344
+ return false;
345
+ }
346
+ }
347
+
348
+ for (size_t i = start_i; i <= end_i; i++) {
349
+ boolmask[i] = true;
350
+ }
351
+
352
+ return true;
353
+ }
354
+
355
+ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) {
356
+ // Discard potential 0x prefix
357
+ size_t start_i = 0;
358
+ if (mask.length() >= 2 && mask.substr(0, 2) == "0x") {
359
+ start_i = 2;
360
+ }
361
+
362
+ size_t num_digits = mask.length() - start_i;
363
+ if (num_digits > 128) num_digits = 128;
364
+
365
+ size_t end_i = num_digits + start_i;
366
+
367
+ for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) {
368
+ char c = mask.at(i);
369
+ int8_t id = c;
370
+
371
+ if ((c >= '0' && c <= '9')) {
372
+ id -= '0';
373
+ } else if (c >= 'a' && c <= 'f') {
374
+ id -= 'a' - 10;
375
+ } else if (c >= 'A' && c <= 'F') {
376
+ id -= 'A' - 10;
377
+ } else {
378
+ LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
379
+ return false;
380
+ }
381
+
382
+ boolmask[ n ] = boolmask[ n ] || ((id & 8) != 0);
383
+ boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0);
384
+ boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0);
385
+ boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0);
386
+ }
387
+
388
+ return true;
389
+ }
390
+
391
+ void common_init() {
392
+ llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
393
+ if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
394
+ common_log_add(common_log_main(), level, "%s", text);
395
+ }
396
+ }, NULL);
397
+
398
+ #ifdef NDEBUG
399
+ const char * build_type = "";
400
+ #else
401
+ const char * build_type = " (debug)";
402
+ #endif
403
+
404
+ LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
405
+ }
406
+
407
+ std::string common_params_get_system_info(const common_params & params) {
408
+ std::ostringstream os;
409
+
410
+ os << "system_info: n_threads = " << params.cpuparams.n_threads;
411
+ if (params.cpuparams_batch.n_threads != -1) {
412
+ os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")";
413
+ }
414
+ #if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
415
+ // TODO: windows + arm64 + mingw64
416
+ DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
417
+ os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
418
+ #else
419
+ os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
420
+ #endif
421
+
422
+ return os.str();
423
+ }
424
+
425
+ //
426
+ // String utils
427
+ //
428
+
429
+ std::string string_format(const char * fmt, ...) {
430
+ va_list ap;
431
+ va_list ap2;
432
+ va_start(ap, fmt);
433
+ va_copy(ap2, ap);
434
+ int size = vsnprintf(NULL, 0, fmt, ap);
435
+ GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
436
+ std::vector<char> buf(size + 1);
437
+ int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
438
+ GGML_ASSERT(size2 == size);
439
+ va_end(ap2);
440
+ va_end(ap);
441
+ return std::string(buf.data(), size);
442
+ }
443
+
444
+ std::string string_strip(const std::string & str) {
445
+ size_t start = 0;
446
+ size_t end = str.size();
447
+ while (start < end && std::isspace(str[start])) {
448
+ start++;
449
+ }
450
+ while (end > start && std::isspace(str[end - 1])) {
451
+ end--;
452
+ }
453
+ return str.substr(start, end - start);
454
+ }
455
+
456
+ std::string string_get_sortable_timestamp() {
457
+ using clock = std::chrono::system_clock;
458
+
459
+ const clock::time_point current_time = clock::now();
460
+ const time_t as_time_t = clock::to_time_t(current_time);
461
+ char timestamp_no_ns[100];
462
+ std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
463
+
464
+ const int64_t ns = std::chrono::duration_cast<std::chrono::nanoseconds>(
465
+ current_time.time_since_epoch() % 1000000000).count();
466
+ char timestamp_ns[11];
467
+ snprintf(timestamp_ns, 11, "%09" PRId64, ns);
468
+
469
+ return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
470
+ }
471
+
472
+ void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
473
+ if (search.empty()) {
474
+ return;
475
+ }
476
+ std::string builder;
477
+ builder.reserve(s.length());
478
+ size_t pos = 0;
479
+ size_t last_pos = 0;
480
+ while ((pos = s.find(search, last_pos)) != std::string::npos) {
481
+ builder.append(s, last_pos, pos - last_pos);
482
+ builder.append(replace);
483
+ last_pos = pos + search.length();
484
+ }
485
+ builder.append(s, last_pos, std::string::npos);
486
+ s = std::move(builder);
487
+ }
488
+
489
+ std::string regex_escape(const std::string & s) {
490
+ static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
491
+ return std::regex_replace(s, special_chars, "\\$0");
492
+ }
493
+
494
+ std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
495
+ std::ostringstream result;
496
+ for (size_t i = 0; i < values.size(); ++i) {
497
+ if (i > 0) {
498
+ result << separator;
499
+ }
500
+ result << values[i];
501
+ }
502
+ return result.str();
503
+ }
504
+
505
+ std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
506
+ std::vector<std::string> parts;
507
+ size_t start = 0;
508
+ size_t end = str.find(delimiter);
509
+
510
+ while (end != std::string::npos) {
511
+ parts.push_back(str.substr(start, end - start));
512
+ start = end + delimiter.length();
513
+ end = str.find(delimiter, start);
514
+ }
515
+
516
+ parts.push_back(str.substr(start));
517
+
518
+ return parts;
519
+ }
520
+
521
+ std::string string_repeat(const std::string & str, size_t n) {
522
+ if (n == 0) {
523
+ return "";
524
+ }
525
+
526
+ std::string result;
527
+ result.reserve(str.length() * n);
528
+
529
+ for (size_t i = 0; i < n; ++i) {
530
+ result += str;
531
+ }
532
+
533
+ return result;
534
+ }
535
+
536
+ std::string string_from(bool value) {
537
+ return value ? "true" : "false";
538
+ }
539
+
540
+ std::string string_from(const std::vector<int> & values) {
541
+ std::stringstream buf;
542
+
543
+ buf << "[ ";
544
+ bool first = true;
545
+ for (auto e : values) {
546
+ if (first) {
547
+ first = false;
548
+ } else {
549
+ buf << ", ";
550
+ }
551
+ buf << std::to_string(e);
552
+ }
553
+ buf << " ]";
554
+
555
+ return buf.str();
556
+ }
557
+
558
+ std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
559
+ std::stringstream buf;
560
+
561
+ buf << "[ ";
562
+
563
+ bool first = true;
564
+ for (const auto & token : tokens) {
565
+ if (!first) {
566
+ buf << ", ";
567
+ } else {
568
+ first = false;
569
+ }
570
+
571
+ auto detokenized = common_token_to_piece(ctx, token);
572
+
573
+ detokenized.erase(
574
+ std::remove_if(
575
+ detokenized.begin(),
576
+ detokenized.end(),
577
+ [](const unsigned char c) { return !std::isprint(c); }),
578
+ detokenized.end());
579
+
580
+ buf << "'" << detokenized << "'"
581
+ << ":" << std::to_string(token);
582
+ }
583
+
584
+ buf << " ]";
585
+
586
+ return buf.str();
587
+ }
588
+
589
+ std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
590
+ std::stringstream buf;
591
+
592
+ buf << "[ ";
593
+
594
+ bool first = true;
595
+ for (int i = 0; i < batch.n_tokens; ++i) {
596
+ if (!first) {
597
+ buf << ", ";
598
+ } else {
599
+ first = false;
600
+ }
601
+
602
+ auto detokenized = common_token_to_piece(ctx, batch.token[i]);
603
+
604
+ detokenized.erase(
605
+ std::remove_if(
606
+ detokenized.begin(),
607
+ detokenized.end(),
608
+ [](const unsigned char c) { return !std::isprint(c); }),
609
+ detokenized.end());
610
+
611
+ buf << "\n" << std::to_string(i)
612
+ << ", token '" << detokenized << "'"
613
+ << ", pos " << std::to_string(batch.pos[i])
614
+ << ", n_seq_id " << std::to_string(batch.n_seq_id[i])
615
+ << ", seq_id " << std::to_string(batch.seq_id[i][0])
616
+ << ", logits " << std::to_string(batch.logits[i]);
617
+ }
618
+
619
+ buf << " ]";
620
+
621
+ return buf.str();
622
+ }
623
+
624
+ void string_process_escapes(std::string & input) {
625
+ std::size_t input_len = input.length();
626
+ std::size_t output_idx = 0;
627
+
628
+ for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
629
+ if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
630
+ switch (input[++input_idx]) {
631
+ case 'n': input[output_idx++] = '\n'; break;
632
+ case 'r': input[output_idx++] = '\r'; break;
633
+ case 't': input[output_idx++] = '\t'; break;
634
+ case '\'': input[output_idx++] = '\''; break;
635
+ case '\"': input[output_idx++] = '\"'; break;
636
+ case '\\': input[output_idx++] = '\\'; break;
637
+ case 'x':
638
+ // Handle \x12, etc
639
+ if (input_idx + 2 < input_len) {
640
+ const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 };
641
+ char *err_p = nullptr;
642
+ const long val = std::strtol(x, &err_p, 16);
643
+ if (err_p == x + 2) {
644
+ input_idx += 2;
645
+ input[output_idx++] = char(val);
646
+ break;
647
+ }
648
+ }
649
+ // fall through
650
+ default: input[output_idx++] = '\\';
651
+ input[output_idx++] = input[input_idx]; break;
652
+ }
653
+ } else {
654
+ input[output_idx++] = input[input_idx];
655
+ }
656
+ }
657
+
658
+ input.resize(output_idx);
659
+ }
660
+
661
+ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
662
+ const char * sep = strchr(data, '=');
663
+ if (sep == nullptr || sep - data >= 128) {
664
+ LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
665
+ return false;
666
+ }
667
+ llama_model_kv_override kvo;
668
+ std::strncpy(kvo.key, data, sep - data);
669
+ kvo.key[sep - data] = 0;
670
+ sep++;
671
+ if (strncmp(sep, "int:", 4) == 0) {
672
+ sep += 4;
673
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
674
+ kvo.val_i64 = std::atol(sep);
675
+ } else if (strncmp(sep, "float:", 6) == 0) {
676
+ sep += 6;
677
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
678
+ kvo.val_f64 = std::atof(sep);
679
+ } else if (strncmp(sep, "bool:", 5) == 0) {
680
+ sep += 5;
681
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
682
+ if (std::strcmp(sep, "true") == 0) {
683
+ kvo.val_bool = true;
684
+ } else if (std::strcmp(sep, "false") == 0) {
685
+ kvo.val_bool = false;
686
+ } else {
687
+ LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
688
+ return false;
689
+ }
690
+ } else if (strncmp(sep, "str:", 4) == 0) {
691
+ sep += 4;
692
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
693
+ if (strlen(sep) > 127) {
694
+ LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
695
+ return false;
696
+ }
697
+ strncpy(kvo.val_str, sep, 127);
698
+ kvo.val_str[127] = '\0';
699
+ } else {
700
+ LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
701
+ return false;
702
+ }
703
+ overrides.emplace_back(std::move(kvo));
704
+ return true;
705
+ }
706
+
707
+ //
708
+ // Filesystem utils
709
+ //
710
+
711
+ // Validate if a filename is safe to use
712
+ // To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
713
+ bool fs_validate_filename(const std::string & filename) {
714
+ if (!filename.length()) {
715
+ // Empty filename invalid
716
+ return false;
717
+ }
718
+ if (filename.length() > 255) {
719
+ // Limit at common largest possible filename on Linux filesystems
720
+ // to avoid unnecessary further validation
721
+ // (On systems with smaller limits it will be caught by the OS)
722
+ return false;
723
+ }
724
+
725
+ std::u32string filename_utf32;
726
+ try {
727
+ #if defined(__clang__)
728
+ // disable C++17 deprecation warning for std::codecvt_utf8
729
+ # pragma clang diagnostic push
730
+ # pragma clang diagnostic ignored "-Wdeprecated-declarations"
731
+ #endif
732
+ std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
733
+
734
+ #if defined(__clang__)
735
+ # pragma clang diagnostic pop
736
+ #endif
737
+
738
+ filename_utf32 = converter.from_bytes(filename);
739
+
740
+ // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
741
+ // or invalid encodings were encountered. Reject such attempts
742
+ std::string filename_reencoded = converter.to_bytes(filename_utf32);
743
+ if (filename_reencoded != filename) {
744
+ return false;
745
+ }
746
+ } catch (const std::exception &) {
747
+ return false;
748
+ }
749
+
750
+ // Check for forbidden codepoints:
751
+ // - Control characters
752
+ // - Unicode equivalents of illegal characters
753
+ // - UTF-16 surrogate pairs
754
+ // - UTF-8 replacement character
755
+ // - Byte order mark (BOM)
756
+ // - Illegal characters: / \ : * ? " < > |
757
+ for (char32_t c : filename_utf32) {
758
+ if (c <= 0x1F // Control characters (C0)
759
+ || c == 0x7F // Control characters (DEL)
760
+ || (c >= 0x80 && c <= 0x9F) // Control characters (C1)
761
+ || c == 0xFF0E // Fullwidth Full Stop (period equivalent)
762
+ || c == 0x2215 // Division Slash (forward slash equivalent)
763
+ || c == 0x2216 // Set Minus (backslash equivalent)
764
+ || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
765
+ || c == 0xFFFD // Replacement Character (UTF-8)
766
+ || c == 0xFEFF // Byte Order Mark (BOM)
767
+ || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
768
+ || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
769
+ return false;
770
+ }
771
+ }
772
+
773
+ // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
774
+ // Unicode and other whitespace is not affected, only 0x20 space
775
+ if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
776
+ return false;
777
+ }
778
+
779
+ // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
780
+ if (filename.find("..") != std::string::npos) {
781
+ return false;
782
+ }
783
+
784
+ // Reject "."
785
+ if (filename == ".") {
786
+ return false;
787
+ }
788
+
789
+ return true;
790
+ }
791
+
792
+ // returns true if successful, false otherwise
793
+ bool fs_create_directory_with_parents(const std::string & path) {
794
+ #ifdef _WIN32
795
+ std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
796
+ std::wstring wpath = converter.from_bytes(path);
797
+
798
+ // if the path already exists, check whether it's a directory
799
+ const DWORD attributes = GetFileAttributesW(wpath.c_str());
800
+ if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
801
+ return true;
802
+ }
803
+
804
+ size_t pos_slash = 0;
805
+
806
+ // process path from front to back, procedurally creating directories
807
+ while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
808
+ const std::wstring subpath = wpath.substr(0, pos_slash);
809
+ const wchar_t * test = subpath.c_str();
810
+
811
+ const bool success = CreateDirectoryW(test, NULL);
812
+ if (!success) {
813
+ const DWORD error = GetLastError();
814
+
815
+ // if the path already exists, ensure that it's a directory
816
+ if (error == ERROR_ALREADY_EXISTS) {
817
+ const DWORD attributes = GetFileAttributesW(subpath.c_str());
818
+ if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
819
+ return false;
820
+ }
821
+ } else {
822
+ return false;
823
+ }
824
+ }
825
+
826
+ pos_slash += 1;
827
+ }
828
+
829
+ return true;
830
+ #else
831
+ // if the path already exists, check whether it's a directory
832
+ struct stat info;
833
+ if (stat(path.c_str(), &info) == 0) {
834
+ return S_ISDIR(info.st_mode);
835
+ }
836
+
837
+ size_t pos_slash = 1; // skip leading slashes for directory creation
838
+
839
+ // process path from front to back, procedurally creating directories
840
+ while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
841
+ const std::string subpath = path.substr(0, pos_slash);
842
+ struct stat info;
843
+
844
+ // if the path already exists, ensure that it's a directory
845
+ if (stat(subpath.c_str(), &info) == 0) {
846
+ if (!S_ISDIR(info.st_mode)) {
847
+ return false;
848
+ }
849
+ } else {
850
+ // create parent directories
851
+ const int ret = mkdir(subpath.c_str(), 0755);
852
+ if (ret != 0) {
853
+ return false;
854
+ }
855
+ }
856
+
857
+ pos_slash += 1;
858
+ }
859
+
860
+ return true;
861
+ #endif // _WIN32
862
+ }
863
+
864
+ std::string fs_get_cache_directory() {
865
+ std::string cache_directory = "";
866
+ auto ensure_trailing_slash = [](std::string p) {
867
+ // Make sure to add trailing slash
868
+ if (p.back() != DIRECTORY_SEPARATOR) {
869
+ p += DIRECTORY_SEPARATOR;
870
+ }
871
+ return p;
872
+ };
873
+ if (getenv("LLAMA_CACHE")) {
874
+ cache_directory = std::getenv("LLAMA_CACHE");
875
+ } else {
876
+ #ifdef __linux__
877
+ if (std::getenv("XDG_CACHE_HOME")) {
878
+ cache_directory = std::getenv("XDG_CACHE_HOME");
879
+ } else {
880
+ cache_directory = std::getenv("HOME") + std::string("/.cache/");
881
+ }
882
+ #elif defined(__APPLE__)
883
+ cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
884
+ #elif defined(_WIN32)
885
+ cache_directory = std::getenv("LOCALAPPDATA");
886
+ #endif // __linux__
887
+ cache_directory = ensure_trailing_slash(cache_directory);
888
+ cache_directory += "llama.cpp";
889
+ }
890
+ return ensure_trailing_slash(cache_directory);
891
+ }
892
+
893
+ std::string fs_get_cache_file(const std::string & filename) {
894
+ GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos);
895
+ std::string cache_directory = fs_get_cache_directory();
896
+ const bool success = fs_create_directory_with_parents(cache_directory);
897
+ if (!success) {
898
+ throw std::runtime_error("failed to create cache directory: " + cache_directory);
899
+ }
900
+ return cache_directory + filename;
901
+ }
902
+
903
+
904
+ //
905
+ // Model utils
906
+ //
907
+ struct common_init_result common_init_from_params(common_params & params) {
908
+ common_init_result iparams;
909
+ auto mparams = common_model_params_to_llama(params);
910
+
911
+ llama_model * model = nullptr;
912
+
913
+ if (!params.hf_repo.empty() && !params.hf_file.empty()) {
914
+ model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
915
+ } else if (!params.model_url.empty()) {
916
+ model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
917
+ } else {
918
+ model = llama_model_load_from_file(params.model.c_str(), mparams);
919
+ }
920
+
921
+ if (model == NULL) {
922
+ LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
923
+ return iparams;
924
+ }
925
+
926
+ const llama_vocab * vocab = llama_model_get_vocab(model);
927
+
928
+ if (params.reranking) {
929
+ bool ok = true;
930
+
931
+ if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
932
+ LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
933
+ ok = false;
934
+ }
935
+
936
+ if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
937
+ LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
938
+ ok = false;
939
+ }
940
+
941
+ if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
942
+ LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
943
+ ok = false;
944
+ }
945
+
946
+ if (!ok) {
947
+ llama_model_free(model);
948
+
949
+ return iparams;
950
+ }
951
+ }
952
+
953
+ auto cparams = common_context_params_to_llama(params);
954
+
955
+ llama_context * lctx = llama_init_from_model(model, cparams);
956
+ if (lctx == NULL) {
957
+ LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
958
+ llama_model_free(model);
959
+ return iparams;
960
+ }
961
+
962
+ if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
963
+ LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
964
+ params.ctx_shift = false;
965
+ }
966
+
967
+ if (!params.control_vectors.empty()) {
968
+ if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
969
+ if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model);
970
+
971
+ const auto cvec = common_control_vector_load(params.control_vectors);
972
+ if (cvec.n_embd == -1) {
973
+ llama_free(lctx);
974
+ llama_model_free(model);
975
+
976
+ return iparams;
977
+ }
978
+
979
+ int err = llama_apply_adapter_cvec(
980
+ lctx,
981
+ cvec.data.data(),
982
+ cvec.data.size(),
983
+ cvec.n_embd,
984
+ params.control_vector_layer_start,
985
+ params.control_vector_layer_end);
986
+ if (err) {
987
+ llama_free(lctx);
988
+ llama_model_free(model);
989
+
990
+ return iparams;
991
+ }
992
+ }
993
+
994
+ // load and optionally apply lora adapters
995
+ for (auto & la : params.lora_adapters) {
996
+ llama_adapter_lora_ptr lora;
997
+ lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
998
+ if (lora == nullptr) {
999
+ LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
1000
+ llama_free(lctx);
1001
+ llama_model_free(model);
1002
+ return iparams;
1003
+ }
1004
+
1005
+ la.ptr = lora.get();
1006
+ iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
1007
+ }
1008
+
1009
+ if (!params.lora_init_without_apply) {
1010
+ common_set_adapter_lora(lctx, params.lora_adapters);
1011
+ }
1012
+
1013
+ if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
1014
+ LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
1015
+ params.sampling.ignore_eos = false;
1016
+ }
1017
+
1018
+ if (params.sampling.ignore_eos) {
1019
+ for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
1020
+ if (llama_vocab_is_eog(vocab, i)) {
1021
+ LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
1022
+ params.sampling.logit_bias.push_back({i, -INFINITY});
1023
+ }
1024
+ }
1025
+ }
1026
+
1027
+ if (params.sampling.penalty_last_n == -1) {
1028
+ LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1029
+ params.sampling.penalty_last_n = llama_n_ctx(lctx);
1030
+ }
1031
+
1032
+ if (params.sampling.dry_penalty_last_n == -1) {
1033
+ LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
1034
+ params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
1035
+ }
1036
+
1037
+ if (params.warmup) {
1038
+ LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
1039
+
1040
+ std::vector<llama_token> tmp;
1041
+ llama_token bos = llama_vocab_bos(vocab);
1042
+ llama_token eos = llama_vocab_eos(vocab);
1043
+
1044
+ // some models (e.g. T5) don't have a BOS token
1045
+ if (bos != LLAMA_TOKEN_NULL) {
1046
+ tmp.push_back(bos);
1047
+ }
1048
+ if (eos != LLAMA_TOKEN_NULL) {
1049
+ tmp.push_back(eos);
1050
+ }
1051
+ if (tmp.empty()) {
1052
+ tmp.push_back(0);
1053
+ }
1054
+
1055
+ if (llama_model_has_encoder(model)) {
1056
+ llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
1057
+ llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
1058
+ if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
1059
+ decoder_start_token_id = bos;
1060
+ }
1061
+ tmp.clear();
1062
+ tmp.push_back(decoder_start_token_id);
1063
+ }
1064
+ if (llama_model_has_decoder(model)) {
1065
+ llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
1066
+ }
1067
+ llama_kv_cache_clear(lctx);
1068
+ llama_synchronize(lctx);
1069
+ llama_perf_context_reset(lctx);
1070
+ }
1071
+
1072
+ iparams.model.reset(model);
1073
+ iparams.context.reset(lctx);
1074
+
1075
+ return iparams;
1076
+ }
1077
+
1078
+ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
1079
+ llama_clear_adapter_lora(ctx);
1080
+ for (auto & la : lora) {
1081
+ if (la.scale != 0.0f) {
1082
+ llama_set_adapter_lora(ctx, la.ptr, la.scale);
1083
+ }
1084
+ }
1085
+ }
1086
+
1087
+ struct llama_model_params common_model_params_to_llama(common_params & params) {
1088
+ auto mparams = llama_model_default_params();
1089
+
1090
+ if (!params.devices.empty()) {
1091
+ mparams.devices = params.devices.data();
1092
+ }
1093
+ if (params.n_gpu_layers != -1) {
1094
+ mparams.n_gpu_layers = params.n_gpu_layers;
1095
+ }
1096
+ mparams.main_gpu = params.main_gpu;
1097
+ mparams.split_mode = params.split_mode;
1098
+ mparams.tensor_split = params.tensor_split;
1099
+ mparams.use_mmap = params.use_mmap;
1100
+ mparams.use_mlock = params.use_mlock;
1101
+ mparams.check_tensors = params.check_tensors;
1102
+ if (params.kv_overrides.empty()) {
1103
+ mparams.kv_overrides = NULL;
1104
+ } else {
1105
+ GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
1106
+ mparams.kv_overrides = params.kv_overrides.data();
1107
+ }
1108
+
1109
+ return mparams;
1110
+ }
1111
+
1112
+ struct llama_context_params common_context_params_to_llama(const common_params & params) {
1113
+ auto cparams = llama_context_default_params();
1114
+
1115
+ cparams.n_ctx = params.n_ctx;
1116
+ cparams.n_seq_max = params.n_parallel;
1117
+ cparams.n_batch = params.n_batch;
1118
+ cparams.n_ubatch = params.n_ubatch;
1119
+ cparams.n_threads = params.cpuparams.n_threads;
1120
+ cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
1121
+ params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
1122
+ cparams.logits_all = params.logits_all;
1123
+ cparams.embeddings = params.embedding;
1124
+ cparams.rope_scaling_type = params.rope_scaling_type;
1125
+ cparams.rope_freq_base = params.rope_freq_base;
1126
+ cparams.rope_freq_scale = params.rope_freq_scale;
1127
+ cparams.yarn_ext_factor = params.yarn_ext_factor;
1128
+ cparams.yarn_attn_factor = params.yarn_attn_factor;
1129
+ cparams.yarn_beta_fast = params.yarn_beta_fast;
1130
+ cparams.yarn_beta_slow = params.yarn_beta_slow;
1131
+ cparams.yarn_orig_ctx = params.yarn_orig_ctx;
1132
+ cparams.pooling_type = params.pooling_type;
1133
+ cparams.attention_type = params.attention_type;
1134
+ cparams.defrag_thold = params.defrag_thold;
1135
+ cparams.cb_eval = params.cb_eval;
1136
+ cparams.cb_eval_user_data = params.cb_eval_user_data;
1137
+ cparams.offload_kqv = !params.no_kv_offload;
1138
+ cparams.flash_attn = params.flash_attn;
1139
+ cparams.no_perf = params.no_perf;
1140
+
1141
+ if (params.reranking) {
1142
+ cparams.embeddings = true;
1143
+ cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
1144
+ }
1145
+
1146
+ cparams.type_k = params.cache_type_k;
1147
+ cparams.type_v = params.cache_type_v;
1148
+
1149
+ return cparams;
1150
+ }
1151
+
1152
+ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
1153
+ struct ggml_threadpool_params tpp;
1154
+
1155
+ ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
1156
+
1157
+ if (params.mask_valid) {
1158
+ std::memcpy(&tpp.cpumask, &params.cpumask, GGML_MAX_N_THREADS);
1159
+ }
1160
+
1161
+ tpp.prio = params.priority;
1162
+ tpp.poll = params.poll;
1163
+ tpp.strict_cpu = params.strict_cpu;
1164
+
1165
+ return tpp;
1166
+ }
1167
+
1168
+ #ifdef LLAMA_USE_CURL
1169
+
1170
+ #define CURL_MAX_RETRY 3
1171
+ #define CURL_RETRY_DELAY_SECONDS 2
1172
+
1173
+ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
1174
+ int remaining_attempts = max_attempts;
1175
+
1176
+ while (remaining_attempts > 0) {
1177
+ LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
1178
+
1179
+ CURLcode res = curl_easy_perform(curl);
1180
+ if (res == CURLE_OK) {
1181
+ return true;
1182
+ }
1183
+
1184
+ int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
1185
+ LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
1186
+
1187
+ remaining_attempts--;
1188
+ std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
1189
+ }
1190
+
1191
+ LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
1192
+
1193
+ return false;
1194
+ }
1195
+
1196
+ static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
1197
+ // Initialize libcurl
1198
+ curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
1199
+ curl_slist_ptr http_headers;
1200
+ if (!curl) {
1201
+ LOG_ERR("%s: error initializing libcurl\n", __func__);
1202
+ return false;
1203
+ }
1204
+
1205
+ bool force_download = false;
1206
+
1207
+ // Set the URL, allow to follow http redirection
1208
+ curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
1209
+ curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
1210
+
1211
+ // Check if hf-token or bearer-token was specified
1212
+ if (!hf_token.empty()) {
1213
+ std::string auth_header = "Authorization: Bearer " + hf_token;
1214
+ http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
1215
+ curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
1216
+ }
1217
+
1218
+ #if defined(_WIN32)
1219
+ // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
1220
+ // operating system. Currently implemented under MS-Windows.
1221
+ curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
1222
+ #endif
1223
+
1224
+ // Check if the file already exists locally
1225
+ auto file_exists = std::filesystem::exists(path);
1226
+
1227
+ // If the file exists, check its JSON metadata companion file.
1228
+ std::string metadata_path = path + ".json";
1229
+ nlohmann::json metadata;
1230
+ std::string etag;
1231
+ std::string last_modified;
1232
+
1233
+ if (file_exists) {
1234
+ // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
1235
+ std::ifstream metadata_in(metadata_path);
1236
+ if (metadata_in.good()) {
1237
+ try {
1238
+ metadata_in >> metadata;
1239
+ LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
1240
+ if (metadata.contains("url") && metadata.at("url").is_string()) {
1241
+ auto previous_url = metadata.at("url").get<std::string>();
1242
+ if (previous_url != url) {
1243
+ LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
1244
+ return false;
1245
+ }
1246
+ }
1247
+ if (metadata.contains("etag") && metadata.at("etag").is_string()) {
1248
+ etag = metadata.at("etag");
1249
+ }
1250
+ if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
1251
+ last_modified = metadata.at("lastModified");
1252
+ }
1253
+ } catch (const nlohmann::json::exception & e) {
1254
+ LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
1255
+ return false;
1256
+ }
1257
+ }
1258
+ } else {
1259
+ LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
1260
+ }
1261
+
1262
+ // Send a HEAD request to retrieve the etag and last-modified headers
1263
+ struct common_load_model_from_url_headers {
1264
+ std::string etag;
1265
+ std::string last_modified;
1266
+ };
1267
+
1268
+ common_load_model_from_url_headers headers;
1269
+
1270
+ {
1271
+ typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
1272
+ auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
1273
+ common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
1274
+
1275
+ static std::regex header_regex("([^:]+): (.*)\r\n");
1276
+ static std::regex etag_regex("ETag", std::regex_constants::icase);
1277
+ static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
1278
+
1279
+ std::string header(buffer, n_items);
1280
+ std::smatch match;
1281
+ if (std::regex_match(header, match, header_regex)) {
1282
+ const std::string & key = match[1];
1283
+ const std::string & value = match[2];
1284
+ if (std::regex_match(key, match, etag_regex)) {
1285
+ headers->etag = value;
1286
+ } else if (std::regex_match(key, match, last_modified_regex)) {
1287
+ headers->last_modified = value;
1288
+ }
1289
+ }
1290
+ return n_items;
1291
+ };
1292
+
1293
+ curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
1294
+ curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
1295
+ curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
1296
+ curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
1297
+
1298
+ bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
1299
+ if (!was_perform_successful) {
1300
+ return false;
1301
+ }
1302
+
1303
+ long http_code = 0;
1304
+ curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
1305
+ if (http_code != 200) {
1306
+ // HEAD not supported, we don't know if the file has changed
1307
+ // force trigger downloading
1308
+ force_download = true;
1309
+ LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
1310
+ }
1311
+ }
1312
+
1313
+ bool should_download = !file_exists || force_download;
1314
+ if (!should_download) {
1315
+ if (!etag.empty() && etag != headers.etag) {
1316
+ LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
1317
+ should_download = true;
1318
+ } else if (!last_modified.empty() && last_modified != headers.last_modified) {
1319
+ LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
1320
+ should_download = true;
1321
+ }
1322
+ }
1323
+ if (should_download) {
1324
+ std::string path_temporary = path + ".downloadInProgress";
1325
+ if (file_exists) {
1326
+ LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
1327
+ if (remove(path.c_str()) != 0) {
1328
+ LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
1329
+ return false;
1330
+ }
1331
+ }
1332
+
1333
+ // Set the output file
1334
+
1335
+ struct FILE_deleter {
1336
+ void operator()(FILE * f) const {
1337
+ fclose(f);
1338
+ }
1339
+ };
1340
+
1341
+ std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
1342
+ if (!outfile) {
1343
+ LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
1344
+ return false;
1345
+ }
1346
+
1347
+ typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
1348
+ auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
1349
+ return fwrite(data, size, nmemb, (FILE *)fd);
1350
+ };
1351
+ curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
1352
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
1353
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
1354
+
1355
+ // display download progress
1356
+ curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
1357
+
1358
+ // helper function to hide password in URL
1359
+ auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
1360
+ std::size_t protocol_pos = url.find("://");
1361
+ if (protocol_pos == std::string::npos) {
1362
+ return url; // Malformed URL
1363
+ }
1364
+
1365
+ std::size_t at_pos = url.find('@', protocol_pos + 3);
1366
+ if (at_pos == std::string::npos) {
1367
+ return url; // No password in URL
1368
+ }
1369
+
1370
+ return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
1371
+ };
1372
+
1373
+ // start the download
1374
+ LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
1375
+ llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
1376
+ bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
1377
+ if (!was_perform_successful) {
1378
+ return false;
1379
+ }
1380
+
1381
+ long http_code = 0;
1382
+ curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
1383
+ if (http_code < 200 || http_code >= 400) {
1384
+ LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
1385
+ return false;
1386
+ }
1387
+
1388
+ // Causes file to be closed explicitly here before we rename it.
1389
+ outfile.reset();
1390
+
1391
+ // Write the updated JSON metadata file.
1392
+ metadata.update({
1393
+ {"url", url},
1394
+ {"etag", headers.etag},
1395
+ {"lastModified", headers.last_modified}
1396
+ });
1397
+ std::ofstream(metadata_path) << metadata.dump(4);
1398
+ LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
1399
+
1400
+ if (rename(path_temporary.c_str(), path.c_str()) != 0) {
1401
+ LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
1402
+ return false;
1403
+ }
1404
+ }
1405
+
1406
+ return true;
1407
+ }
1408
+
1409
+ struct llama_model * common_load_model_from_url(
1410
+ const std::string & model_url,
1411
+ const std::string & local_path,
1412
+ const std::string & hf_token,
1413
+ const struct llama_model_params & params) {
1414
+ // Basic validation of the model_url
1415
+ if (model_url.empty()) {
1416
+ LOG_ERR("%s: invalid model_url\n", __func__);
1417
+ return NULL;
1418
+ }
1419
+
1420
+ if (!common_download_file(model_url, local_path, hf_token)) {
1421
+ return NULL;
1422
+ }
1423
+
1424
+ // check for additional GGUFs split to download
1425
+ int n_split = 0;
1426
+ {
1427
+ struct gguf_init_params gguf_params = {
1428
+ /*.no_alloc = */ true,
1429
+ /*.ctx = */ NULL,
1430
+ };
1431
+ auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
1432
+ if (!ctx_gguf) {
1433
+ LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str());
1434
+ return NULL;
1435
+ }
1436
+
1437
+ auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
1438
+ if (key_n_split >= 0) {
1439
+ n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
1440
+ }
1441
+
1442
+ gguf_free(ctx_gguf);
1443
+ }
1444
+
1445
+ if (n_split > 1) {
1446
+ char split_prefix[PATH_MAX] = {0};
1447
+ char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
1448
+
1449
+ // Verify the first split file format
1450
+ // and extract split URL and PATH prefixes
1451
+ {
1452
+ if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
1453
+ LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
1454
+ return NULL;
1455
+ }
1456
+
1457
+ if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
1458
+ LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
1459
+ return NULL;
1460
+ }
1461
+ }
1462
+
1463
+ // Prepare download in parallel
1464
+ std::vector<std::future<bool>> futures_download;
1465
+ for (int idx = 1; idx < n_split; idx++) {
1466
+ futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
1467
+ char split_path[PATH_MAX] = {0};
1468
+ llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
1469
+
1470
+ char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
1471
+ llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
1472
+
1473
+ return common_download_file(split_url, split_path, hf_token);
1474
+ }, idx));
1475
+ }
1476
+
1477
+ // Wait for all downloads to complete
1478
+ for (auto & f : futures_download) {
1479
+ if (!f.get()) {
1480
+ return NULL;
1481
+ }
1482
+ }
1483
+ }
1484
+
1485
+ return llama_model_load_from_file(local_path.c_str(), params);
1486
+ }
1487
+
1488
+ struct llama_model * common_load_model_from_hf(
1489
+ const std::string & repo,
1490
+ const std::string & remote_path,
1491
+ const std::string & local_path,
1492
+ const std::string & hf_token,
1493
+ const struct llama_model_params & params) {
1494
+ // construct hugging face model url:
1495
+ //
1496
+ // --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
1497
+ // https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
1498
+ //
1499
+ // --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
1500
+ // https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
1501
+ //
1502
+
1503
+ std::string model_url = "https://huggingface.co/";
1504
+ model_url += repo;
1505
+ model_url += "/resolve/main/";
1506
+ model_url += remote_path;
1507
+
1508
+ return common_load_model_from_url(model_url, local_path, hf_token, params);
1509
+ }
1510
+
1511
+ /**
1512
+ * Allow getting the HF file from the HF repo with tag (like ollama), for example:
1513
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
1514
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
1515
+ * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
1516
+ * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
1517
+ *
1518
+ * Return pair of <repo, file> (with "repo" already having tag removed)
1519
+ *
1520
+ * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
1521
+ */
1522
+ std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
1523
+ auto parts = string_split<std::string>(hf_repo_with_tag, ':');
1524
+ std::string tag = parts.size() > 1 ? parts.back() : "latest";
1525
+ std::string hf_repo = parts[0];
1526
+ if (string_split<std::string>(hf_repo, '/').size() != 2) {
1527
+ throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
1528
+ }
1529
+
1530
+ // fetch model info from Hugging Face Hub API
1531
+ json model_info;
1532
+ curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
1533
+ curl_slist_ptr http_headers;
1534
+ std::string res_str;
1535
+ std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
1536
+ curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
1537
+ curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
1538
+ typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
1539
+ auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
1540
+ static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
1541
+ return size * nmemb;
1542
+ };
1543
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
1544
+ curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
1545
+ #if defined(_WIN32)
1546
+ curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
1547
+ #endif
1548
+ if (!hf_token.empty()) {
1549
+ std::string auth_header = "Authorization: Bearer " + hf_token;
1550
+ http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
1551
+ }
1552
+ // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
1553
+ http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
1554
+ http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
1555
+ curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
1556
+
1557
+ CURLcode res = curl_easy_perform(curl.get());
1558
+
1559
+ if (res != CURLE_OK) {
1560
+ throw std::runtime_error("error: cannot make GET request to HF API");
1561
+ }
1562
+
1563
+ long res_code;
1564
+ curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
1565
+ if (res_code == 200) {
1566
+ model_info = json::parse(res_str);
1567
+ } else if (res_code == 401) {
1568
+ throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
1569
+ } else {
1570
+ throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
1571
+ }
1572
+
1573
+ // check response
1574
+ if (!model_info.contains("ggufFile")) {
1575
+ throw std::runtime_error("error: model does not have ggufFile");
1576
+ }
1577
+ json & gguf_file = model_info.at("ggufFile");
1578
+ if (!gguf_file.contains("rfilename")) {
1579
+ throw std::runtime_error("error: ggufFile does not have rfilename");
1580
+ }
1581
+
1582
+ return std::make_pair(hf_repo, gguf_file.at("rfilename"));
1583
+ }
1584
+
1585
+ #else
1586
+
1587
+ struct llama_model * common_load_model_from_url(
1588
+ const std::string & /*model_url*/,
1589
+ const std::string & /*local_path*/,
1590
+ const std::string & /*hf_token*/,
1591
+ const struct llama_model_params & /*params*/) {
1592
+ LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
1593
+ return nullptr;
1594
+ }
1595
+
1596
+ struct llama_model * common_load_model_from_hf(
1597
+ const std::string & /*repo*/,
1598
+ const std::string & /*remote_path*/,
1599
+ const std::string & /*local_path*/,
1600
+ const std::string & /*hf_token*/,
1601
+ const struct llama_model_params & /*params*/) {
1602
+ LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
1603
+ return nullptr;
1604
+ }
1605
+
1606
+ std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
1607
+ LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
1608
+ return std::make_pair("", "");
1609
+ }
1610
+
1611
+ #endif // LLAMA_USE_CURL
1612
+
1613
+ //
1614
+ // Batch utils
1615
+ //
1616
+
1617
+ void common_batch_clear(struct llama_batch & batch) {
1618
+ batch.n_tokens = 0;
1619
+ }
1620
+
1621
+ void common_batch_add(
1622
+ struct llama_batch & batch,
1623
+ llama_token id,
1624
+ llama_pos pos,
1625
+ const std::vector<llama_seq_id> & seq_ids,
1626
+ bool logits) {
1627
+ GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
1628
+
1629
+ batch.token [batch.n_tokens] = id;
1630
+ batch.pos [batch.n_tokens] = pos;
1631
+ batch.n_seq_id[batch.n_tokens] = seq_ids.size();
1632
+ for (size_t i = 0; i < seq_ids.size(); ++i) {
1633
+ batch.seq_id[batch.n_tokens][i] = seq_ids[i];
1634
+ }
1635
+ batch.logits [batch.n_tokens] = logits;
1636
+
1637
+ batch.n_tokens++;
1638
+ }
1639
+
1640
+ //
1641
+ // Token utils
1642
+ //
1643
+
1644
+ size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
1645
+ size_t i;
1646
+ for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
1647
+
1648
+ return i;
1649
+ }
1650
+
1651
+ size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
1652
+ // check for empty sequences
1653
+ if (a.empty() || b.empty()) {
1654
+ return 0;
1655
+ }
1656
+
1657
+ // get the lengths of the input sequences
1658
+ size_t a_len = a.size();
1659
+ size_t b_len = b.size();
1660
+
1661
+ // initialize the maximum length of the longest common subsequence (LCS)
1662
+ size_t max_length = 0;
1663
+
1664
+ // use two rows instead of a 2D matrix to optimize space
1665
+ std::vector<size_t> prev_row(b_len + 1, 0);
1666
+ std::vector<size_t> curr_row(b_len + 1, 0);
1667
+
1668
+ // iterate through the elements of a
1669
+ for (size_t i = 1; i <= a_len; i++) {
1670
+ // iterate through the elements of b
1671
+ for (size_t j = 1; j <= b_len; j++) {
1672
+ // if elements at the current positions match
1673
+ if (a[i - 1] == b[j - 1]) {
1674
+ // if it's the first element of either sequences, set LCS length to 1
1675
+ if (i == 1 || j == 1) {
1676
+ curr_row[j] = 1;
1677
+ } else {
1678
+ // increment LCS length by 1 compared to the previous element
1679
+ curr_row[j] = prev_row[j - 1] + 1;
1680
+ }
1681
+
1682
+ // update max_length if necessary
1683
+ if (curr_row[j] > max_length) {
1684
+ max_length = curr_row[j];
1685
+ }
1686
+ } else {
1687
+ // reset LCS length if elements don't match
1688
+ curr_row[j] = 0;
1689
+ }
1690
+ }
1691
+
1692
+ // update the previous row for the next iteration
1693
+ prev_row = curr_row;
1694
+ }
1695
+
1696
+ // return the maximum length of the LCS
1697
+ return max_length;
1698
+ }
1699
+
1700
+ //
1701
+ // Vocab utils
1702
+ //
1703
+
1704
+ std::vector<llama_token> common_tokenize(
1705
+ const struct llama_context * ctx,
1706
+ const std::string & text,
1707
+ bool add_special,
1708
+ bool parse_special) {
1709
+ const llama_model * model = llama_get_model(ctx);
1710
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1711
+ return common_tokenize(vocab, text, add_special, parse_special);
1712
+ }
1713
+
1714
+ std::vector<llama_token> common_tokenize(
1715
+ const struct llama_vocab * vocab,
1716
+ const std::string & text,
1717
+ bool add_special,
1718
+ bool parse_special) {
1719
+ // upper limit for the number of tokens
1720
+ int n_tokens = text.length() + 2 * add_special;
1721
+ std::vector<llama_token> result(n_tokens);
1722
+ n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1723
+ if (n_tokens < 0) {
1724
+ result.resize(-n_tokens);
1725
+ int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1726
+ GGML_ASSERT(check == -n_tokens);
1727
+ } else {
1728
+ result.resize(n_tokens);
1729
+ }
1730
+ return result;
1731
+ }
1732
+
1733
+ std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1734
+ const llama_model * model = llama_get_model(ctx);
1735
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1736
+ return common_token_to_piece(vocab, token, special);
1737
+ }
1738
+
1739
+ std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
1740
+ std::string piece;
1741
+ piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
1742
+ const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
1743
+ if (n_chars < 0) {
1744
+ piece.resize(-n_chars);
1745
+ int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
1746
+ GGML_ASSERT(check == -n_chars);
1747
+ }
1748
+ else {
1749
+ piece.resize(n_chars);
1750
+ }
1751
+
1752
+ return piece;
1753
+ }
1754
+
1755
+ std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1756
+ const llama_model * model = llama_get_model(ctx);
1757
+ const llama_vocab * vocab = llama_model_get_vocab(model);
1758
+ return common_detokenize(vocab, tokens, special);
1759
+ }
1760
+
1761
+ std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
1762
+ std::string text;
1763
+ text.resize(std::max(text.capacity(), tokens.size()));
1764
+ int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1765
+ if (n_chars < 0) {
1766
+ text.resize(-n_chars);
1767
+ n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1768
+ GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
1769
+ }
1770
+
1771
+ text.resize(n_chars);
1772
+
1773
+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
1774
+ return text;
1775
+ }
1776
+
1777
+ //
1778
+ // KV cache utils
1779
+ //
1780
+
1781
+ void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
1782
+ static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
1783
+
1784
+ printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
1785
+ view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
1786
+
1787
+ llama_kv_cache_view_cell * c_curr = view.cells;
1788
+ llama_seq_id * cs_curr = view.cells_sequences;
1789
+
1790
+ for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
1791
+ if (i % row_size == 0) {
1792
+ printf("\n%5d: ", i);
1793
+ }
1794
+ int seq_count = 0;
1795
+ for (int j = 0; j < view.n_seq_max; j++) {
1796
+ if (cs_curr[j] >= 0) { seq_count++; }
1797
+ }
1798
+ putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
1799
+ }
1800
+
1801
+ printf("\n=== Done dumping\n");
1802
+ }
1803
+
1804
+ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
1805
+ static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
1806
+
1807
+ printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
1808
+ view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
1809
+
1810
+ std::unordered_map<llama_seq_id, size_t> seqs;
1811
+ llama_kv_cache_view_cell * c_curr = view.cells;
1812
+ llama_seq_id * cs_curr = view.cells_sequences;
1813
+
1814
+ for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
1815
+ for (int j = 0; j < view.n_seq_max; j++) {
1816
+ if (cs_curr[j] < 0) { continue; }
1817
+ if (seqs.find(cs_curr[j]) == seqs.end()) {
1818
+ if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
1819
+ const size_t sz = seqs.size();
1820
+ seqs[cs_curr[j]] = sz;
1821
+ }
1822
+ }
1823
+ if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
1824
+ }
1825
+
1826
+ printf("=== Sequence legend: ");
1827
+ for (const auto & it : seqs) {
1828
+ printf("%zu=%d, ", it.second, it.first);
1829
+ }
1830
+ printf("'+'=other sequence ids");
1831
+
1832
+ c_curr = view.cells;
1833
+ cs_curr = view.cells_sequences;
1834
+ for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
1835
+ if (i % row_size == 0) {
1836
+ printf("\n%5d: ", i);
1837
+ }
1838
+ for (int j = 0; j < view.n_seq_max; j++) {
1839
+ if (cs_curr[j] >= 0) {
1840
+ const auto & it = seqs.find(cs_curr[j]);
1841
+ putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
1842
+ } else {
1843
+ putchar('.');
1844
+ }
1845
+ }
1846
+ putchar(' ');
1847
+ }
1848
+
1849
+ printf("\n=== Done dumping\n");
1850
+ }
1851
+
1852
+ //
1853
+ // Embedding utils
1854
+ //
1855
+
1856
+ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
1857
+ double sum = 0.0;
1858
+
1859
+ switch (embd_norm) {
1860
+ case -1: // no normalisation
1861
+ sum = 1.0;
1862
+ break;
1863
+ case 0: // max absolute
1864
+ for (int i = 0; i < n; i++) {
1865
+ if (sum < std::abs(inp[i])) {
1866
+ sum = std::abs(inp[i]);
1867
+ }
1868
+ }
1869
+ sum /= 32760.0; // make an int16 range
1870
+ break;
1871
+ case 2: // euclidean
1872
+ for (int i = 0; i < n; i++) {
1873
+ sum += inp[i] * inp[i];
1874
+ }
1875
+ sum = std::sqrt(sum);
1876
+ break;
1877
+ default: // p-norm (euclidean is p-norm p=2)
1878
+ for (int i = 0; i < n; i++) {
1879
+ sum += std::pow(std::abs(inp[i]), embd_norm);
1880
+ }
1881
+ sum = std::pow(sum, 1.0 / embd_norm);
1882
+ break;
1883
+ }
1884
+
1885
+ const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
1886
+
1887
+ for (int i = 0; i < n; i++) {
1888
+ out[i] = inp[i] * norm;
1889
+ }
1890
+ }
1891
+
1892
+ float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){
1893
+ double sum = 0.0;
1894
+ double sum1 = 0.0;
1895
+ double sum2 = 0.0;
1896
+
1897
+ for (int i = 0; i < n; i++) {
1898
+ sum += embd1[i] * embd2[i];
1899
+ sum1 += embd1[i] * embd1[i];
1900
+ sum2 += embd2[i] * embd2[i];
1901
+ }
1902
+
1903
+ // Handle the case where one or both vectors are zero vectors
1904
+ if (sum1 == 0.0 || sum2 == 0.0) {
1905
+ if (sum1 == 0.0 && sum2 == 0.0) {
1906
+ return 1.0f; // two zero vectors are similar
1907
+ }
1908
+ return 0.0f;
1909
+ }
1910
+
1911
+ return sum / (sqrt(sum1) * sqrt(sum2));
1912
+ }
1913
+
1914
+ //
1915
+ // Control vector utils
1916
+ //
1917
+
1918
+ static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) {
1919
+ common_control_vector_data result = { -1, {} };
1920
+
1921
+ ggml_context * ctx = nullptr;
1922
+ struct gguf_init_params meta_gguf_params = {
1923
+ /* .no_alloc = */ false,
1924
+ /* .ctx = */ &ctx,
1925
+ };
1926
+ struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
1927
+ if (!ctx_gguf) {
1928
+ LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
1929
+ return result;
1930
+ }
1931
+
1932
+ int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
1933
+ if (n_tensors == 0) {
1934
+ LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
1935
+ }
1936
+
1937
+ for (int i = 0; i < n_tensors; i++) {
1938
+ std::string name = gguf_get_tensor_name(ctx_gguf, i);
1939
+
1940
+ int layer_idx = -1;
1941
+
1942
+ // split on '.'
1943
+ size_t dotpos = name.find('.');
1944
+ if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
1945
+ try {
1946
+ layer_idx = std::stoi(name.substr(dotpos + 1));
1947
+ } catch (...) {
1948
+ layer_idx = -1;
1949
+ }
1950
+ }
1951
+ if (layer_idx < 0) {
1952
+ LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
1953
+ result.n_embd = -1;
1954
+ break;
1955
+ } else if (layer_idx == 0) {
1956
+ LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
1957
+ result.n_embd = -1;
1958
+ break;
1959
+ }
1960
+
1961
+ struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
1962
+ if (tensor->type != GGML_TYPE_F32) {
1963
+ LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
1964
+ result.n_embd = -1;
1965
+ break;
1966
+ }
1967
+ if (ggml_n_dims(tensor) != 1) {
1968
+ LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
1969
+ result.n_embd = -1;
1970
+ break;
1971
+ }
1972
+
1973
+ if (result.n_embd == -1) {
1974
+ result.n_embd = ggml_nelements(tensor);
1975
+ } else if (ggml_nelements(tensor) != result.n_embd) {
1976
+ LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
1977
+ result.n_embd = -1;
1978
+ break;
1979
+ }
1980
+
1981
+ // extend if necessary - do not store data for layer 0 (it's not used)
1982
+ result.data.resize(std::max(result.data.size(), static_cast<size_t>(result.n_embd * layer_idx)), 0.0f);
1983
+
1984
+ const float * src = (const float *) tensor->data;
1985
+ float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0]
1986
+ for (int j = 0; j < result.n_embd; j++) {
1987
+ dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file
1988
+ }
1989
+
1990
+ }
1991
+
1992
+ if (result.n_embd == -1) {
1993
+ LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
1994
+ result.data.clear();
1995
+ }
1996
+
1997
+ gguf_free(ctx_gguf);
1998
+ ggml_free(ctx);
1999
+
2000
+ return result;
2001
+ }
2002
+
2003
+ common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos) {
2004
+ common_control_vector_data result = { -1, {} };
2005
+
2006
+ for (const auto & info : load_infos) {
2007
+ auto cur = common_control_vector_load_one(info);
2008
+
2009
+ if (cur.n_embd == -1) {
2010
+ result.n_embd = -1;
2011
+ break;
2012
+ }
2013
+ if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
2014
+ LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
2015
+ result.n_embd = -1;
2016
+ break;
2017
+ }
2018
+
2019
+ if (result.n_embd == -1) {
2020
+ result = std::move(cur);
2021
+ } else {
2022
+ result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f); // extend if necessary
2023
+ for (size_t i = 0; i < cur.data.size(); i++) {
2024
+ result.data[i] += cur.data[i];
2025
+ }
2026
+ }
2027
+ }
2028
+
2029
+ if (result.n_embd == -1) {
2030
+ LOG_ERR("%s: no valid control vector files passed\n", __func__);
2031
+ result.data.clear();
2032
+ }
2033
+
2034
+ return result;
2035
+ }
2036
+
2037
+ template <>
2038
+ json common_grammar_trigger::to_json() const {
2039
+ json out {
2040
+ {"type", (int) type},
2041
+ {"value", value},
2042
+ };
2043
+ if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
2044
+ out["token"] = (int) token;
2045
+ }
2046
+ return out;
2047
+ }
2048
+
2049
+ template <>
2050
+ common_grammar_trigger common_grammar_trigger::from_json(const json & in) {
2051
+ common_grammar_trigger out;
2052
+ out.type = (common_grammar_trigger_type) in.at("type").get<int>();
2053
+ out.value = in.at("value").get<std::string>();
2054
+ if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
2055
+ out.token = (llama_token) in.at("token").get<int>();
2056
+ }
2057
+ return out;
2058
+ }
common/common.h ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Various helper functions and utilities
2
+
3
+ #pragma once
4
+
5
+ #include "llama-cpp.h"
6
+
7
+ #include <set>
8
+ #include <string>
9
+ #include <vector>
10
+ #include <sstream>
11
+
12
+ #ifdef _WIN32
13
+ #define DIRECTORY_SEPARATOR '\\'
14
+ #else
15
+ #define DIRECTORY_SEPARATOR '/'
16
+ #endif // _WIN32
17
+
18
+ #define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
19
+ #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
20
+
21
+ #define print_build_info() do { \
22
+ fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \
23
+ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
24
+ } while(0)
25
+
26
+ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
27
+
28
+ struct common_adapter_lora_info {
29
+ std::string path;
30
+ float scale;
31
+
32
+ struct llama_adapter_lora * ptr;
33
+ };
34
+
35
+ using llama_tokens = std::vector<llama_token>;
36
+
37
+ // build info
38
+
39
+ struct common_control_vector_load_info;
40
+
41
+ //
42
+ // CPU utils
43
+ //
44
+
45
+ struct cpu_params {
46
+ int n_threads = -1;
47
+ bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
48
+ bool mask_valid = false; // Default: any CPU
49
+ enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
50
+ bool strict_cpu = false; // Use strict CPU placement
51
+ uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
52
+ };
53
+
54
+ int32_t cpu_get_num_physical_cores();
55
+ int32_t cpu_get_num_math();
56
+
57
+ //
58
+ // Common params
59
+ //
60
+
61
+ enum llama_example {
62
+ LLAMA_EXAMPLE_COMMON,
63
+ LLAMA_EXAMPLE_SPECULATIVE,
64
+ LLAMA_EXAMPLE_MAIN,
65
+ LLAMA_EXAMPLE_INFILL,
66
+ LLAMA_EXAMPLE_EMBEDDING,
67
+ LLAMA_EXAMPLE_PERPLEXITY,
68
+ LLAMA_EXAMPLE_RETRIEVAL,
69
+ LLAMA_EXAMPLE_PASSKEY,
70
+ LLAMA_EXAMPLE_IMATRIX,
71
+ LLAMA_EXAMPLE_BENCH,
72
+ LLAMA_EXAMPLE_SERVER,
73
+ LLAMA_EXAMPLE_CVECTOR_GENERATOR,
74
+ LLAMA_EXAMPLE_EXPORT_LORA,
75
+ LLAMA_EXAMPLE_LLAVA,
76
+ LLAMA_EXAMPLE_LOOKUP,
77
+ LLAMA_EXAMPLE_PARALLEL,
78
+ LLAMA_EXAMPLE_TTS,
79
+
80
+ LLAMA_EXAMPLE_COUNT,
81
+ };
82
+
83
+ enum common_sampler_type {
84
+ COMMON_SAMPLER_TYPE_NONE = 0,
85
+ COMMON_SAMPLER_TYPE_DRY = 1,
86
+ COMMON_SAMPLER_TYPE_TOP_K = 2,
87
+ COMMON_SAMPLER_TYPE_TOP_P = 3,
88
+ COMMON_SAMPLER_TYPE_MIN_P = 4,
89
+ //COMMON_SAMPLER_TYPE_TFS_Z = 5,
90
+ COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
91
+ COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
92
+ COMMON_SAMPLER_TYPE_XTC = 8,
93
+ COMMON_SAMPLER_TYPE_INFILL = 9,
94
+ COMMON_SAMPLER_TYPE_PENALTIES = 10,
95
+ };
96
+
97
+ // dimensionality reduction methods, used by cvector-generator
98
+ enum dimre_method {
99
+ DIMRE_METHOD_PCA,
100
+ DIMRE_METHOD_MEAN,
101
+ };
102
+
103
+ enum common_conversation_mode {
104
+ COMMON_CONVERSATION_MODE_DISABLED = 0,
105
+ COMMON_CONVERSATION_MODE_ENABLED = 1,
106
+ COMMON_CONVERSATION_MODE_AUTO = 2,
107
+ };
108
+
109
+ enum common_grammar_trigger_type {
110
+ COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
111
+ COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
112
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
113
+ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
114
+ };
115
+
116
+ struct common_grammar_trigger {
117
+ common_grammar_trigger_type type;
118
+ std::string value;
119
+ llama_token token = LLAMA_TOKEN_NULL;
120
+
121
+ // T can only be nlohmann::ordered_json
122
+ template <class T> T to_json() const;
123
+ template <class T> static common_grammar_trigger from_json(const T & in);
124
+ };
125
+
126
+ // sampling parameters
127
+ struct common_params_sampling {
128
+ uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
129
+
130
+ int32_t n_prev = 64; // number of previous tokens to remember
131
+ int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
132
+ int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
133
+ int32_t top_k = 40; // <= 0 to use vocab size
134
+ float top_p = 0.95f; // 1.0 = disabled
135
+ float min_p = 0.05f; // 0.0 = disabled
136
+ float xtc_probability = 0.00f; // 0.0 = disabled
137
+ float xtc_threshold = 0.10f; // > 0.5 disables XTC
138
+ float typ_p = 1.00f; // typical_p, 1.0 = disabled
139
+ float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
140
+ float dynatemp_range = 0.00f; // 0.0 = disabled
141
+ float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
142
+ int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
143
+ float penalty_repeat = 1.00f; // 1.0 = disabled
144
+ float penalty_freq = 0.00f; // 0.0 = disabled
145
+ float penalty_present = 0.00f; // 0.0 = disabled
146
+ float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
147
+ float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
148
+ int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
149
+ int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
150
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
151
+ float top_n_sigma = -1.00f;// -1.0 = disabled
152
+ float mirostat_tau = 5.00f; // target entropy
153
+ float mirostat_eta = 0.10f; // learning rate
154
+ bool ignore_eos = false;
155
+ bool no_perf = false; // disable performance metrics
156
+ bool timing_per_token = false;
157
+
158
+ std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
159
+
160
+
161
+ std::vector<enum common_sampler_type> samplers = {
162
+ COMMON_SAMPLER_TYPE_PENALTIES,
163
+ COMMON_SAMPLER_TYPE_DRY,
164
+ COMMON_SAMPLER_TYPE_TOP_K,
165
+ COMMON_SAMPLER_TYPE_TYPICAL_P,
166
+ COMMON_SAMPLER_TYPE_TOP_P,
167
+ COMMON_SAMPLER_TYPE_MIN_P,
168
+ COMMON_SAMPLER_TYPE_XTC,
169
+ COMMON_SAMPLER_TYPE_TEMPERATURE,
170
+ };
171
+
172
+ std::string grammar; // optional BNF-like grammar to constrain sampling
173
+ bool grammar_lazy = false;
174
+ std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
175
+ std::set<llama_token> preserved_tokens;
176
+
177
+ std::vector<llama_logit_bias> logit_bias; // logit biases to apply
178
+
179
+ // print the parameters into a string
180
+ std::string print() const;
181
+ };
182
+
183
+ struct common_params_speculative {
184
+ std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
185
+
186
+ int32_t n_ctx = 0; // draft context size
187
+ int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
188
+ int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
189
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
190
+ float p_split = 0.1f; // speculative decoding split probability
191
+ float p_min = 0.75f; // minimum speculative decoding probability (greedy)
192
+
193
+ struct cpu_params cpuparams;
194
+ struct cpu_params cpuparams_batch;
195
+
196
+ std::string hf_repo = ""; // HF repo // NOLINT
197
+ std::string hf_file = ""; // HF file // NOLINT
198
+
199
+ std::string model = ""; // draft model for speculative decoding // NOLINT
200
+ std::string model_url = ""; // model url to download // NOLINT
201
+ };
202
+
203
+ struct common_params_vocoder {
204
+ std::string hf_repo = ""; // HF repo // NOLINT
205
+ std::string hf_file = ""; // HF file // NOLINT
206
+
207
+ std::string model = ""; // model path // NOLINT
208
+ std::string model_url = ""; // model url to download // NOLINT
209
+
210
+ std::string speaker_file = ""; // speaker file path // NOLINT
211
+
212
+ bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
213
+ };
214
+
215
+ enum common_reasoning_format {
216
+ COMMON_REASONING_FORMAT_NONE,
217
+ COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`
218
+ };
219
+
220
+ struct common_params {
221
+ int32_t n_predict = -1; // new tokens to predict
222
+ int32_t n_ctx = 4096; // context size
223
+ int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
224
+ int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
225
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
226
+ int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
227
+ int32_t n_parallel = 1; // number of parallel sequences to decode
228
+ int32_t n_sequences = 1; // number of sequences to decode
229
+ int32_t grp_attn_n = 1; // group-attention factor
230
+ int32_t grp_attn_w = 512; // group-attention width
231
+ int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
232
+ float rope_freq_base = 0.0f; // RoPE base frequency
233
+ float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
234
+ float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
235
+ float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
236
+ float yarn_beta_fast = 32.0f; // YaRN low correction dim
237
+ float yarn_beta_slow = 1.0f; // YaRN high correction dim
238
+ int32_t yarn_orig_ctx = 0; // YaRN original context length
239
+ float defrag_thold = 0.1f; // KV cache defragmentation threshold
240
+
241
+ // offload params
242
+ std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
243
+
244
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
245
+ int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
246
+ float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
247
+
248
+ enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
249
+
250
+ struct cpu_params cpuparams;
251
+ struct cpu_params cpuparams_batch;
252
+
253
+ ggml_backend_sched_eval_callback cb_eval = nullptr;
254
+ void * cb_eval_user_data = nullptr;
255
+
256
+ ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
257
+
258
+ enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
259
+ enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
260
+ enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
261
+
262
+ struct common_params_sampling sampling;
263
+ struct common_params_speculative speculative;
264
+ struct common_params_vocoder vocoder;
265
+
266
+ std::string model = ""; // model path // NOLINT
267
+ std::string model_alias = ""; // model alias // NOLINT
268
+ std::string model_url = ""; // model url to download // NOLINT
269
+ std::string hf_token = ""; // HF token // NOLINT
270
+ std::string hf_repo = ""; // HF repo // NOLINT
271
+ std::string hf_file = ""; // HF file // NOLINT
272
+ std::string prompt = ""; // NOLINT
273
+ std::string system_prompt = ""; // NOLINT
274
+ std::string prompt_file = ""; // store the external prompt file name // NOLINT
275
+ std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
276
+ std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
277
+ std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
278
+ std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
279
+ std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
280
+ std::string logits_file = ""; // file for saving *all* logits // NOLINT
281
+
282
+ std::vector<std::string> in_files; // all input files
283
+ std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
284
+ std::vector<llama_model_kv_override> kv_overrides;
285
+
286
+ bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
287
+ std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
288
+
289
+ std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
290
+
291
+ int32_t verbosity = 0;
292
+ int32_t control_vector_layer_start = -1; // layer range for control vector
293
+ int32_t control_vector_layer_end = -1; // layer range for control vector
294
+
295
+ int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
296
+ int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
297
+ // (which is more convenient to use for plotting)
298
+ //
299
+ bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
300
+ size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
301
+
302
+ bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
303
+ size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
304
+
305
+ bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt
306
+ size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
307
+
308
+ bool kl_divergence = false; // compute KL divergence
309
+
310
+ bool usage = false; // print usage
311
+ bool completion = false; // print source-able completion script
312
+ bool use_color = false; // use color to distinguish generations and inputs
313
+ bool special = false; // enable special token output
314
+ bool interactive = false; // interactive mode
315
+ bool interactive_first = false; // wait for user input immediately
316
+ bool prompt_cache_all = false; // save user input and generations to prompt cache
317
+ bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
318
+
319
+ bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
320
+ bool multiline_input = false; // reverse the usage of `\`
321
+ bool simple_io = false; // improves compatibility with subprocesses and limited consoles
322
+ bool cont_batching = true; // insert new sequences for decoding on-the-fly
323
+ bool flash_attn = false; // flash attention
324
+ bool no_perf = false; // disable performance metrics
325
+ bool ctx_shift = true; // context shift on inifinite text generation
326
+
327
+ bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
328
+ bool logits_all = false; // return logits for all tokens in the batch
329
+ bool use_mmap = true; // use mmap for faster loads
330
+ bool use_mlock = false; // use mlock to keep model in memory
331
+ bool verbose_prompt = false; // print prompt tokens before generation
332
+ bool display_prompt = true; // print prompt before generation
333
+ bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
334
+ bool no_kv_offload = false; // disable KV offloading
335
+ bool warmup = true; // warmup run
336
+ bool check_tensors = false; // validate tensor data
337
+
338
+ bool single_turn = false; // single turn chat conversation
339
+
340
+ ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
341
+ ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
342
+
343
+ common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
344
+
345
+ // multimodal models (see examples/llava)
346
+ std::string mmproj = ""; // path to multimodal projector // NOLINT
347
+ std::vector<std::string> image; // path to image file(s)
348
+
349
+ // embedding
350
+ bool embedding = false; // get only sentence embedding
351
+ int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
352
+ std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
353
+ std::string embd_sep = "\n"; // separator of embeddings
354
+ bool reranking = false; // enable reranking support on server
355
+
356
+ // server params
357
+ int32_t port = 8080; // server listens on this network port
358
+ int32_t timeout_read = 600; // http read timeout in seconds
359
+ int32_t timeout_write = timeout_read; // http write timeout in seconds
360
+ int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
361
+ int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
362
+
363
+ std::string hostname = "127.0.0.1";
364
+ std::string public_path = ""; // NOLINT
365
+ std::string chat_template = ""; // NOLINT
366
+ bool use_jinja = false; // NOLINT
367
+ bool enable_chat_template = true;
368
+ common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
369
+
370
+ std::vector<std::string> api_keys;
371
+
372
+ std::string ssl_file_key = ""; // NOLINT
373
+ std::string ssl_file_cert = ""; // NOLINT
374
+
375
+ // "advanced" endpoints are disabled by default for better security
376
+ bool webui = true;
377
+ bool endpoint_slots = false;
378
+ bool endpoint_props = false; // only control POST requests, not GET
379
+ bool endpoint_metrics = false;
380
+
381
+ bool log_json = false;
382
+
383
+ std::string slot_save_path;
384
+
385
+ float slot_prompt_similarity = 0.5f;
386
+
387
+ // batched-bench params
388
+ bool is_pp_shared = false;
389
+
390
+ std::vector<int32_t> n_pp;
391
+ std::vector<int32_t> n_tg;
392
+ std::vector<int32_t> n_pl;
393
+
394
+ // retrieval params
395
+ std::vector<std::string> context_files; // context files to embed
396
+
397
+ int32_t chunk_size = 64; // chunk size for context embedding
398
+
399
+ std::string chunk_separator = "\n"; // chunk separator for context embedding
400
+
401
+ // passkey params
402
+ int32_t n_junk = 250; // number of times to repeat the junk text
403
+ int32_t i_pos = -1; // position of the passkey in the junk text
404
+
405
+ // imatrix params
406
+ int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
407
+ int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
408
+ int32_t i_chunk = 0; // start processing from this chunk
409
+
410
+ bool process_output = false; // collect data for the output tensor
411
+ bool compute_ppl = true; // whether to compute perplexity
412
+
413
+ // cvector-generator params
414
+ int n_pca_batch = 100;
415
+ int n_pca_iterations = 1000;
416
+ dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
417
+ std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
418
+ std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
419
+
420
+ bool spm_infill = false; // suffix/prefix/middle pattern for infill
421
+
422
+ // batched-bench params
423
+ bool batched_bench_output_jsonl = false;
424
+
425
+ // common params
426
+ std::string out_file; // output filename for all example programs
427
+ };
428
+
429
+ // call once at the start of a program if it uses libcommon
430
+ // initializes the logging system and prints info about the build
431
+ void common_init();
432
+
433
+ std::string common_params_get_system_info(const common_params & params);
434
+
435
+ bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
436
+ bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
437
+ void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
438
+ bool set_process_priority(enum ggml_sched_priority prio);
439
+
440
+ //
441
+ // String utils
442
+ //
443
+
444
+ #ifdef __GNUC__
445
+ # if defined(__MINGW32__) && !defined(__clang__)
446
+ # define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
447
+ # else
448
+ # define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
449
+ # endif
450
+ #else
451
+ # define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
452
+ #endif
453
+
454
+ LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
455
+ std::string string_format(const char * fmt, ...);
456
+
457
+ std::string string_strip(const std::string & str);
458
+ std::string string_get_sortable_timestamp();
459
+
460
+ std::string string_join(const std::vector<std::string> & values, const std::string & separator);
461
+ std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
462
+ std::string string_repeat(const std::string & str, size_t n);
463
+
464
+ void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
465
+
466
+ std::string regex_escape(const std::string & s);
467
+
468
+ template<class T>
469
+ static std::vector<T> string_split(const std::string & str, char delim) {
470
+ static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
471
+ std::vector<T> values;
472
+ std::istringstream str_stream(str);
473
+ std::string token;
474
+ while (std::getline(str_stream, token, delim)) {
475
+ T value;
476
+ std::istringstream token_stream(token);
477
+ token_stream >> value;
478
+ values.push_back(value);
479
+ }
480
+ return values;
481
+ }
482
+
483
+ template<>
484
+ std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
485
+ {
486
+ std::vector<std::string> parts;
487
+ size_t begin_pos = 0;
488
+ size_t separator_pos = input.find(separator);
489
+ while (separator_pos != std::string::npos) {
490
+ std::string part = input.substr(begin_pos, separator_pos - begin_pos);
491
+ parts.emplace_back(part);
492
+ begin_pos = separator_pos + 1;
493
+ separator_pos = input.find(separator, begin_pos);
494
+ }
495
+ parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
496
+ return parts;
497
+ }
498
+
499
+ static bool string_starts_with(const std::string & str,
500
+ const std::string & prefix) { // While we wait for C++20's std::string::starts_with...
501
+ return str.rfind(prefix, 0) == 0;
502
+ }
503
+
504
+ static bool string_ends_with(const std::string & str,
505
+ const std::string & suffix) { // While we wait for C++20's std::string::ends_with...
506
+ return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
507
+ }
508
+
509
+ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
510
+ void string_process_escapes(std::string & input);
511
+
512
+ std::string string_from(bool value);
513
+ std::string string_from(const std::vector<int> & values);
514
+ std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
515
+ std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
516
+
517
+ //
518
+ // Filesystem utils
519
+ //
520
+
521
+ bool fs_validate_filename(const std::string & filename);
522
+ bool fs_create_directory_with_parents(const std::string & path);
523
+
524
+ std::string fs_get_cache_directory();
525
+ std::string fs_get_cache_file(const std::string & filename);
526
+
527
+ //
528
+ // Model utils
529
+ //
530
+
531
+ // note: defines object's lifetime
532
+ struct common_init_result {
533
+ llama_model_ptr model;
534
+ llama_context_ptr context;
535
+
536
+ std::vector<llama_adapter_lora_ptr> lora;
537
+ };
538
+
539
+ struct common_init_result common_init_from_params(common_params & params);
540
+
541
+ struct llama_model_params common_model_params_to_llama ( common_params & params);
542
+ struct llama_context_params common_context_params_to_llama(const common_params & params);
543
+ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
544
+
545
+ struct llama_model * common_load_model_from_url(
546
+ const std::string & model_url,
547
+ const std::string & local_path,
548
+ const std::string & hf_token,
549
+ const struct llama_model_params & params);
550
+
551
+ struct llama_model * common_load_model_from_hf(
552
+ const std::string & repo,
553
+ const std::string & remote_path,
554
+ const std::string & local_path,
555
+ const std::string & hf_token,
556
+ const struct llama_model_params & params);
557
+
558
+ std::pair<std::string, std::string> common_get_hf_file(
559
+ const std::string & hf_repo_with_tag,
560
+ const std::string & hf_token);
561
+
562
+ // clear LoRA adapters from context, then apply new list of adapters
563
+ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
564
+
565
+ //
566
+ // Batch utils
567
+ //
568
+
569
+ void common_batch_clear(struct llama_batch & batch);
570
+
571
+ void common_batch_add(
572
+ struct llama_batch & batch,
573
+ llama_token id,
574
+ llama_pos pos,
575
+ const std::vector<llama_seq_id> & seq_ids,
576
+ bool logits);
577
+
578
+ //
579
+ // Token utils
580
+ //
581
+
582
+ // longest common prefix
583
+ size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
584
+
585
+ // longet common subsequence
586
+ size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
587
+
588
+ //
589
+ // Vocab utils
590
+ //
591
+
592
+ // tokenizes a string into a vector of tokens
593
+ // should work similar to Python's `tokenizer.encode`
594
+ std::vector<llama_token> common_tokenize(
595
+ const struct llama_context * ctx,
596
+ const std::string & text,
597
+ bool add_special,
598
+ bool parse_special = false);
599
+
600
+ std::vector<llama_token> common_tokenize(
601
+ const struct llama_vocab * vocab,
602
+ const std::string & text,
603
+ bool add_special,
604
+ bool parse_special = false);
605
+
606
+ // tokenizes a token into a piece, optionally renders special/control tokens
607
+ // should work similar to Python's `tokenizer.id_to_piece`
608
+ std::string common_token_to_piece(
609
+ const struct llama_context * ctx,
610
+ llama_token token,
611
+ bool special = true);
612
+
613
+ std::string common_token_to_piece(
614
+ const struct llama_vocab * vocab,
615
+ llama_token token,
616
+ bool special = true);
617
+
618
+ // detokenizes a vector of tokens into a string
619
+ // should work similar to Python's `tokenizer.decode`
620
+ // optionally renders special/control tokens
621
+ std::string common_detokenize(
622
+ const struct llama_context * ctx,
623
+ const std::vector<llama_token> & tokens,
624
+ bool special = true);
625
+
626
+ std::string common_detokenize(
627
+ const struct llama_vocab * vocab,
628
+ const std::vector<llama_token> & tokens,
629
+ bool special = true);
630
+
631
+ //
632
+ // KV cache utils
633
+ //
634
+
635
+ // Dump the KV cache view with the number of sequences per cell.
636
+ void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
637
+
638
+ // Dump the KV cache view showing individual sequences in each cell (long output).
639
+ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
640
+
641
+ //
642
+ // Embedding utils
643
+ //
644
+
645
+ // TODO: repace embd_norm with an enum
646
+ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
647
+
648
+ float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
649
+
650
+ //
651
+ // Control vector utils
652
+ //
653
+
654
+ struct common_control_vector_data {
655
+ int n_embd;
656
+
657
+ // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
658
+ std::vector<float> data;
659
+ };
660
+
661
+ struct common_control_vector_load_info {
662
+ float strength;
663
+
664
+ std::string fname;
665
+ };
666
+
667
+ // Load control vectors, scale each by strength, and add them together.
668
+ // On error, returns {-1, empty}
669
+ common_control_vector_data common_control_vector_load(const std::vector<common_control_vector_load_info> & load_infos);
670
+
671
+ //
672
+ // Split utils
673
+ //
674
+
675
+ namespace {
676
+
677
+ const char * const LLM_KV_SPLIT_NO = "split.no";
678
+ const char * const LLM_KV_SPLIT_COUNT = "split.count";
679
+ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
680
+
681
+ }
common/console.cpp ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "console.h"
2
+ #include <vector>
3
+ #include <iostream>
4
+
5
+ #if defined(_WIN32)
6
+ #define WIN32_LEAN_AND_MEAN
7
+ #ifndef NOMINMAX
8
+ #define NOMINMAX
9
+ #endif
10
+ #include <windows.h>
11
+ #include <fcntl.h>
12
+ #include <io.h>
13
+ #ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING
14
+ #define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004
15
+ #endif
16
+ #else
17
+ #include <climits>
18
+ #include <sys/ioctl.h>
19
+ #include <unistd.h>
20
+ #include <wchar.h>
21
+ #include <stdio.h>
22
+ #include <stdlib.h>
23
+ #include <signal.h>
24
+ #include <termios.h>
25
+ #endif
26
+
27
+ #define ANSI_COLOR_RED "\x1b[31m"
28
+ #define ANSI_COLOR_GREEN "\x1b[32m"
29
+ #define ANSI_COLOR_YELLOW "\x1b[33m"
30
+ #define ANSI_COLOR_BLUE "\x1b[34m"
31
+ #define ANSI_COLOR_MAGENTA "\x1b[35m"
32
+ #define ANSI_COLOR_CYAN "\x1b[36m"
33
+ #define ANSI_COLOR_RESET "\x1b[0m"
34
+ #define ANSI_BOLD "\x1b[1m"
35
+
36
+ namespace console {
37
+
38
+ //
39
+ // Console state
40
+ //
41
+
42
+ static bool advanced_display = false;
43
+ static bool simple_io = true;
44
+ static display_t current_display = reset;
45
+
46
+ static FILE* out = stdout;
47
+
48
+ #if defined (_WIN32)
49
+ static void* hConsole;
50
+ #else
51
+ static FILE* tty = nullptr;
52
+ static termios initial_state;
53
+ #endif
54
+
55
+ //
56
+ // Init and cleanup
57
+ //
58
+
59
+ void init(bool use_simple_io, bool use_advanced_display) {
60
+ advanced_display = use_advanced_display;
61
+ simple_io = use_simple_io;
62
+ #if defined(_WIN32)
63
+ // Windows-specific console initialization
64
+ DWORD dwMode = 0;
65
+ hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
66
+ if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
67
+ hConsole = GetStdHandle(STD_ERROR_HANDLE);
68
+ if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) {
69
+ hConsole = nullptr;
70
+ simple_io = true;
71
+ }
72
+ }
73
+ if (hConsole) {
74
+ // Check conditions combined to reduce nesting
75
+ if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) &&
76
+ !SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
77
+ advanced_display = false;
78
+ }
79
+ // Set console output codepage to UTF8
80
+ SetConsoleOutputCP(CP_UTF8);
81
+ }
82
+ HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
83
+ if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
84
+ // Set console input codepage to UTF16
85
+ _setmode(_fileno(stdin), _O_WTEXT);
86
+
87
+ // Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
88
+ if (simple_io) {
89
+ dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT;
90
+ } else {
91
+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
92
+ }
93
+ if (!SetConsoleMode(hConIn, dwMode)) {
94
+ simple_io = true;
95
+ }
96
+ }
97
+ if (simple_io) {
98
+ _setmode(_fileno(stdin), _O_U8TEXT);
99
+ }
100
+ #else
101
+ // POSIX-specific console initialization
102
+ if (!simple_io) {
103
+ struct termios new_termios;
104
+ tcgetattr(STDIN_FILENO, &initial_state);
105
+ new_termios = initial_state;
106
+ new_termios.c_lflag &= ~(ICANON | ECHO);
107
+ new_termios.c_cc[VMIN] = 1;
108
+ new_termios.c_cc[VTIME] = 0;
109
+ tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
110
+
111
+ tty = fopen("/dev/tty", "w+");
112
+ if (tty != nullptr) {
113
+ out = tty;
114
+ }
115
+ }
116
+
117
+ setlocale(LC_ALL, "");
118
+ #endif
119
+ }
120
+
121
+ void cleanup() {
122
+ // Reset console display
123
+ set_display(reset);
124
+
125
+ #if !defined(_WIN32)
126
+ // Restore settings on POSIX systems
127
+ if (!simple_io) {
128
+ if (tty != nullptr) {
129
+ out = stdout;
130
+ fclose(tty);
131
+ tty = nullptr;
132
+ }
133
+ tcsetattr(STDIN_FILENO, TCSANOW, &initial_state);
134
+ }
135
+ #endif
136
+ }
137
+
138
+ //
139
+ // Display and IO
140
+ //
141
+
142
+ // Keep track of current display and only emit ANSI code if it changes
143
+ void set_display(display_t display) {
144
+ if (advanced_display && current_display != display) {
145
+ fflush(stdout);
146
+ switch(display) {
147
+ case reset:
148
+ fprintf(out, ANSI_COLOR_RESET);
149
+ break;
150
+ case prompt:
151
+ fprintf(out, ANSI_COLOR_YELLOW);
152
+ break;
153
+ case user_input:
154
+ fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
155
+ break;
156
+ case error:
157
+ fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
158
+ }
159
+ current_display = display;
160
+ fflush(out);
161
+ }
162
+ }
163
+
164
+ static char32_t getchar32() {
165
+ #if defined(_WIN32)
166
+ HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE);
167
+ wchar_t high_surrogate = 0;
168
+
169
+ while (true) {
170
+ INPUT_RECORD record;
171
+ DWORD count;
172
+ if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) {
173
+ return WEOF;
174
+ }
175
+
176
+ if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) {
177
+ wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar;
178
+ if (wc == 0) {
179
+ continue;
180
+ }
181
+
182
+ if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
183
+ high_surrogate = wc;
184
+ continue;
185
+ }
186
+ if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate
187
+ if (high_surrogate != 0) { // Check if we have a high surrogate
188
+ return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000;
189
+ }
190
+ }
191
+
192
+ high_surrogate = 0; // Reset the high surrogate
193
+ return static_cast<char32_t>(wc);
194
+ }
195
+ }
196
+ #else
197
+ wchar_t wc = getwchar();
198
+ if (static_cast<wint_t>(wc) == WEOF) {
199
+ return WEOF;
200
+ }
201
+
202
+ #if WCHAR_MAX == 0xFFFF
203
+ if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
204
+ wchar_t low_surrogate = getwchar();
205
+ if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate
206
+ return (static_cast<char32_t>(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000;
207
+ }
208
+ }
209
+ if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair
210
+ return 0xFFFD; // Return the replacement character U+FFFD
211
+ }
212
+ #endif
213
+
214
+ return static_cast<char32_t>(wc);
215
+ #endif
216
+ }
217
+
218
+ static void pop_cursor() {
219
+ #if defined(_WIN32)
220
+ if (hConsole != NULL) {
221
+ CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
222
+ GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
223
+
224
+ COORD newCursorPosition = bufferInfo.dwCursorPosition;
225
+ if (newCursorPosition.X == 0) {
226
+ newCursorPosition.X = bufferInfo.dwSize.X - 1;
227
+ newCursorPosition.Y -= 1;
228
+ } else {
229
+ newCursorPosition.X -= 1;
230
+ }
231
+
232
+ SetConsoleCursorPosition(hConsole, newCursorPosition);
233
+ return;
234
+ }
235
+ #endif
236
+ putc('\b', out);
237
+ }
238
+
239
+ static int estimateWidth(char32_t codepoint) {
240
+ #if defined(_WIN32)
241
+ (void)codepoint;
242
+ return 1;
243
+ #else
244
+ return wcwidth(codepoint);
245
+ #endif
246
+ }
247
+
248
+ static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) {
249
+ #if defined(_WIN32)
250
+ CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
251
+ if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) {
252
+ // go with the default
253
+ return expectedWidth;
254
+ }
255
+ COORD initialPosition = bufferInfo.dwCursorPosition;
256
+ DWORD nNumberOfChars = length;
257
+ WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
258
+
259
+ CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
260
+ GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
261
+
262
+ // Figure out our real position if we're in the last column
263
+ if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
264
+ DWORD nNumberOfChars;
265
+ WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL);
266
+ GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
267
+ }
268
+
269
+ int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
270
+ if (width < 0) {
271
+ width += newBufferInfo.dwSize.X;
272
+ }
273
+ return width;
274
+ #else
275
+ // We can trust expectedWidth if we've got one
276
+ if (expectedWidth >= 0 || tty == nullptr) {
277
+ fwrite(utf8_codepoint, length, 1, out);
278
+ return expectedWidth;
279
+ }
280
+
281
+ fputs("\033[6n", tty); // Query cursor position
282
+ int x1;
283
+ int y1;
284
+ int x2;
285
+ int y2;
286
+ int results = 0;
287
+ results = fscanf(tty, "\033[%d;%dR", &y1, &x1);
288
+
289
+ fwrite(utf8_codepoint, length, 1, tty);
290
+
291
+ fputs("\033[6n", tty); // Query cursor position
292
+ results += fscanf(tty, "\033[%d;%dR", &y2, &x2);
293
+
294
+ if (results != 4) {
295
+ return expectedWidth;
296
+ }
297
+
298
+ int width = x2 - x1;
299
+ if (width < 0) {
300
+ // Calculate the width considering text wrapping
301
+ struct winsize w;
302
+ ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
303
+ width += w.ws_col;
304
+ }
305
+ return width;
306
+ #endif
307
+ }
308
+
309
+ static void replace_last(char ch) {
310
+ #if defined(_WIN32)
311
+ pop_cursor();
312
+ put_codepoint(&ch, 1, 1);
313
+ #else
314
+ fprintf(out, "\b%c", ch);
315
+ #endif
316
+ }
317
+
318
+ static void append_utf8(char32_t ch, std::string & out) {
319
+ if (ch <= 0x7F) {
320
+ out.push_back(static_cast<unsigned char>(ch));
321
+ } else if (ch <= 0x7FF) {
322
+ out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
323
+ out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
324
+ } else if (ch <= 0xFFFF) {
325
+ out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
326
+ out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
327
+ out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
328
+ } else if (ch <= 0x10FFFF) {
329
+ out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
330
+ out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
331
+ out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
332
+ out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
333
+ } else {
334
+ // Invalid Unicode code point
335
+ }
336
+ }
337
+
338
+ // Helper function to remove the last UTF-8 character from a string
339
+ static void pop_back_utf8_char(std::string & line) {
340
+ if (line.empty()) {
341
+ return;
342
+ }
343
+
344
+ size_t pos = line.length() - 1;
345
+
346
+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
347
+ for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
348
+ if ((line[pos] & 0xC0) != 0x80) {
349
+ break; // Found the start of the character
350
+ }
351
+ }
352
+ line.erase(pos);
353
+ }
354
+
355
+ static bool readline_advanced(std::string & line, bool multiline_input) {
356
+ if (out != stdout) {
357
+ fflush(stdout);
358
+ }
359
+
360
+ line.clear();
361
+ std::vector<int> widths;
362
+ bool is_special_char = false;
363
+ bool end_of_stream = false;
364
+
365
+ char32_t input_char;
366
+ while (true) {
367
+ fflush(out); // Ensure all output is displayed before waiting for input
368
+ input_char = getchar32();
369
+
370
+ if (input_char == '\r' || input_char == '\n') {
371
+ break;
372
+ }
373
+
374
+ if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) {
375
+ end_of_stream = true;
376
+ break;
377
+ }
378
+
379
+ if (is_special_char) {
380
+ set_display(user_input);
381
+ replace_last(line.back());
382
+ is_special_char = false;
383
+ }
384
+
385
+ if (input_char == '\033') { // Escape sequence
386
+ char32_t code = getchar32();
387
+ if (code == '[' || code == 0x1B) {
388
+ // Discard the rest of the escape sequence
389
+ while ((code = getchar32()) != (char32_t) WEOF) {
390
+ if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
391
+ break;
392
+ }
393
+ }
394
+ }
395
+ } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
396
+ if (!widths.empty()) {
397
+ int count;
398
+ do {
399
+ count = widths.back();
400
+ widths.pop_back();
401
+ // Move cursor back, print space, and move cursor back again
402
+ for (int i = 0; i < count; i++) {
403
+ replace_last(' ');
404
+ pop_cursor();
405
+ }
406
+ pop_back_utf8_char(line);
407
+ } while (count == 0 && !widths.empty());
408
+ }
409
+ } else {
410
+ int offset = line.length();
411
+ append_utf8(input_char, line);
412
+ int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
413
+ if (width < 0) {
414
+ width = 0;
415
+ }
416
+ widths.push_back(width);
417
+ }
418
+
419
+ if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
420
+ set_display(prompt);
421
+ replace_last(line.back());
422
+ is_special_char = true;
423
+ }
424
+ }
425
+
426
+ bool has_more = multiline_input;
427
+ if (is_special_char) {
428
+ replace_last(' ');
429
+ pop_cursor();
430
+
431
+ char last = line.back();
432
+ line.pop_back();
433
+ if (last == '\\') {
434
+ line += '\n';
435
+ fputc('\n', out);
436
+ has_more = !has_more;
437
+ } else {
438
+ // llama will just eat the single space, it won't act as a space
439
+ if (line.length() == 1 && line.back() == ' ') {
440
+ line.clear();
441
+ pop_cursor();
442
+ }
443
+ has_more = false;
444
+ }
445
+ } else {
446
+ if (end_of_stream) {
447
+ has_more = false;
448
+ } else {
449
+ line += '\n';
450
+ fputc('\n', out);
451
+ }
452
+ }
453
+
454
+ fflush(out);
455
+ return has_more;
456
+ }
457
+
458
+ static bool readline_simple(std::string & line, bool multiline_input) {
459
+ #if defined(_WIN32)
460
+ std::wstring wline;
461
+ if (!std::getline(std::wcin, wline)) {
462
+ // Input stream is bad or EOF received
463
+ line.clear();
464
+ GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
465
+ return false;
466
+ }
467
+
468
+ int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
469
+ line.resize(size_needed);
470
+ WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
471
+ #else
472
+ if (!std::getline(std::cin, line)) {
473
+ // Input stream is bad or EOF received
474
+ line.clear();
475
+ return false;
476
+ }
477
+ #endif
478
+ if (!line.empty()) {
479
+ char last = line.back();
480
+ if (last == '/') { // Always return control on '/' symbol
481
+ line.pop_back();
482
+ return false;
483
+ }
484
+ if (last == '\\') { // '\\' changes the default action
485
+ line.pop_back();
486
+ multiline_input = !multiline_input;
487
+ }
488
+ }
489
+ line += '\n';
490
+
491
+ // By default, continue input if multiline_input is set
492
+ return multiline_input;
493
+ }
494
+
495
+ bool readline(std::string & line, bool multiline_input) {
496
+ set_display(user_input);
497
+
498
+ if (simple_io) {
499
+ return readline_simple(line, multiline_input);
500
+ }
501
+ return readline_advanced(line, multiline_input);
502
+ }
503
+
504
+ }
common/console.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Console functions
2
+
3
+ #pragma once
4
+
5
+ #include <string>
6
+
7
+ namespace console {
8
+ enum display_t {
9
+ reset = 0,
10
+ prompt,
11
+ user_input,
12
+ error
13
+ };
14
+
15
+ void init(bool use_simple_io, bool use_advanced_display);
16
+ void cleanup();
17
+ void set_display(display_t display);
18
+ bool readline(std::string & line, bool multiline_input);
19
+ }
common/json-schema-to-grammar.cpp ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "json-schema-to-grammar.h"
2
+ #include "common.h"
3
+
4
+ #include <algorithm>
5
+ #include <fstream>
6
+ #include <map>
7
+ #include <regex>
8
+ #include <sstream>
9
+ #include <string>
10
+ #include <unordered_map>
11
+ #include <unordered_set>
12
+ #include <vector>
13
+
14
+ using json = nlohmann::ordered_json;
15
+
16
+ static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
17
+ auto has_max = max_items != std::numeric_limits<int>::max();
18
+
19
+ if (min_items == 0 && max_items == 1) {
20
+ return item_rule + "?";
21
+ }
22
+
23
+ if (separator_rule.empty()) {
24
+ if (min_items == 1 && !has_max) {
25
+ return item_rule + "+";
26
+ } else if (min_items == 0 && !has_max) {
27
+ return item_rule + "*";
28
+ } else {
29
+ return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
30
+ }
31
+ }
32
+
33
+ auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
34
+ if (min_items == 0) {
35
+ result = "(" + result + ")?";
36
+ }
37
+ return result;
38
+ }
39
+
40
+ /* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
41
+ class string_view {
42
+ const std::string & _str;
43
+ const size_t _start;
44
+ const size_t _end;
45
+ public:
46
+ string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
47
+
48
+ size_t size() const {
49
+ return _end - _start;
50
+ }
51
+
52
+ size_t length() const {
53
+ return size();
54
+ }
55
+
56
+ operator std::string() const {
57
+ return str();
58
+ }
59
+
60
+ std::string str() const {
61
+ return _str.substr(_start, _end - _start);
62
+ }
63
+
64
+ string_view substr(size_t pos, size_t len = std::string::npos) const {
65
+ return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
66
+ }
67
+
68
+ char operator[](size_t pos) const {
69
+ auto index = _start + pos;
70
+ if (index >= _end) {
71
+ throw std::out_of_range("string_view index out of range");
72
+ }
73
+ return _str[_start + pos];
74
+ }
75
+
76
+ bool operator==(const string_view & other) const {
77
+ std::string this_str = *this;
78
+ std::string other_str = other;
79
+ return this_str == other_str;
80
+ }
81
+ };
82
+
83
+ static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
84
+ auto has_min = min_value != std::numeric_limits<int>::min();
85
+ auto has_max = max_value != std::numeric_limits<int>::max();
86
+
87
+ auto digit_range = [&](char from, char to) {
88
+ out << "[";
89
+ if (from == to) {
90
+ out << from;
91
+ } else {
92
+ out << from << "-" << to;
93
+ }
94
+ out << "]";
95
+ };
96
+ auto more_digits = [&](int min_digits, int max_digits) {
97
+ out << "[0-9]";
98
+ if (min_digits == max_digits && min_digits == 1) {
99
+ return;
100
+ }
101
+ out << "{";
102
+ out << min_digits;
103
+ if (max_digits != min_digits) {
104
+ out << ",";
105
+ if (max_digits != std::numeric_limits<int>::max()) {
106
+ out << max_digits;
107
+ }
108
+ }
109
+ out << "}";
110
+ };
111
+ std::function<void(const string_view &, const string_view &)> uniform_range =
112
+ [&](const string_view & from, const string_view & to) {
113
+ size_t i = 0;
114
+ while (i < from.length() && i < to.length() && from[i] == to[i]) {
115
+ i++;
116
+ }
117
+ if (i > 0) {
118
+ out << "\"" << from.substr(0, i).str() << "\"";
119
+ }
120
+ if (i < from.length() && i < to.length()) {
121
+ if (i > 0) {
122
+ out << " ";
123
+ }
124
+ auto sub_len = from.length() - i - 1;
125
+ if (sub_len > 0) {
126
+ auto from_sub = from.substr(i + 1);
127
+ auto to_sub = to.substr(i + 1);
128
+ auto sub_zeros = string_repeat("0", sub_len);
129
+ auto sub_nines = string_repeat("9", sub_len);
130
+
131
+ auto to_reached = false;
132
+ out << "(";
133
+ if (from_sub == sub_zeros) {
134
+ digit_range(from[i], to[i] - 1);
135
+ out << " ";
136
+ more_digits(sub_len, sub_len);
137
+ } else {
138
+ out << "[" << from[i] << "] ";
139
+ out << "(";
140
+ uniform_range(from_sub, sub_nines);
141
+ out << ")";
142
+ if (from[i] < to[i] - 1) {
143
+ out << " | ";
144
+ if (to_sub == sub_nines) {
145
+ digit_range(from[i] + 1, to[i]);
146
+ to_reached = true;
147
+ } else {
148
+ digit_range(from[i] + 1, to[i] - 1);
149
+ }
150
+ out << " ";
151
+ more_digits(sub_len, sub_len);
152
+ }
153
+ }
154
+ if (!to_reached) {
155
+ out << " | ";
156
+ digit_range(to[i], to[i]);
157
+ out << " ";
158
+ uniform_range(sub_zeros, to_sub);
159
+ }
160
+ out << ")";
161
+ } else {
162
+ out << "[" << from[i] << "-" << to[i] << "]";
163
+ }
164
+ }
165
+ };
166
+
167
+ if (has_min && has_max) {
168
+ if (min_value < 0 && max_value < 0) {
169
+ out << "\"-\" (";
170
+ _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
171
+ out << ")";
172
+ return;
173
+ }
174
+
175
+ if (min_value < 0) {
176
+ out << "\"-\" (";
177
+ _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
178
+ out << ") | ";
179
+ min_value = 0;
180
+ }
181
+
182
+ auto min_s = std::to_string(min_value);
183
+ auto max_s = std::to_string(max_value);
184
+ auto min_digits = min_s.length();
185
+ auto max_digits = max_s.length();
186
+
187
+ for (auto digits = min_digits; digits < max_digits; digits++) {
188
+ uniform_range(min_s, string_repeat("9", digits));
189
+ min_s = "1" + string_repeat("0", digits);
190
+ out << " | ";
191
+ }
192
+ uniform_range(min_s, max_s);
193
+ return;
194
+ }
195
+
196
+ auto less_decimals = std::max(decimals_left - 1, 1);
197
+
198
+ if (has_min) {
199
+ if (min_value < 0) {
200
+ out << "\"-\" (";
201
+ _build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
202
+ out << ") | [0] | [1-9] ";
203
+ more_digits(0, decimals_left - 1);
204
+ } else if (min_value == 0) {
205
+ if (top_level) {
206
+ out << "[0] | [1-9] ";
207
+ more_digits(0, less_decimals);
208
+ } else {
209
+ more_digits(1, decimals_left);
210
+ }
211
+ } else if (min_value <= 9) {
212
+ char c = '0' + min_value;
213
+ auto range_start = top_level ? '1' : '0';
214
+ if (c > range_start) {
215
+ digit_range(range_start, c - 1);
216
+ out << " ";
217
+ more_digits(1, less_decimals);
218
+ out << " | ";
219
+ }
220
+ digit_range(c, '9');
221
+ out << " ";
222
+ more_digits(0, less_decimals);
223
+ } else {
224
+ auto min_s = std::to_string(min_value);
225
+ auto len = min_s.length();
226
+ auto c = min_s[0];
227
+
228
+ if (c > '1') {
229
+ digit_range(top_level ? '1' : '0', c - 1);
230
+ out << " ";
231
+ more_digits(len, less_decimals);
232
+ out << " | ";
233
+ }
234
+ digit_range(c, c);
235
+ out << " (";
236
+ _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
237
+ out << ")";
238
+ if (c < '9') {
239
+ out << " | ";
240
+ digit_range(c + 1, '9');
241
+ out << " ";
242
+ more_digits(len - 1, less_decimals);
243
+ }
244
+ }
245
+ return;
246
+ }
247
+
248
+ if (has_max) {
249
+ if (max_value >= 0) {
250
+ if (top_level) {
251
+ out << "\"-\" [1-9] ";
252
+ more_digits(0, less_decimals);
253
+ out << " | ";
254
+ }
255
+ _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
256
+ } else {
257
+ out << "\"-\" (";
258
+ _build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
259
+ out << ")";
260
+ }
261
+ return;
262
+ }
263
+
264
+ throw std::runtime_error("At least one of min_value or max_value must be set");
265
+ }
266
+
267
+ const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}";
268
+
269
+ struct BuiltinRule {
270
+ std::string content;
271
+ std::vector<std::string> deps;
272
+ };
273
+
274
+ std::unordered_map<std::string, BuiltinRule> PRIMITIVE_RULES = {
275
+ {"boolean", {"(\"true\" | \"false\") space", {}}},
276
+ {"decimal-part", {"[0-9]{1,16}", {}}},
277
+ {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
278
+ {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
279
+ {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
280
+ {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
281
+ {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
282
+ {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
283
+ {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
284
+ {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
285
+ {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
286
+ {"null", {"\"null\" space", {}}},
287
+ };
288
+
289
+ std::unordered_map<std::string, BuiltinRule> STRING_FORMAT_RULES = {
290
+ {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
291
+ {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
292
+ {"date-time", {"date \"T\" time", {"date", "time"}}},
293
+ {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
294
+ {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
295
+ {"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
296
+ };
297
+
298
+ static bool is_reserved_name(const std::string & name) {
299
+ static std::unordered_set<std::string> RESERVED_NAMES;
300
+ if (RESERVED_NAMES.empty()) {
301
+ RESERVED_NAMES.insert("root");
302
+ for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
303
+ for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first);
304
+ }
305
+ return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
306
+ }
307
+
308
+ std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
309
+ std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
310
+ std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
311
+ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
312
+ {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
313
+ };
314
+
315
+ std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
316
+ std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
317
+
318
+ static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
319
+ std::smatch match;
320
+ std::string result;
321
+
322
+ std::string::const_iterator searchStart(input.cbegin());
323
+ std::string::const_iterator searchEnd(input.cend());
324
+
325
+ while (std::regex_search(searchStart, searchEnd, match, regex)) {
326
+ result.append(searchStart, searchStart + match.position());
327
+ result.append(replacement(match));
328
+ searchStart = match.suffix().first;
329
+ }
330
+
331
+ result.append(searchStart, searchEnd);
332
+
333
+ return result;
334
+ }
335
+
336
+ static std::string format_literal(const std::string & literal) {
337
+ std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
338
+ char c = match.str()[0];
339
+ return GRAMMAR_LITERAL_ESCAPES.at(c);
340
+ });
341
+ return "\"" + escaped + "\"";
342
+ }
343
+
344
+ class SchemaConverter {
345
+ private:
346
+ friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
347
+ std::function<json(const std::string &)> _fetch_json;
348
+ bool _dotall;
349
+ std::map<std::string, std::string> _rules;
350
+ std::unordered_map<std::string, json> _refs;
351
+ std::unordered_set<std::string> _refs_being_resolved;
352
+ std::vector<std::string> _errors;
353
+ std::vector<std::string> _warnings;
354
+
355
+ std::string _add_rule(const std::string & name, const std::string & rule) {
356
+ std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
357
+ if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
358
+ _rules[esc_name] = rule;
359
+ return esc_name;
360
+ } else {
361
+ int i = 0;
362
+ while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
363
+ i++;
364
+ }
365
+ std::string key = esc_name + std::to_string(i);
366
+ _rules[key] = rule;
367
+ return key;
368
+ }
369
+ }
370
+
371
+ std::string _generate_union_rule(const std::string & name, const std::vector<json> & alt_schemas) {
372
+ std::vector<std::string> rules;
373
+ for (size_t i = 0; i < alt_schemas.size(); i++) {
374
+ rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
375
+ }
376
+ return string_join(rules, " | ");
377
+ }
378
+
379
+ std::string _visit_pattern(const std::string & pattern, const std::string & name) {
380
+ if (!(pattern.front() == '^' && pattern.back() == '$')) {
381
+ _errors.push_back("Pattern must start with '^' and end with '$'");
382
+ return "";
383
+ }
384
+ std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
385
+ std::unordered_map<std::string, std::string> sub_rule_ids;
386
+
387
+ size_t i = 0;
388
+ size_t length = sub_pattern.length();
389
+
390
+ using literal_or_rule = std::pair<std::string, bool>;
391
+ auto to_rule = [&](const literal_or_rule & ls) {
392
+ auto is_literal = ls.second;
393
+ auto s = ls.first;
394
+ return is_literal ? "\"" + s + "\"" : s;
395
+ };
396
+ std::function<literal_or_rule()> transform = [&]() -> literal_or_rule {
397
+ size_t start = i;
398
+ std::vector<literal_or_rule> seq;
399
+
400
+ auto get_dot = [&]() {
401
+ std::string rule;
402
+ if (_dotall) {
403
+ rule = "[\\U00000000-\\U0010FFFF]";
404
+ } else {
405
+ rule = "[^\\x0A\\x0D]";
406
+ }
407
+ return _add_rule("dot", rule);
408
+ };
409
+
410
+ // Joins the sequence, merging consecutive literals together.
411
+ auto join_seq = [&]() {
412
+ std::vector<literal_or_rule> ret;
413
+
414
+ std::string literal;
415
+ auto flush_literal = [&]() {
416
+ if (literal.empty()) {
417
+ return false;
418
+ }
419
+ ret.emplace_back(literal, true);
420
+ literal.clear();
421
+ return true;
422
+ };
423
+
424
+ for (const auto & item : seq) {
425
+ auto is_literal = item.second;
426
+ if (is_literal) {
427
+ literal += item.first;
428
+ } else {
429
+ flush_literal();
430
+ ret.push_back(item);
431
+ }
432
+ }
433
+ flush_literal();
434
+
435
+ std::vector<std::string> results;
436
+ for (const auto & item : ret) {
437
+ results.push_back(to_rule(item));
438
+ }
439
+ return std::make_pair(string_join(results, " "), false);
440
+ };
441
+
442
+ while (i < length) {
443
+ char c = sub_pattern[i];
444
+ if (c == '.') {
445
+ seq.emplace_back(get_dot(), false);
446
+ i++;
447
+ } else if (c == '(') {
448
+ i++;
449
+ if (i < length) {
450
+ if (sub_pattern[i] == '?') {
451
+ _warnings.push_back("Unsupported pattern syntax");
452
+ }
453
+ }
454
+ seq.emplace_back("(" + to_rule(transform()) + ")", false);
455
+ } else if (c == ')') {
456
+ i++;
457
+ if (start > 0 && sub_pattern[start - 1] != '(') {
458
+ _errors.push_back("Unbalanced parentheses");
459
+ }
460
+ return join_seq();
461
+ } else if (c == '[') {
462
+ std::string square_brackets = std::string(1, c);
463
+ i++;
464
+ while (i < length && sub_pattern[i] != ']') {
465
+ if (sub_pattern[i] == '\\') {
466
+ square_brackets += sub_pattern.substr(i, 2);
467
+ i += 2;
468
+ } else {
469
+ square_brackets += sub_pattern[i];
470
+ i++;
471
+ }
472
+ }
473
+ if (i >= length) {
474
+ _errors.push_back("Unbalanced square brackets");
475
+ }
476
+ square_brackets += ']';
477
+ i++;
478
+ seq.emplace_back(square_brackets, false);
479
+ } else if (c == '|') {
480
+ seq.emplace_back("|", false);
481
+ i++;
482
+ } else if (c == '*' || c == '+' || c == '?') {
483
+ seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
484
+ i++;
485
+ } else if (c == '{') {
486
+ std::string curly_brackets = std::string(1, c);
487
+ i++;
488
+ while (i < length && sub_pattern[i] != '}') {
489
+ curly_brackets += sub_pattern[i];
490
+ i++;
491
+ }
492
+ if (i >= length) {
493
+ _errors.push_back("Unbalanced curly brackets");
494
+ }
495
+ curly_brackets += '}';
496
+ i++;
497
+ auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
498
+ int min_times = 0;
499
+ int max_times = std::numeric_limits<int>::max();
500
+ try {
501
+ if (nums.size() == 1) {
502
+ min_times = max_times = std::stoi(nums[0]);
503
+ } else if (nums.size() != 2) {
504
+ _errors.push_back("Wrong number of values in curly brackets");
505
+ } else {
506
+ if (!nums[0].empty()) {
507
+ min_times = std::stoi(nums[0]);
508
+ }
509
+ if (!nums[1].empty()) {
510
+ max_times = std::stoi(nums[1]);
511
+ }
512
+ }
513
+ } catch (const std::invalid_argument & e) {
514
+ _errors.push_back("Invalid number in curly brackets");
515
+ return std::make_pair("", false);
516
+ }
517
+ auto &last = seq.back();
518
+ auto &sub = last.first;
519
+ auto sub_is_literal = last.second;
520
+
521
+ if (!sub_is_literal) {
522
+ std::string & sub_id = sub_rule_ids[sub];
523
+ if (sub_id.empty()) {
524
+ sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
525
+ }
526
+ sub = sub_id;
527
+ }
528
+ seq.back().first = build_repetition(
529
+ sub_is_literal ? "\"" + sub + "\"" : sub,
530
+ min_times,
531
+ max_times,
532
+ ""
533
+ );
534
+ seq.back().second = false;
535
+ } else {
536
+ std::string literal;
537
+ auto is_non_literal = [&](char c) {
538
+ return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
539
+ };
540
+ while (i < length) {
541
+ if (sub_pattern[i] == '\\' && i < length - 1) {
542
+ char next = sub_pattern[i + 1];
543
+ if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
544
+ i++;
545
+ literal += sub_pattern[i];
546
+ i++;
547
+ } else {
548
+ literal += sub_pattern.substr(i, 2);
549
+ i += 2;
550
+ }
551
+ } else if (sub_pattern[i] == '"') {
552
+ literal += "\\\"";
553
+ i++;
554
+ } else if (!is_non_literal(sub_pattern[i]) &&
555
+ (i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
556
+ literal += sub_pattern[i];
557
+ i++;
558
+ } else {
559
+ break;
560
+ }
561
+ }
562
+ if (!literal.empty()) {
563
+ seq.emplace_back(literal, true);
564
+ }
565
+ }
566
+ }
567
+ return join_seq();
568
+ };
569
+ return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
570
+ }
571
+
572
+ /*
573
+ Returns a rule that matches a JSON string that is none of the provided strings
574
+
575
+ not_strings({"a"})
576
+ -> ["] ( [a] char+ | [^"a] char* )? ["] space
577
+ not_strings({"and", "also"})
578
+ -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
579
+ */
580
+ std::string _not_strings(const std::vector<std::string> & strings) {
581
+
582
+ struct TrieNode {
583
+ std::map<char, TrieNode> children;
584
+ bool is_end_of_string;
585
+
586
+ TrieNode() : is_end_of_string(false) {}
587
+
588
+ void insert(const std::string & string) {
589
+ auto node = this;
590
+ for (char c : string) {
591
+ node = &node->children[c];
592
+ }
593
+ node->is_end_of_string = true;
594
+ }
595
+ };
596
+
597
+ TrieNode trie;
598
+ for (const auto & s : strings) {
599
+ trie.insert(s);
600
+ }
601
+
602
+ std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
603
+ std::ostringstream out;
604
+ out << "[\"] ( ";
605
+ std::function<void(const TrieNode &)> visit = [&](const TrieNode & node) {
606
+ std::ostringstream rejects;
607
+ auto first = true;
608
+ for (const auto & kv : node.children) {
609
+ rejects << kv.first;
610
+ if (first) {
611
+ first = false;
612
+ } else {
613
+ out << " | ";
614
+ }
615
+ out << "[" << kv.first << "]";
616
+ if (!kv.second.children.empty()) {
617
+ out << " (";
618
+ visit(kv.second);
619
+ out << ")";
620
+ } else if (kv.second.is_end_of_string) {
621
+ out << " " << char_rule << "+";
622
+ }
623
+ }
624
+ if (!node.children.empty()) {
625
+ if (!first) {
626
+ out << " | ";
627
+ }
628
+ out << "[^\"" << rejects.str() << "] " << char_rule << "*";
629
+ }
630
+ };
631
+ visit(trie);
632
+
633
+ out << " )";
634
+ if (!trie.is_end_of_string) {
635
+ out << "?";
636
+ }
637
+ out << " [\"] space";
638
+ return out.str();
639
+ }
640
+
641
+ std::string _resolve_ref(const std::string & ref) {
642
+ std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
643
+ if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
644
+ _refs_being_resolved.insert(ref);
645
+ json resolved = _refs[ref];
646
+ ref_name = visit(resolved, ref_name);
647
+ _refs_being_resolved.erase(ref);
648
+ }
649
+ return ref_name;
650
+ }
651
+
652
+ std::string _build_object_rule(
653
+ const std::vector<std::pair<std::string, json>> & properties,
654
+ const std::unordered_set<std::string> & required,
655
+ const std::string & name,
656
+ const json & additional_properties)
657
+ {
658
+ std::vector<std::string> required_props;
659
+ std::vector<std::string> optional_props;
660
+ std::unordered_map<std::string, std::string> prop_kv_rule_names;
661
+ std::vector<std::string> prop_names;
662
+ for (const auto & kv : properties) {
663
+ const auto &prop_name = kv.first;
664
+ const auto &prop_schema = kv.second;
665
+
666
+ std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
667
+ prop_kv_rule_names[prop_name] = _add_rule(
668
+ name + (name.empty() ? "" : "-") + prop_name + "-kv",
669
+ format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
670
+ );
671
+ if (required.find(prop_name) != required.end()) {
672
+ required_props.push_back(prop_name);
673
+ } else {
674
+ optional_props.push_back(prop_name);
675
+ }
676
+ prop_names.push_back(prop_name);
677
+ }
678
+ if ((additional_properties.is_boolean() && additional_properties.get<bool>()) || additional_properties.is_object()) {
679
+ std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
680
+ std::string value_rule =
681
+ additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
682
+ : _add_primitive("value", PRIMITIVE_RULES.at("value"));
683
+
684
+ auto key_rule =
685
+ prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
686
+ : _add_rule(sub_name + "-k", _not_strings(prop_names));
687
+ std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
688
+ prop_kv_rule_names["*"] = kv_rule;
689
+ optional_props.push_back("*");
690
+ }
691
+
692
+ std::string rule = "\"{\" space ";
693
+ for (size_t i = 0; i < required_props.size(); i++) {
694
+ if (i > 0) {
695
+ rule += " \",\" space ";
696
+ }
697
+ rule += prop_kv_rule_names[required_props[i]];
698
+ }
699
+
700
+ if (!optional_props.empty()) {
701
+ rule += " (";
702
+ if (!required_props.empty()) {
703
+ rule += " \",\" space ( ";
704
+ }
705
+
706
+ std::function<std::string(const std::vector<std::string> &, bool)> get_recursive_refs = [&](const std::vector<std::string> & ks, bool first_is_optional) {
707
+ std::string res;
708
+ if (ks.empty()) {
709
+ return res;
710
+ }
711
+ std::string k = ks[0];
712
+ std::string kv_rule_name = prop_kv_rule_names[k];
713
+ std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
714
+ if (first_is_optional) {
715
+ res = comma_ref + (k == "*" ? "*" : "?");
716
+ } else {
717
+ res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
718
+ }
719
+ if (ks.size() > 1) {
720
+ res += " " + _add_rule(
721
+ name + (name.empty() ? "" : "-") + k + "-rest",
722
+ get_recursive_refs(std::vector<std::string>(ks.begin() + 1, ks.end()), true)
723
+ );
724
+ }
725
+ return res;
726
+ };
727
+
728
+ for (size_t i = 0; i < optional_props.size(); i++) {
729
+ if (i > 0) {
730
+ rule += " | ";
731
+ }
732
+ rule += get_recursive_refs(std::vector<std::string>(optional_props.begin() + i, optional_props.end()), false);
733
+ }
734
+ if (!required_props.empty()) {
735
+ rule += " )";
736
+ }
737
+ rule += " )?";
738
+ }
739
+
740
+ rule += " \"}\" space";
741
+
742
+ return rule;
743
+ }
744
+
745
+ std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
746
+ auto n = _add_rule(name, rule.content);
747
+ for (const auto & dep : rule.deps) {
748
+ BuiltinRule dep_rule;
749
+ auto it = PRIMITIVE_RULES.find(dep);
750
+ if (it == PRIMITIVE_RULES.end()) {
751
+ it = STRING_FORMAT_RULES.find(dep);
752
+ if (it == STRING_FORMAT_RULES.end()) {
753
+ _errors.push_back("Rule " + dep + " not known");
754
+ continue;
755
+ }
756
+ }
757
+ if (_rules.find(dep) == _rules.end()) {
758
+ _add_primitive(dep, it->second);
759
+ }
760
+ }
761
+ return n;
762
+ }
763
+
764
+ public:
765
+ SchemaConverter(
766
+ const std::function<json(const std::string &)> & fetch_json,
767
+ bool dotall)
768
+ : _fetch_json(fetch_json), _dotall(dotall)
769
+ {
770
+ _rules["space"] = SPACE_RULE;
771
+ }
772
+
773
+ void resolve_refs(json & schema, const std::string & url) {
774
+ /*
775
+ * Resolves all $ref fields in the given schema, fetching any remote schemas,
776
+ * replacing each $ref with absolute reference URL and populates _refs with the
777
+ * respective referenced (sub)schema dictionaries.
778
+ */
779
+ std::function<void(json &)> visit_refs = [&](json & n) {
780
+ if (n.is_array()) {
781
+ for (auto & x : n) {
782
+ visit_refs(x);
783
+ }
784
+ } else if (n.is_object()) {
785
+ if (n.contains("$ref")) {
786
+ std::string ref = n["$ref"];
787
+ if (_refs.find(ref) == _refs.end()) {
788
+ json target;
789
+ if (ref.find("https://") == 0) {
790
+ std::string base_url = ref.substr(0, ref.find('#'));
791
+ auto it = _refs.find(base_url);
792
+ if (it != _refs.end()) {
793
+ target = it->second;
794
+ } else {
795
+ // Fetch the referenced schema and resolve its refs
796
+ auto referenced = _fetch_json(ref);
797
+ resolve_refs(referenced, base_url);
798
+ _refs[base_url] = referenced;
799
+ }
800
+ if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
801
+ return;
802
+ }
803
+ } else if (ref.find("#/") == 0) {
804
+ target = schema;
805
+ n["$ref"] = url + ref;
806
+ ref = url + ref;
807
+ } else {
808
+ _errors.push_back("Unsupported ref: " + ref);
809
+ return;
810
+ }
811
+ std::string pointer = ref.substr(ref.find('#') + 1);
812
+ std::vector<std::string> tokens = string_split(pointer, "/");
813
+ for (size_t i = 1; i < tokens.size(); ++i) {
814
+ std::string sel = tokens[i];
815
+ if (target.is_null() || !target.contains(sel)) {
816
+ _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
817
+ return;
818
+ }
819
+ target = target[sel];
820
+ }
821
+ _refs[ref] = target;
822
+ }
823
+ } else {
824
+ for (auto & kv : n.items()) {
825
+ visit_refs(kv.value());
826
+ }
827
+ }
828
+ }
829
+ };
830
+
831
+ visit_refs(schema);
832
+ }
833
+
834
+ std::string _generate_constant_rule(const json & value) {
835
+ return format_literal(value.dump());
836
+ }
837
+
838
+ std::string visit(const json & schema, const std::string & name) {
839
+ json schema_type = schema.contains("type") ? schema["type"] : json();
840
+ std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
841
+ std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
842
+
843
+ if (schema.contains("$ref")) {
844
+ return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
845
+ } else if (schema.contains("oneOf") || schema.contains("anyOf")) {
846
+ std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
847
+ return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
848
+ } else if (schema_type.is_array()) {
849
+ std::vector<json> schema_types;
850
+ for (const auto & t : schema_type) {
851
+ json schema_copy(schema);
852
+ schema_copy["type"] = t;
853
+ schema_types.push_back(schema_copy);
854
+ }
855
+ return _add_rule(rule_name, _generate_union_rule(name, schema_types));
856
+ } else if (schema.contains("const")) {
857
+ return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
858
+ } else if (schema.contains("enum")) {
859
+ std::vector<std::string> enum_values;
860
+ for (const auto & v : schema["enum"]) {
861
+ enum_values.push_back(_generate_constant_rule(v));
862
+ }
863
+ return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
864
+ } else if ((schema_type.is_null() || schema_type == "object")
865
+ && (schema.contains("properties") ||
866
+ (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
867
+ std::unordered_set<std::string> required;
868
+ if (schema.contains("required") && schema["required"].is_array()) {
869
+ for (const auto & item : schema["required"]) {
870
+ if (item.is_string()) {
871
+ required.insert(item.get<std::string>());
872
+ }
873
+ }
874
+ }
875
+ std::vector<std::pair<std::string, json>> properties;
876
+ if (schema.contains("properties")) {
877
+ for (const auto & prop : schema["properties"].items()) {
878
+ properties.emplace_back(prop.key(), prop.value());
879
+ }
880
+ }
881
+ return _add_rule(rule_name,
882
+ _build_object_rule(
883
+ properties, required, name,
884
+ schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
885
+ } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
886
+ std::unordered_set<std::string> required;
887
+ std::vector<std::pair<std::string, json>> properties;
888
+ std::string hybrid_name = name;
889
+ std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
890
+ if (comp_schema.contains("$ref")) {
891
+ add_component(_refs[comp_schema["$ref"]], is_required);
892
+ } else if (comp_schema.contains("properties")) {
893
+ for (const auto & prop : comp_schema["properties"].items()) {
894
+ properties.emplace_back(prop.key(), prop.value());
895
+ if (is_required) {
896
+ required.insert(prop.key());
897
+ }
898
+ }
899
+ } else {
900
+ // todo warning
901
+ }
902
+ };
903
+ for (auto & t : schema["allOf"]) {
904
+ if (t.contains("anyOf")) {
905
+ for (auto & tt : t["anyOf"]) {
906
+ add_component(tt, false);
907
+ }
908
+ } else {
909
+ add_component(t, true);
910
+ }
911
+ }
912
+ return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
913
+ } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
914
+ json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
915
+ if (items.is_array()) {
916
+ std::string rule = "\"[\" space ";
917
+ for (size_t i = 0; i < items.size(); i++) {
918
+ if (i > 0) {
919
+ rule += " \",\" space ";
920
+ }
921
+ rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
922
+ }
923
+ rule += " \"]\" space";
924
+ return _add_rule(rule_name, rule);
925
+ } else {
926
+ std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
927
+ int min_items = schema.contains("minItems") ? schema["minItems"].get<int>() : 0;
928
+ json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
929
+ int max_items = max_items_json.is_number_integer() ? max_items_json.get<int>() : std::numeric_limits<int>::max();
930
+
931
+ return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
932
+ }
933
+ } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
934
+ return _visit_pattern(schema["pattern"], rule_name);
935
+ } else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
936
+ return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
937
+ } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
938
+ auto prim_name = schema_format + "-string";
939
+ return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
940
+ } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
941
+ std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
942
+ int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
943
+ int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
944
+ return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
945
+ } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
946
+ int min_value = std::numeric_limits<int>::min();
947
+ int max_value = std::numeric_limits<int>::max();
948
+ if (schema.contains("minimum")) {
949
+ min_value = schema["minimum"].get<int>();
950
+ } else if (schema.contains("exclusiveMinimum")) {
951
+ min_value = schema["exclusiveMinimum"].get<int>() + 1;
952
+ }
953
+ if (schema.contains("maximum")) {
954
+ max_value = schema["maximum"].get<int>();
955
+ } else if (schema.contains("exclusiveMaximum")) {
956
+ max_value = schema["exclusiveMaximum"].get<int>() - 1;
957
+ }
958
+ std::stringstream out;
959
+ out << "(";
960
+ _build_min_max_int(min_value, max_value, out);
961
+ out << ") space";
962
+ return _add_rule(rule_name, out.str());
963
+ } else if (schema.empty() || schema_type == "object") {
964
+ return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
965
+ } else {
966
+ if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
967
+ _errors.push_back("Unrecognized schema: " + schema.dump());
968
+ return "";
969
+ }
970
+ // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
971
+ return _add_primitive(rule_name == "root" ? "root" : schema_type.get<std::string>(), PRIMITIVE_RULES.at(schema_type.get<std::string>()));
972
+ }
973
+ }
974
+
975
+ void check_errors() {
976
+ if (!_errors.empty()) {
977
+ throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
978
+ }
979
+ if (!_warnings.empty()) {
980
+ fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
981
+ }
982
+ }
983
+
984
+ std::string format_grammar() {
985
+ std::stringstream ss;
986
+ for (const auto & kv : _rules) {
987
+ ss << kv.first << " ::= " << kv.second << std::endl;
988
+ }
989
+ return ss.str();
990
+ }
991
+ };
992
+
993
+ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
994
+ #ifdef LLAMA_USE_LLGUIDANCE
995
+ if (!force_gbnf) {
996
+ return "%llguidance {}\nstart: %json " + schema.dump();
997
+ }
998
+ #else
999
+ (void)force_gbnf;
1000
+ #endif // LLAMA_USE_LLGUIDANCE
1001
+ return build_grammar([&](const common_grammar_builder & callbacks) {
1002
+ auto copy = schema;
1003
+ callbacks.resolve_refs(copy);
1004
+ callbacks.add_schema("", copy);
1005
+ });
1006
+ }
1007
+
1008
+ std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
1009
+ SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
1010
+ common_grammar_builder builder {
1011
+ /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
1012
+ return converter._add_rule(name, rule);
1013
+ },
1014
+ /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
1015
+ return converter.visit(schema, name == "root" ? "" : name);
1016
+ },
1017
+ /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
1018
+ converter.resolve_refs(schema, "");
1019
+ }
1020
+ };
1021
+ cb(builder);
1022
+ converter.check_errors();
1023
+ return converter.format_grammar();
1024
+ }
common/json-schema-to-grammar.h ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ // Change JSON_ASSERT from assert() to GGML_ASSERT:
5
+ #define JSON_ASSERT GGML_ASSERT
6
+ #include "json.hpp"
7
+
8
+ std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
9
+ bool force_gbnf = false);
10
+
11
+ struct common_grammar_builder {
12
+ std::function<std::string(const std::string &, const std::string &)> add_rule;
13
+ std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
14
+ std::function<void(nlohmann::ordered_json &)> resolve_refs;
15
+ };
16
+
17
+ struct common_grammar_options {
18
+ bool dotall = false;
19
+ };
20
+
21
+ std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});
common/json.hpp ADDED
The diff for this file is too large to render. See raw diff
 
common/llguidance.cpp ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "sampling.h"
2
+ #include "log.h"
3
+
4
+ #ifdef LLAMA_USE_LLGUIDANCE
5
+
6
+ # include "llguidance.h"
7
+ # include <cmath>
8
+
9
+ struct llama_sampler_llg {
10
+ const llama_vocab * vocab;
11
+ std::string grammar_kind;
12
+ std::string grammar_data;
13
+ LlgTokenizer * tokenizer;
14
+ LlgConstraint * grammar;
15
+ LlgMaskResult llg_res;
16
+ bool has_llg_res;
17
+ };
18
+
19
+ static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
20
+ const char * grammar_data) {
21
+ LlgConstraintInit cinit;
22
+ llg_constraint_init_set_defaults(&cinit, tokenizer);
23
+ const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
24
+ if (log_level && *log_level) {
25
+ cinit.log_stderr_level = atoi(log_level);
26
+ }
27
+ auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
28
+ if (llg_get_error(c)) {
29
+ LOG_ERR("llg error: %s\n", llg_get_error(c));
30
+ llg_free_constraint(c);
31
+ return nullptr;
32
+ }
33
+ return c;
34
+ }
35
+
36
+ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
37
+ return "llguidance";
38
+ }
39
+
40
+ static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
41
+ auto * ctx = (llama_sampler_llg *) smpl->ctx;
42
+ if (ctx->grammar) {
43
+ LlgCommitResult res;
44
+ llg_commit_token(ctx->grammar, token, &res);
45
+ ctx->has_llg_res = false;
46
+ }
47
+ }
48
+
49
+ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
50
+ auto * ctx = (llama_sampler_llg *) smpl->ctx;
51
+ if (ctx->grammar) {
52
+ if (!ctx->has_llg_res) {
53
+ if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
54
+ ctx->has_llg_res = true;
55
+ } else {
56
+ LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
57
+ llg_free_constraint(ctx->grammar);
58
+ ctx->grammar = nullptr;
59
+ }
60
+ }
61
+ if (ctx->has_llg_res) {
62
+ if (ctx->llg_res.is_stop) {
63
+ for (size_t i = 0; i < cur_p->size; ++i) {
64
+ if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
65
+ cur_p->data[i].logit = -INFINITY;
66
+ }
67
+ }
68
+ } else {
69
+ const uint32_t * mask = ctx->llg_res.sample_mask;
70
+ for (size_t i = 0; i < cur_p->size; ++i) {
71
+ auto token = cur_p->data[i].id;
72
+ if ((mask[token / 32] & (1 << (token % 32))) == 0) {
73
+ cur_p->data[i].logit = -INFINITY;
74
+ }
75
+ }
76
+ }
77
+ }
78
+ }
79
+ }
80
+
81
+ static void llama_sampler_llg_reset(llama_sampler * smpl) {
82
+ auto * ctx = (llama_sampler_llg *) smpl->ctx;
83
+ if (!ctx->grammar) {
84
+ return;
85
+ }
86
+
87
+ auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
88
+ llg_free_constraint(ctx->grammar);
89
+ ctx->grammar = grammar_new;
90
+ ctx->has_llg_res = false;
91
+ }
92
+
93
+ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
94
+ const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
95
+
96
+ auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
97
+
98
+ // copy the state
99
+ {
100
+ auto * result_ctx = (llama_sampler_llg *) result->ctx;
101
+
102
+ if (ctx->grammar) {
103
+ result_ctx->grammar_kind = ctx->grammar_kind;
104
+ result_ctx->grammar_data = ctx->grammar_data;
105
+ result_ctx->grammar = llg_clone_constraint(ctx->grammar);
106
+ result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
107
+ }
108
+ }
109
+
110
+ return result;
111
+ }
112
+
113
+ static void llama_sampler_llg_free(llama_sampler * smpl) {
114
+ const auto * ctx = (llama_sampler_llg *) smpl->ctx;
115
+
116
+ if (ctx->grammar) {
117
+ llg_free_constraint(ctx->grammar);
118
+ llg_free_tokenizer(ctx->tokenizer);
119
+ }
120
+
121
+ delete ctx;
122
+ }
123
+
124
+ static llama_sampler_i llama_sampler_llg_i = {
125
+ /* .name = */ llama_sampler_llg_name,
126
+ /* .accept = */ llama_sampler_llg_accept_impl,
127
+ /* .apply = */ llama_sampler_llg_apply,
128
+ /* .reset = */ llama_sampler_llg_reset,
129
+ /* .clone = */ llama_sampler_llg_clone,
130
+ /* .free = */ llama_sampler_llg_free,
131
+ };
132
+
133
+ static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
134
+ uint32_t * output_tokens, size_t output_tokens_len) {
135
+ const llama_vocab * vocab = (const llama_vocab *) user_data;
136
+ int r = 0;
137
+ try {
138
+ r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
139
+ true);
140
+ } catch (const std::exception & e) {
141
+ GGML_ABORT("llama_tokenize failed: %s\n", e.what());
142
+ }
143
+ if (r < 0) {
144
+ return -r;
145
+ }
146
+ return r;
147
+ }
148
+
149
+ static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
150
+ // TODO store the tokenizer in the vocab somehow
151
+ static const llama_vocab * vocab_cache;
152
+ static LlgTokenizer * tokenizer_cache;
153
+
154
+ if (vocab_cache == vocab) {
155
+ return llg_clone_tokenizer(tokenizer_cache);
156
+ }
157
+
158
+ auto tok_eos = llama_vocab_eot(vocab);
159
+ if (tok_eos == LLAMA_TOKEN_NULL) {
160
+ tok_eos = llama_vocab_eos(vocab);
161
+ }
162
+
163
+ size_t vocab_size = llama_vocab_n_tokens(vocab);
164
+
165
+ auto token_lens = new uint32_t[vocab_size];
166
+ // we typically have ~7 bytes per token; let's go on the safe side here
167
+ auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
168
+ auto token_bytes = new uint8_t[token_bytes_size];
169
+
170
+ size_t offset = 0;
171
+ for (size_t i = 0; i < vocab_size; i++) {
172
+ size_t max_token = 1024;
173
+ if (token_bytes_size - offset < max_token) {
174
+ GGML_ABORT("token_bytes buffer too small\n");
175
+ }
176
+
177
+ llama_token token = i;
178
+ auto dp = (char *) token_bytes + offset;
179
+ auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
180
+ if (size < 0) {
181
+ GGML_ABORT("llama_detokenize failed\n");
182
+ }
183
+ if (size == 0) {
184
+ size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
185
+ if (size < 0) {
186
+ GGML_ABORT("llama_detokenize failed\n");
187
+ }
188
+ if (size != 0) {
189
+ *dp = '\xff'; // special token prefix marker
190
+ size += 1;
191
+ }
192
+ }
193
+
194
+ token_lens[i] = size;
195
+ offset += size;
196
+ }
197
+
198
+ LlgTokenizerInit tinit = {
199
+ /* .vocab_size = */ (uint32_t) vocab_size,
200
+ /* .tok_eos = */ (uint32_t) tok_eos,
201
+ /* .token_lens = */ token_lens,
202
+ /* .token_bytes = */ token_bytes,
203
+ /* .tokenizer_json = */ nullptr,
204
+ /* .tokenize_assumes_string = */ true,
205
+ /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
206
+ /* .use_approximate_greedy_tokenize_fn = */ false,
207
+ /* .tokenize_user_data = */ vocab,
208
+ };
209
+
210
+ char error_buffer[1024];
211
+ LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
212
+
213
+ delete[] token_bytes;
214
+ delete[] token_lens;
215
+
216
+ if (tokenizer == nullptr) {
217
+ LOG_ERR("llg tokenizer error: %s\n", error_buffer);
218
+ return tokenizer;
219
+ }
220
+
221
+ if (tokenizer_cache) {
222
+ llg_free_tokenizer(tokenizer_cache);
223
+ }
224
+ vocab_cache = vocab;
225
+ tokenizer_cache = tokenizer;
226
+
227
+ return llg_clone_tokenizer(tokenizer_cache);
228
+ }
229
+
230
+ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
231
+ const char * grammar_data) {
232
+ auto * ctx = new llama_sampler_llg;
233
+
234
+ if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
235
+ auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
236
+ *ctx = {
237
+ /* .vocab = */ vocab,
238
+ /* .grammar_kind = */ grammar_kind,
239
+ /* .grammar_data = */ grammar_data,
240
+ /* .tokenizer = */ tokenizer,
241
+ /* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
242
+ /* .llg_res = */ {},
243
+ /* .has_llg_res = */ false,
244
+ };
245
+ } else {
246
+ *ctx = {
247
+ /* .vocab = */ vocab,
248
+ /* .grammar_kind = */ {},
249
+ /* .grammar_data = */ {},
250
+ /* .tokenizer = */ nullptr,
251
+ /* .grammar = */ nullptr,
252
+ /* .llg_res = */ {},
253
+ /* .has_llg_res = */ false,
254
+ };
255
+ }
256
+
257
+ return llama_sampler_init(
258
+ /* .iface = */ &llama_sampler_llg_i,
259
+ /* .ctx = */ ctx
260
+ );
261
+ }
262
+
263
+ #else
264
+
265
+ llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
266
+ LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
267
+ return nullptr;
268
+ }
269
+
270
+ #endif // LLAMA_USE_LLGUIDANCE
common/log.cpp ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "log.h"
2
+
3
+ #include <chrono>
4
+ #include <condition_variable>
5
+ #include <cstdarg>
6
+ #include <cstdio>
7
+ #include <mutex>
8
+ #include <sstream>
9
+ #include <thread>
10
+ #include <vector>
11
+
12
+ int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
13
+
14
+ void common_log_set_verbosity_thold(int verbosity) {
15
+ common_log_verbosity_thold = verbosity;
16
+ }
17
+
18
+ static int64_t t_us() {
19
+ return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
20
+ }
21
+
22
+ // colors
23
+ enum common_log_col : int {
24
+ COMMON_LOG_COL_DEFAULT = 0,
25
+ COMMON_LOG_COL_BOLD,
26
+ COMMON_LOG_COL_RED,
27
+ COMMON_LOG_COL_GREEN,
28
+ COMMON_LOG_COL_YELLOW,
29
+ COMMON_LOG_COL_BLUE,
30
+ COMMON_LOG_COL_MAGENTA,
31
+ COMMON_LOG_COL_CYAN,
32
+ COMMON_LOG_COL_WHITE,
33
+ };
34
+
35
+ // disable colors by default
36
+ static std::vector<const char *> g_col = {
37
+ "",
38
+ "",
39
+ "",
40
+ "",
41
+ "",
42
+ "",
43
+ "",
44
+ "",
45
+ "",
46
+ };
47
+
48
+ struct common_log_entry {
49
+ enum ggml_log_level level;
50
+
51
+ bool prefix;
52
+
53
+ int64_t timestamp;
54
+
55
+ std::vector<char> msg;
56
+
57
+ // signals the worker thread to stop
58
+ bool is_end;
59
+
60
+ void print(FILE * file = nullptr) const {
61
+ FILE * fcur = file;
62
+ if (!fcur) {
63
+ // stderr displays DBG messages only when their verbosity level is not higher than the threshold
64
+ // these messages will still be logged to a file
65
+ if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) {
66
+ return;
67
+ }
68
+
69
+ fcur = stdout;
70
+
71
+ if (level != GGML_LOG_LEVEL_NONE) {
72
+ fcur = stderr;
73
+ }
74
+ }
75
+
76
+ if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) {
77
+ if (timestamp) {
78
+ // [M.s.ms.us]
79
+ fprintf(fcur, "%s%d.%02d.%03d.%03d%s ",
80
+ g_col[COMMON_LOG_COL_BLUE],
81
+ (int) (timestamp / 1000000 / 60),
82
+ (int) (timestamp / 1000000 % 60),
83
+ (int) (timestamp / 1000 % 1000),
84
+ (int) (timestamp % 1000),
85
+ g_col[COMMON_LOG_COL_DEFAULT]);
86
+ }
87
+
88
+ switch (level) {
89
+ case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break;
90
+ case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break;
91
+ case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break;
92
+ case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break;
93
+ default:
94
+ break;
95
+ }
96
+ }
97
+
98
+ fprintf(fcur, "%s", msg.data());
99
+
100
+ if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) {
101
+ fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]);
102
+ }
103
+
104
+ fflush(fcur);
105
+ }
106
+ };
107
+
108
+ struct common_log {
109
+ // default capacity - will be expanded if needed
110
+ common_log() : common_log(256) {}
111
+
112
+ common_log(size_t capacity) {
113
+ file = nullptr;
114
+ prefix = false;
115
+ timestamps = false;
116
+ running = false;
117
+ t_start = t_us();
118
+
119
+ // initial message size - will be expanded if longer messages arrive
120
+ entries.resize(capacity);
121
+ for (auto & entry : entries) {
122
+ entry.msg.resize(256);
123
+ }
124
+
125
+ head = 0;
126
+ tail = 0;
127
+
128
+ resume();
129
+ }
130
+
131
+ ~common_log() {
132
+ pause();
133
+ if (file) {
134
+ fclose(file);
135
+ }
136
+ }
137
+
138
+ private:
139
+ std::mutex mtx;
140
+ std::thread thrd;
141
+ std::condition_variable cv;
142
+
143
+ FILE * file;
144
+
145
+ bool prefix;
146
+ bool timestamps;
147
+ bool running;
148
+
149
+ int64_t t_start;
150
+
151
+ // ring buffer of entries
152
+ std::vector<common_log_entry> entries;
153
+ size_t head;
154
+ size_t tail;
155
+
156
+ // worker thread copies into this
157
+ common_log_entry cur;
158
+
159
+ public:
160
+ void add(enum ggml_log_level level, const char * fmt, va_list args) {
161
+ std::lock_guard<std::mutex> lock(mtx);
162
+
163
+ if (!running) {
164
+ // discard messages while the worker thread is paused
165
+ return;
166
+ }
167
+
168
+ auto & entry = entries[tail];
169
+
170
+ {
171
+ // cannot use args twice, so make a copy in case we need to expand the buffer
172
+ va_list args_copy;
173
+ va_copy(args_copy, args);
174
+
175
+ #if 1
176
+ const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args);
177
+ if (n >= entry.msg.size()) {
178
+ entry.msg.resize(n + 1);
179
+ vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy);
180
+ }
181
+ #else
182
+ // hack for bolding arguments
183
+
184
+ std::stringstream ss;
185
+ for (int i = 0; fmt[i] != 0; i++) {
186
+ if (fmt[i] == '%') {
187
+ ss << LOG_COL_BOLD;
188
+ while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++];
189
+ ss << LOG_COL_DEFAULT;
190
+ if (fmt[i] == 0) break;
191
+ }
192
+ ss << fmt[i];
193
+ }
194
+ const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args);
195
+ if (n >= entry.msg.size()) {
196
+ entry.msg.resize(n + 1);
197
+ vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy);
198
+ }
199
+ #endif
200
+ va_end(args_copy);
201
+ }
202
+
203
+ entry.level = level;
204
+ entry.prefix = prefix;
205
+ entry.timestamp = 0;
206
+ if (timestamps) {
207
+ entry.timestamp = t_us() - t_start;
208
+ }
209
+ entry.is_end = false;
210
+
211
+ tail = (tail + 1) % entries.size();
212
+ if (tail == head) {
213
+ // expand the buffer
214
+ std::vector<common_log_entry> new_entries(2*entries.size());
215
+
216
+ size_t new_tail = 0;
217
+
218
+ do {
219
+ new_entries[new_tail] = std::move(entries[head]);
220
+
221
+ head = (head + 1) % entries.size();
222
+ new_tail = (new_tail + 1);
223
+ } while (head != tail);
224
+
225
+ head = 0;
226
+ tail = new_tail;
227
+
228
+ for (size_t i = tail; i < new_entries.size(); i++) {
229
+ new_entries[i].msg.resize(256);
230
+ }
231
+
232
+ entries = std::move(new_entries);
233
+ }
234
+
235
+ cv.notify_one();
236
+ }
237
+
238
+ void resume() {
239
+ std::lock_guard<std::mutex> lock(mtx);
240
+
241
+ if (running) {
242
+ return;
243
+ }
244
+
245
+ running = true;
246
+
247
+ thrd = std::thread([this]() {
248
+ while (true) {
249
+ {
250
+ std::unique_lock<std::mutex> lock(mtx);
251
+ cv.wait(lock, [this]() { return head != tail; });
252
+
253
+ cur = entries[head];
254
+
255
+ head = (head + 1) % entries.size();
256
+ }
257
+
258
+ if (cur.is_end) {
259
+ break;
260
+ }
261
+
262
+ cur.print(); // stdout and stderr
263
+
264
+ if (file) {
265
+ cur.print(file);
266
+ }
267
+ }
268
+ });
269
+ }
270
+
271
+ void pause() {
272
+ {
273
+ std::lock_guard<std::mutex> lock(mtx);
274
+
275
+ if (!running) {
276
+ return;
277
+ }
278
+
279
+ running = false;
280
+
281
+ // push an entry to signal the worker thread to stop
282
+ {
283
+ auto & entry = entries[tail];
284
+ entry.is_end = true;
285
+
286
+ tail = (tail + 1) % entries.size();
287
+ }
288
+
289
+ cv.notify_one();
290
+ }
291
+
292
+ thrd.join();
293
+ }
294
+
295
+ void set_file(const char * path) {
296
+ pause();
297
+
298
+ if (file) {
299
+ fclose(file);
300
+ }
301
+
302
+ if (path) {
303
+ file = fopen(path, "w");
304
+ } else {
305
+ file = nullptr;
306
+ }
307
+
308
+ resume();
309
+ }
310
+
311
+ void set_colors(bool colors) {
312
+ pause();
313
+
314
+ if (colors) {
315
+ g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT;
316
+ g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD;
317
+ g_col[COMMON_LOG_COL_RED] = LOG_COL_RED;
318
+ g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN;
319
+ g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW;
320
+ g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE;
321
+ g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA;
322
+ g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN;
323
+ g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE;
324
+ } else {
325
+ for (size_t i = 0; i < g_col.size(); i++) {
326
+ g_col[i] = "";
327
+ }
328
+ }
329
+
330
+ resume();
331
+ }
332
+
333
+ void set_prefix(bool prefix) {
334
+ std::lock_guard<std::mutex> lock(mtx);
335
+
336
+ this->prefix = prefix;
337
+ }
338
+
339
+ void set_timestamps(bool timestamps) {
340
+ std::lock_guard<std::mutex> lock(mtx);
341
+
342
+ this->timestamps = timestamps;
343
+ }
344
+ };
345
+
346
+ //
347
+ // public API
348
+ //
349
+
350
+ struct common_log * common_log_init() {
351
+ return new common_log;
352
+ }
353
+
354
+ struct common_log * common_log_main() {
355
+ static struct common_log log;
356
+
357
+ return &log;
358
+ }
359
+
360
+ void common_log_pause(struct common_log * log) {
361
+ log->pause();
362
+ }
363
+
364
+ void common_log_resume(struct common_log * log) {
365
+ log->resume();
366
+ }
367
+
368
+ void common_log_free(struct common_log * log) {
369
+ delete log;
370
+ }
371
+
372
+ void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) {
373
+ va_list args;
374
+ va_start(args, fmt);
375
+ log->add(level, fmt, args);
376
+ va_end(args);
377
+ }
378
+
379
+ void common_log_set_file(struct common_log * log, const char * file) {
380
+ log->set_file(file);
381
+ }
382
+
383
+ void common_log_set_colors(struct common_log * log, bool colors) {
384
+ log->set_colors(colors);
385
+ }
386
+
387
+ void common_log_set_prefix(struct common_log * log, bool prefix) {
388
+ log->set_prefix(prefix);
389
+ }
390
+
391
+ void common_log_set_timestamps(struct common_log * log, bool timestamps) {
392
+ log->set_timestamps(timestamps);
393
+ }
common/log.h ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h" // for ggml_log_level
4
+
5
+ #define LOG_CLR_TO_EOL "\033[K\r"
6
+ #define LOG_COL_DEFAULT "\033[0m"
7
+ #define LOG_COL_BOLD "\033[1m"
8
+ #define LOG_COL_RED "\033[31m"
9
+ #define LOG_COL_GREEN "\033[32m"
10
+ #define LOG_COL_YELLOW "\033[33m"
11
+ #define LOG_COL_BLUE "\033[34m"
12
+ #define LOG_COL_MAGENTA "\033[35m"
13
+ #define LOG_COL_CYAN "\033[36m"
14
+ #define LOG_COL_WHITE "\033[37m"
15
+
16
+ #ifndef __GNUC__
17
+ # define LOG_ATTRIBUTE_FORMAT(...)
18
+ #elif defined(__MINGW32__) && !defined(__clang__)
19
+ # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
20
+ #else
21
+ # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
22
+ #endif
23
+
24
+ #define LOG_DEFAULT_DEBUG 1
25
+ #define LOG_DEFAULT_LLAMA 0
26
+
27
+ // needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
28
+ // set via common_log_set_verbosity()
29
+ extern int common_log_verbosity_thold;
30
+
31
+ void common_log_set_verbosity_thold(int verbosity); // not thread-safe
32
+
33
+ // the common_log uses an internal worker thread to print/write log messages
34
+ // when the worker thread is paused, incoming log messages are discarded
35
+ struct common_log;
36
+
37
+ struct common_log * common_log_init();
38
+ struct common_log * common_log_main(); // singleton, automatically destroys itself on exit
39
+ void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe
40
+ void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe
41
+ void common_log_free (struct common_log * log);
42
+
43
+ LOG_ATTRIBUTE_FORMAT(3, 4)
44
+ void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...);
45
+
46
+ // defaults: file = NULL, colors = false, prefix = false, timestamps = false
47
+ //
48
+ // regular log output:
49
+ //
50
+ // ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
51
+ // llm_load_tensors: ggml ctx size = 0.27 MiB
52
+ // llm_load_tensors: offloading 32 repeating layers to GPU
53
+ // llm_load_tensors: offloading non-repeating layers to GPU
54
+ //
55
+ // with prefix = true, timestamps = true, the log output will look like this:
56
+ //
57
+ // 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34)
58
+ // 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB
59
+ // 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
60
+ // 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
61
+ //
62
+ // I - info (stdout, V = 0)
63
+ // W - warning (stderr, V = 0)
64
+ // E - error (stderr, V = 0)
65
+ // D - debug (stderr, V = LOG_DEFAULT_DEBUG)
66
+ //
67
+
68
+ void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
69
+ void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe
70
+ void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
71
+ void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
72
+
73
+ // helper macros for logging
74
+ // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold
75
+ //
76
+ // for example:
77
+ //
78
+ // LOG_DBG("this is a debug message: %d\n", expensive_function());
79
+ //
80
+ // this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold
81
+ //
82
+
83
+ #define LOG_TMPL(level, verbosity, ...) \
84
+ do { \
85
+ if ((verbosity) <= common_log_verbosity_thold) { \
86
+ common_log_add(common_log_main(), (level), __VA_ARGS__); \
87
+ } \
88
+ } while (0)
89
+
90
+ #define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
91
+ #define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
92
+
93
+ #define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
94
+ #define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
95
+ #define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
96
+ #define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
97
+ #define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
98
+
99
+ #define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
100
+ #define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)
101
+ #define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__)
102
+ #define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__)
103
+ #define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__)
common/minja/chat-template.hpp ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright 2024 Google LLC
3
+
4
+ Use of this source code is governed by an MIT-style
5
+ license that can be found in the LICENSE file or at
6
+ https://opensource.org/licenses/MIT.
7
+ */
8
+ // SPDX-License-Identifier: MIT
9
+ #pragma once
10
+
11
+ #include "minja.hpp"
12
+ #include <json.hpp>
13
+ #include <string>
14
+ #include <vector>
15
+
16
+ using json = nlohmann::ordered_json;
17
+
18
+ namespace minja {
19
+
20
+ struct chat_template_caps {
21
+ bool supports_tools = false;
22
+ bool supports_tool_calls = false;
23
+ bool supports_tool_responses = false;
24
+ bool supports_system_role = false;
25
+ bool supports_parallel_tool_calls = false;
26
+ bool supports_tool_call_id = false;
27
+ // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object.
28
+ // Most other templates (and OpenAI's API) expect the arguments object to be stringified.
29
+ bool requires_object_arguments = false;
30
+ // CohereForAI/c4ai-command-r-plus simple variant
31
+ bool requires_non_null_content = false;
32
+ // MiniMaxAI/MiniMax-Text-01 special
33
+ bool requires_typed_content = false;
34
+ };
35
+
36
+ struct chat_template_inputs {
37
+ nlohmann::ordered_json messages;
38
+ nlohmann::ordered_json tools;
39
+ bool add_generation_prompt = true;
40
+ nlohmann::ordered_json extra_context;
41
+ std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
42
+ };
43
+
44
+ struct chat_template_options {
45
+ bool apply_polyfills = true;
46
+ bool use_bos_token = true;
47
+ bool use_eos_token = true;
48
+ bool define_strftime_now = true;
49
+
50
+ bool polyfill_tools = true;
51
+ bool polyfill_tool_call_examples = true;
52
+ bool polyfill_tool_calls = true;
53
+ bool polyfill_tool_responses = true;
54
+ bool polyfill_system_role = true;
55
+ bool polyfill_object_arguments = true;
56
+ bool polyfill_typed_content = true;
57
+ };
58
+
59
+ class chat_template {
60
+
61
+ private:
62
+ chat_template_caps caps_;
63
+ std::string source_;
64
+ std::string bos_token_;
65
+ std::string eos_token_;
66
+ std::shared_ptr<minja::TemplateNode> template_root_;
67
+ std::string tool_call_example_;
68
+
69
+ std::string try_raw_render(
70
+ const nlohmann::ordered_json & messages,
71
+ const nlohmann::ordered_json & tools,
72
+ bool add_generation_prompt,
73
+ const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
74
+ {
75
+ try {
76
+ chat_template_inputs inputs;
77
+ inputs.messages = messages;
78
+ inputs.tools = tools;
79
+ inputs.add_generation_prompt = add_generation_prompt;
80
+ inputs.extra_context = extra_context;
81
+ // Use fixed date for tests
82
+ inputs.now = std::chrono::system_clock::from_time_t(0);
83
+
84
+ chat_template_options opts;
85
+ opts.apply_polyfills = false;
86
+
87
+ auto prompt = apply(inputs, opts);
88
+ // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
89
+ return prompt;
90
+ } catch (const std::exception & e) {
91
+ // fprintf(stderr, "try_raw_render error: %s\n", e.what());
92
+ return "";
93
+ }
94
+ }
95
+
96
+ public:
97
+
98
+ chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
99
+ : source_(source), bos_token_(bos_token), eos_token_(eos_token)
100
+ {
101
+ template_root_ = minja::Parser::parse(source_, {
102
+ /* .trim_blocks = */ true,
103
+ /* .lstrip_blocks = */ true,
104
+ /* .keep_trailing_newline = */ false,
105
+ });
106
+
107
+ auto contains = [](const std::string & haystack, const std::string & needle) {
108
+ return haystack.find(needle) != std::string::npos;
109
+ };
110
+
111
+ const std::string user_needle = "<User Needle>";
112
+ const std::string sys_needle = "<System Needle>";
113
+ const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}};
114
+ const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}};
115
+
116
+ caps_.requires_typed_content =
117
+ !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle)
118
+ && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle);
119
+
120
+ const auto dummy_user_msg = caps_.requires_typed_content
121
+ ? dummy_typed_user_msg
122
+ : dummy_str_user_msg;
123
+ const json needle_system_msg = {
124
+ {"role", "system"},
125
+ {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)},
126
+ };
127
+
128
+ caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle);
129
+
130
+ auto out = try_raw_render(json::array({
131
+ dummy_user_msg
132
+ }), json::array({
133
+ {
134
+ {"name", "some_tool"},
135
+ {"type", "function"},
136
+ {"function", {
137
+ {"name", "some_tool"},
138
+ {"description", "Some tool."},
139
+ {"parameters", {
140
+ {"type", "object"},
141
+ {"properties", {
142
+ {"arg", {
143
+ {"type", "string"},
144
+ {"description", "Some argument."},
145
+ }},
146
+ }},
147
+ {"required", json::array({ "arg" })},
148
+ }},
149
+ }},
150
+ },
151
+ }), false);
152
+ caps_.supports_tools = contains(out, "some_tool");
153
+
154
+ auto make_tool_calls_msg = [&](const json & tool_calls) {
155
+ return json {
156
+ {"role", "assistant"},
157
+ {"content", nullptr},
158
+ {"tool_calls", tool_calls},
159
+ };
160
+ };
161
+ auto make_tool_call = [](const std::string & tool_name, const json & arguments) {
162
+ return json {
163
+ {"id", "call_1___"},
164
+ {"type", "function"},
165
+ {"function", {
166
+ {"arguments", arguments},
167
+ {"name", tool_name},
168
+ }},
169
+ };
170
+ };
171
+ const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}};
172
+
173
+ // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want.
174
+ out = try_raw_render(json::array({
175
+ dummy_user_msg,
176
+ make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})),
177
+ }), {}, false);
178
+ auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
179
+ out = try_raw_render(json::array({
180
+ dummy_user_msg,
181
+ make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})),
182
+ }), {}, false);
183
+ auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':");
184
+
185
+ caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments;
186
+ caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments;
187
+ auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false);
188
+ auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false);
189
+ caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle);
190
+
191
+ if (caps_.supports_tool_calls) {
192
+ auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump());
193
+ auto tc1 = make_tool_call("test_tool1", dummy_args);
194
+ auto tc2 = make_tool_call("test_tool2", dummy_args);
195
+ auto out = try_raw_render(json::array({
196
+ dummy_user_msg,
197
+ make_tool_calls_msg(json::array({tc1, tc2})),
198
+ }), {}, false);
199
+ caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2");
200
+
201
+ out = try_raw_render(json::array({
202
+ dummy_user_msg,
203
+ make_tool_calls_msg(json::array({tc1})),
204
+ {
205
+ {"role", "tool"},
206
+ {"name", "test_tool1"},
207
+ {"content", "Some response!"},
208
+ {"tool_call_id", "call_911_"},
209
+ }
210
+ }), {}, false);
211
+ caps_.supports_tool_responses = contains(out, "Some response!");
212
+ caps_.supports_tool_call_id = contains(out, "call_911_");
213
+ }
214
+
215
+ try {
216
+ if (!caps_.supports_tools) {
217
+ const json user_msg {
218
+ {"role", "user"},
219
+ {"content", "Hey"},
220
+ };
221
+ const json args {
222
+ {"arg1", "some_value"},
223
+ };
224
+ const json tool_call_msg {
225
+ {"role", "assistant"},
226
+ {"content", nullptr},
227
+ {"tool_calls", json::array({
228
+ {
229
+ // TODO: detect if requires numerical id or fixed length == 6 like Nemo
230
+ {"id", "call_1___"},
231
+ {"type", "function"},
232
+ {"function", {
233
+ {"name", "tool_name"},
234
+ {"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
235
+ }},
236
+ },
237
+ })},
238
+ };
239
+ std::string prefix, full;
240
+ {
241
+ chat_template_inputs inputs;
242
+ inputs.messages = json::array({user_msg});
243
+ inputs.add_generation_prompt = true;
244
+ prefix = apply(inputs);
245
+ }
246
+ {
247
+ chat_template_inputs inputs;
248
+ inputs.messages = json::array({user_msg, tool_call_msg});
249
+ inputs.add_generation_prompt = false;
250
+ full = apply(inputs);
251
+ }
252
+ auto eos_pos_last = full.rfind(eos_token_);
253
+ if (eos_pos_last == prefix.size() - eos_token_.size() ||
254
+ (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
255
+ full = full.substr(0, eos_pos_last);
256
+ }
257
+ size_t common_prefix_length = 0;
258
+ for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
259
+ if (prefix[i] != full[i]) {
260
+ break;
261
+ }
262
+ if (prefix[i] == '<') {
263
+ // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
264
+ // but it removes thinking tags for past messages.
265
+ // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
266
+ continue;
267
+ }
268
+ common_prefix_length = i + 1;
269
+ }
270
+ auto example = full.substr(common_prefix_length);
271
+ if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
272
+ fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
273
+ } else {
274
+ tool_call_example_ = example;
275
+ }
276
+ }
277
+ } catch (const std::exception & e) {
278
+ fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
279
+ }
280
+ }
281
+
282
+ const std::string & source() const { return source_; }
283
+ const std::string & bos_token() const { return bos_token_; }
284
+ const std::string & eos_token() const { return eos_token_; }
285
+ const chat_template_caps & original_caps() const { return caps_; }
286
+
287
+ // Deprecated, please use the form with chat_template_inputs and chat_template_options
288
+ std::string apply(
289
+ const nlohmann::ordered_json & messages,
290
+ const nlohmann::ordered_json & tools,
291
+ bool add_generation_prompt,
292
+ const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
293
+ bool apply_polyfills = true)
294
+ {
295
+ fprintf(stderr, "[%s] Deprecated!\n", __func__);
296
+ chat_template_inputs inputs;
297
+ inputs.messages = messages;
298
+ inputs.tools = tools;
299
+ inputs.add_generation_prompt = add_generation_prompt;
300
+ inputs.extra_context = extra_context;
301
+ inputs.now = std::chrono::system_clock::now();
302
+
303
+ chat_template_options opts;
304
+ opts.apply_polyfills = apply_polyfills;
305
+
306
+ return apply(inputs, opts);
307
+ }
308
+
309
+ std::string apply(
310
+ const chat_template_inputs & inputs,
311
+ const chat_template_options & opts = chat_template_options()) const
312
+ {
313
+ json actual_messages;
314
+
315
+ auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
316
+ auto has_tool_calls = false;
317
+ auto has_tool_responses = false;
318
+ auto has_string_content = false;
319
+ for (const auto & message : inputs.messages) {
320
+ if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
321
+ has_tool_calls = true;
322
+ }
323
+ if (message.contains("role") && message["role"] == "tool") {
324
+ has_tool_responses = true;
325
+ }
326
+ if (message.contains("content") && message["content"].is_string()) {
327
+ has_string_content = true;
328
+ }
329
+ }
330
+
331
+ auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
332
+ auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
333
+ auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
334
+ auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
335
+ auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
336
+ auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
337
+ auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
338
+
339
+ auto needs_polyfills = opts.apply_polyfills && (false
340
+ || polyfill_system_role
341
+ || polyfill_tools
342
+ || polyfill_tool_calls
343
+ || polyfill_tool_responses
344
+ || polyfill_object_arguments
345
+ || polyfill_typed_content
346
+ );
347
+
348
+ if (needs_polyfills) {
349
+ actual_messages = json::array();
350
+
351
+ auto add_message = [&](const json & msg) {
352
+ if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
353
+ actual_messages.push_back({
354
+ {"role", msg.at("role")},
355
+ {"content", {{
356
+ {"type", "text"},
357
+ {"text", msg.at("content")},
358
+ }}},
359
+ });
360
+ } else {
361
+ actual_messages.push_back(msg);
362
+ }
363
+ };
364
+
365
+ std::string pending_system;
366
+ auto flush_sys = [&]() {
367
+ if (!pending_system.empty()) {
368
+ add_message({
369
+ {"role", "user"},
370
+ {"content", pending_system},
371
+ });
372
+ pending_system.clear();
373
+ }
374
+ };
375
+
376
+ json adjusted_messages;
377
+ if (polyfill_tools) {
378
+ adjusted_messages = add_system(inputs.messages,
379
+ "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
380
+ (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
381
+ } else {
382
+ adjusted_messages = inputs.messages;
383
+ }
384
+
385
+ for (const auto & message_ : adjusted_messages) {
386
+ auto message = message_;
387
+ if (!message.contains("role") || !message.contains("content")) {
388
+ throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
389
+ }
390
+ std::string role = message.at("role");
391
+
392
+ if (message.contains("tool_calls")) {
393
+ if (polyfill_object_arguments || polyfill_tool_calls) {
394
+ for (auto & tool_call : message.at("tool_calls")) {
395
+ if (tool_call["type"] == "function") {
396
+ auto & function = tool_call.at("function");
397
+ auto & arguments = function.at("arguments");
398
+ if (arguments.is_string()) {
399
+ try {
400
+ arguments = json::parse(arguments.get<std::string>());
401
+ } catch (const std::exception & ecvt) {
402
+ fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what());
403
+ }
404
+ }
405
+ }
406
+ }
407
+ }
408
+ if (polyfill_tool_calls) {
409
+ auto content = message.at("content");
410
+ auto tool_calls = json::array();
411
+ for (const auto & tool_call : message.at("tool_calls")) {
412
+ if (tool_call.at("type") != "function") {
413
+ continue;
414
+ }
415
+ const auto & function = tool_call.at("function");
416
+ auto tc = json {
417
+ {"name", function.at("name")},
418
+ {"arguments", function.at("arguments")},
419
+ };
420
+ if (tool_call.contains("id")) {
421
+ tc["id"] = tool_call["id"];
422
+ }
423
+ tool_calls.push_back(tc);
424
+ }
425
+ auto obj = json {
426
+ {"tool_calls", tool_calls},
427
+ };
428
+ if (!content.is_null() && content != "") {
429
+ obj["content"] = content;
430
+ }
431
+ message["content"] = obj.dump(2);
432
+ message.erase("tool_calls");
433
+ }
434
+ }
435
+ if (polyfill_tool_responses && role == "tool") {
436
+ message["role"] = "user";
437
+ auto obj = json {
438
+ {"tool_response", {
439
+ {"content", message.at("content")},
440
+ }},
441
+ };
442
+ if (message.contains("name")) {
443
+ obj["tool_response"]["name"] = message.at("name");
444
+ }
445
+ if (message.contains("tool_call_id")) {
446
+ obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
447
+ }
448
+ message["content"] = obj.dump(2);
449
+ message.erase("name");
450
+ }
451
+
452
+ if (!message["content"].is_null() && polyfill_system_role) {
453
+ std::string content = message.at("content");
454
+ if (role == "system") {
455
+ if (!pending_system.empty()) pending_system += "\n";
456
+ pending_system += content;
457
+ continue;
458
+ } else {
459
+ if (role == "user") {
460
+ if (!pending_system.empty()) {
461
+ message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
462
+ pending_system.clear();
463
+ }
464
+ } else {
465
+ flush_sys();
466
+ }
467
+ }
468
+ }
469
+ add_message(message);
470
+ }
471
+ flush_sys();
472
+ } else {
473
+ actual_messages = inputs.messages;
474
+ }
475
+
476
+ auto context = minja::Context::make(json({
477
+ {"messages", actual_messages},
478
+ {"add_generation_prompt", inputs.add_generation_prompt},
479
+ }));
480
+ context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
481
+ context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
482
+ if (opts.define_strftime_now) {
483
+ auto now = inputs.now;
484
+ context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
485
+ args.expectArgs("strftime_now", {1, 1}, {0, 0});
486
+ auto format = args.args[0].get<std::string>();
487
+
488
+ auto time = std::chrono::system_clock::to_time_t(now);
489
+ auto local_time = *std::localtime(&time);
490
+ std::ostringstream ss;
491
+ ss << std::put_time(&local_time, format.c_str());
492
+ return ss.str();
493
+ }));
494
+ }
495
+ if (!inputs.tools.is_null()) {
496
+ context->set("tools", minja::Value(inputs.tools));
497
+ }
498
+ if (!inputs.extra_context.is_null()) {
499
+ for (auto & kv : inputs.extra_context.items()) {
500
+ context->set(kv.key(), minja::Value(kv.value()));
501
+ }
502
+ }
503
+
504
+ auto ret = template_root_->render(context);
505
+ // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str());
506
+ // fprintf(stderr, "apply: %s\n\n", ret.c_str());
507
+ return ret;
508
+ }
509
+
510
+ static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) {
511
+ json messages_with_system = messages;
512
+
513
+ if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") {
514
+ std::string existing_system = messages_with_system.at(0).at("content");
515
+ messages_with_system[0] = json {
516
+ {"role", "system"},
517
+ {"content", existing_system + "\n\n" + system_prompt},
518
+ };
519
+ } else {
520
+ messages_with_system.insert(messages_with_system.begin(), json {
521
+ {"role", "system"},
522
+ {"content", system_prompt},
523
+ });
524
+ }
525
+ return messages_with_system;
526
+ }
527
+ };
528
+
529
+ } // namespace minja
common/minja/minja.hpp ADDED
The diff for this file is too large to render. See raw diff
 
common/ngram-cache.cpp ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ngram-cache.h"
2
+ #include "common.h"
3
+ #include "log.h"
4
+
5
+ #include <cinttypes>
6
+ #include <cstdint>
7
+ #include <cstdio>
8
+ #include <fstream>
9
+ #include <thread>
10
+ #include <algorithm>
11
+
12
+ void common_ngram_cache_update(common_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
13
+ std::vector<llama_token> & inp, int nnew, bool print_progress) {
14
+ const int64_t t_start_ms = ggml_time_ms();
15
+ const int64_t inp_size = inp.size();
16
+
17
+ const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
18
+ int64_t n_done = 0;
19
+
20
+ for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
21
+ const int64_t i_start = std::max(inp_size - nnew, ngram_size);
22
+ for (int64_t i = i_start; i < inp_size; ++i) {
23
+ const int64_t ngram_start = i - ngram_size;
24
+ common_ngram ngram(&inp[ngram_start], ngram_size);
25
+ const llama_token token = inp[i];
26
+
27
+ common_ngram_cache::iterator part_it = ngram_cache.find(ngram);
28
+ if (part_it == ngram_cache.end()) {
29
+ common_ngram_cache_part part;
30
+ part.emplace(token, 1);
31
+ ngram_cache.emplace(ngram, part);
32
+ } else {
33
+ common_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
34
+ if (token_count_it == part_it->second.end()) {
35
+ part_it->second.emplace(token, 1);
36
+ } else {
37
+ token_count_it->second++;
38
+ }
39
+ }
40
+ ++n_done;
41
+
42
+ if (print_progress && n_done % 10000000 == 0) {
43
+ const int64_t t_now_ms = ggml_time_ms();
44
+ const int64_t eta_ms = (inp_size*(ngram_max-ngram_min+1) - n_done) * (t_now_ms - t_start_ms) / n_done;
45
+ const int64_t eta_min = eta_ms / (60*1000);
46
+ const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
47
+
48
+ fprintf(stderr, "%s: %" PRId64 "/%" PRId64 " done, ETA: %02" PRId64 ":%02" PRId64 "\n", __func__, n_done, n_todo, eta_min, eta_s);
49
+ }
50
+ }
51
+ }
52
+ }
53
+
54
+ // Helper function to get a token from the combined, speculative sequence of inp and draft.
55
+ static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
56
+ return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
57
+ }
58
+
59
+ // If sample size or percentage are below these thresholds the draft is aborted early:
60
+ constexpr int draft_min_sample_size_lax[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
61
+ constexpr int draft_min_percent_lax[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
62
+ constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
63
+ constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};
64
+
65
+ // Helper function that tries to draft a token from only the static ngram cache:
66
+ static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) {
67
+ common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
68
+ if (part_static_it == nc_static.end()) {
69
+ return LLAMA_TOKEN_NULL;
70
+ }
71
+ const common_ngram_cache_part part_static = part_static_it->second;
72
+
73
+ int max_count_static = 0;
74
+ int sum_count_static = 0;
75
+ llama_token max_token = LLAMA_TOKEN_NULL;
76
+
77
+ for (std::pair<llama_token, int> token_count_static : part_static) {
78
+ const llama_token token = token_count_static.first;
79
+ const int32_t count_static = token_count_static.second;
80
+
81
+ if (count_static > max_count_static) {
82
+ max_token = token;
83
+ max_count_static = count_static;
84
+ }
85
+ sum_count_static += count_static;
86
+ }
87
+
88
+ if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) {
89
+ return LLAMA_TOKEN_NULL;
90
+ }
91
+ if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) {
92
+ return LLAMA_TOKEN_NULL;
93
+ }
94
+ return max_token;
95
+ }
96
+
97
+ // Try to draft a token from primary cache (context/dynamic), validate with static cache:
98
+ static llama_token try_draft(
99
+ common_ngram_cache & nc_primary, const std::vector<common_ngram> & ngrams_primary, common_ngram_cache_part & part_static,
100
+ const int * min_sample_size, const int * min_percent) {
101
+
102
+ llama_token drafted_token = LLAMA_TOKEN_NULL;
103
+
104
+ for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) {
105
+ const common_ngram ngram_primary = ngrams_primary[i];
106
+
107
+ common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
108
+ if (part_primary_it == nc_primary.end()) {
109
+ continue;
110
+ }
111
+ const common_ngram_cache_part part_primary = part_primary_it->second;
112
+
113
+ int max_count_primary = 0;
114
+ int max_count_static = 0;
115
+ int sum_count_primary = 0;
116
+ llama_token max_token = LLAMA_TOKEN_NULL;
117
+
118
+ for (std::pair<llama_token, int> token_count_primary : part_primary) {
119
+ const llama_token token = token_count_primary.first;
120
+
121
+ common_ngram_cache_part::iterator token_count_static_it = part_static.find(token);
122
+
123
+ const int32_t count_primary = token_count_primary.second;
124
+ const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1;
125
+
126
+ if (count_primary*count_static > max_count_primary*max_count_static) {
127
+ max_token = token;
128
+ max_count_primary = count_primary;
129
+ max_count_static = count_static;
130
+ }
131
+ sum_count_primary += count_primary;
132
+ }
133
+
134
+ if (sum_count_primary < min_sample_size[i]) {
135
+ continue;
136
+ }
137
+ if (100*max_count_primary < min_percent[i]*sum_count_primary) {
138
+ continue;;
139
+ }
140
+ drafted_token = max_token;
141
+ }
142
+
143
+ return drafted_token;
144
+ }
145
+
146
+ void common_ngram_cache_draft(
147
+ std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
148
+ common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static
149
+ ) {
150
+ GGML_ASSERT(draft.size() == 1);
151
+ const int inp_size = inp.size();
152
+
153
+ if (inp_size < LLAMA_NGRAM_STATIC) {
154
+ return;
155
+ }
156
+
157
+ while ((int) draft.size()-1 < n_draft) {
158
+ llama_token drafted_token = LLAMA_TOKEN_NULL;
159
+
160
+ const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
161
+ common_ngram ngram_static;
162
+ for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
163
+ ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
164
+ }
165
+ common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
166
+ common_ngram_cache_part part_static;
167
+ if (part_static_it != nc_static.end()) {
168
+ part_static = part_static_it->second;
169
+ }
170
+
171
+ // cd = context + dynamic
172
+ std::vector<common_ngram> ngrams_cd;
173
+ for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) {
174
+ const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
175
+ common_ngram ngram_cd;
176
+ for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
177
+ ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
178
+ }
179
+ ngrams_cd.push_back(ngram_cd);
180
+ }
181
+ if (drafted_token == LLAMA_TOKEN_NULL) {
182
+ drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax);
183
+ }
184
+ if (drafted_token == LLAMA_TOKEN_NULL) {
185
+ drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict);
186
+ }
187
+ if (drafted_token == LLAMA_TOKEN_NULL) {
188
+ drafted_token = try_draft(nc_static, ngram_static);
189
+ }
190
+
191
+ if (drafted_token == LLAMA_TOKEN_NULL) {
192
+ break;
193
+ }
194
+
195
+ LOG(" - draft candidate: token=%d\n", drafted_token);
196
+ draft.push_back(drafted_token);
197
+ }
198
+ }
199
+
200
+ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
201
+ std::ofstream file_out(filename, std::ios::binary);
202
+ for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
203
+ const common_ngram ngram = item.first;
204
+ common_ngram_cache_part token_counts = item.second;
205
+ GGML_ASSERT(!token_counts.empty());
206
+ const int32_t ntokens = token_counts.size();
207
+ GGML_ASSERT(ntokens > 0);
208
+
209
+ file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(common_ngram));
210
+ file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
211
+ for (std::pair<llama_token, int32_t> item2 : token_counts) {
212
+ const llama_token token = item2.first;
213
+ const int32_t count = item2.second;
214
+ GGML_ASSERT(count > 0);
215
+
216
+ file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
217
+ file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
218
+ }
219
+ }
220
+
221
+ }
222
+
223
+ common_ngram_cache common_ngram_cache_load(std::string & filename) {
224
+ std::ifstream hashmap_file(filename, std::ios::binary);
225
+ if (!hashmap_file) {
226
+ throw std::ifstream::failure("Unable to open file " + filename);
227
+ }
228
+ common_ngram_cache ngram_cache;
229
+
230
+ common_ngram ngram;
231
+ int32_t ntokens;
232
+ llama_token token;
233
+ int32_t count;
234
+
235
+ char * ngramc = reinterpret_cast<char*>(&ngram);
236
+ char * ntokensc = reinterpret_cast<char*>(&ntokens);
237
+ char * tokenc = reinterpret_cast<char*>(&token);
238
+ char * countc = reinterpret_cast<char*>(&count);
239
+ while(hashmap_file.read(ngramc, sizeof(common_ngram))) {
240
+ GGML_ASSERT(!hashmap_file.eof());
241
+ GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
242
+ GGML_ASSERT(ntokens > 0);
243
+ common_ngram_cache_part token_counts;
244
+
245
+ for (int i = 0; i < ntokens; ++i) {
246
+ GGML_ASSERT(!hashmap_file.eof());
247
+ GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
248
+ GGML_ASSERT(!hashmap_file.eof());
249
+ GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
250
+ GGML_ASSERT(count > 0);
251
+ token_counts.emplace(token, count);
252
+ }
253
+
254
+ ngram_cache.emplace(ngram, token_counts);
255
+ }
256
+ GGML_ASSERT(hashmap_file.eof());
257
+
258
+ return ngram_cache;
259
+ }
260
+
261
+ void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) {
262
+ for (std::pair<common_ngram, common_ngram_cache_part> ngram_part : ngram_cache_add) {
263
+ const common_ngram ngram = ngram_part.first;
264
+ common_ngram_cache_part part = ngram_part.second;
265
+
266
+ common_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
267
+ if (part_merged_it == ngram_cache_target.end()) {
268
+ ngram_cache_target.emplace(ngram, part);
269
+ continue;
270
+ }
271
+
272
+ for (std::pair<llama_token, int32_t> token_count : part) {
273
+ const llama_token token = token_count.first;
274
+ const int32_t count = token_count.second;
275
+ GGML_ASSERT(count > 0);
276
+
277
+ common_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
278
+ if (token_count_merged_it == part_merged_it->second.end()) {
279
+ part_merged_it->second.emplace(token, count);
280
+ continue;
281
+ }
282
+
283
+ token_count_merged_it->second += count;
284
+ }
285
+ }
286
+ }
common/ngram-cache.h ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ #include <unordered_map>
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ #define LLAMA_NGRAM_MIN 1
10
+ #define LLAMA_NGRAM_MAX 4
11
+ #define LLAMA_NGRAM_STATIC 2
12
+
13
+ // Data structures to map n-grams to empirical token probabilities:
14
+
15
+ struct common_ngram {
16
+ llama_token tokens[LLAMA_NGRAM_MAX];
17
+
18
+ common_ngram() {
19
+ for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
20
+ tokens[i] = LLAMA_TOKEN_NULL;
21
+ }
22
+ }
23
+
24
+ common_ngram(const llama_token * input, const int ngram_size) {
25
+ for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
26
+ tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL;
27
+ }
28
+ }
29
+
30
+ bool operator==(const common_ngram & other) const {
31
+ for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
32
+ if (tokens[i] != other.tokens[i]) {
33
+ return false;
34
+ }
35
+ }
36
+ return true;
37
+ }
38
+ };
39
+
40
+ struct common_token_hash_function {
41
+ size_t operator()(const llama_token token) const {
42
+ // see https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/
43
+ return token * 11400714819323198485llu;
44
+ }
45
+ };
46
+
47
+ struct common_ngram_hash_function {
48
+ size_t operator()(const common_ngram & ngram) const {
49
+ size_t hash = common_token_hash_function{}(ngram.tokens[0]);
50
+ for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
51
+ hash ^= common_token_hash_function{}(ngram.tokens[i]);
52
+ }
53
+ return hash;
54
+ }
55
+ };
56
+
57
+ // token -> number of times token has been seen
58
+ typedef std::unordered_map<llama_token, int32_t> common_ngram_cache_part;
59
+
60
+ // n-gram -> empirical distribution of following tokens
61
+ typedef std::unordered_map<common_ngram, common_ngram_cache_part, common_ngram_hash_function> common_ngram_cache;
62
+
63
+
64
+ // Update an ngram cache with tokens.
65
+ // ngram_cache: the cache to modify.
66
+ // ngram_min/ngram_max: the min/max size of the ngrams to extract from inp_data.
67
+ // inp_data: the token sequence with which to update ngram_cache.
68
+ // nnew: how many new tokens have been appended to inp_data since the last call to this function.
69
+ // print_progress: whether to print progress to stderr.
70
+ //
71
+ // In order to get correct results inp_data can ONLY BE APPENDED TO.
72
+ // Changes in the middle need a complete rebuild.
73
+ void common_ngram_cache_update(
74
+ common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
75
+
76
+ // Try to draft tokens from ngram caches.
77
+ // inp: the tokens generated so far.
78
+ // draft: the token sequence to draft. Expected to initially contain the previously sampled token.
79
+ // n_draft: maximum number of tokens to add to draft.
80
+ // ngram_min/gram_max: the min/max size of the ngrams in nc_context and nc_dynamic.
81
+ // nc_context: ngram cache based on current context.
82
+ // nc_dynamic: ngram cache based on previous user generations.
83
+ // nc_static: ngram cache generated from a large text corpus, used for validation.
84
+ void common_ngram_cache_draft(
85
+ std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
86
+ common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static);
87
+
88
+ // Save an ngram cache to a file.
89
+ // ngram_cache: the ngram cache to save.
90
+ // filename: the path under which to save the ngram cache.
91
+ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
92
+
93
+ // Load an ngram cache saved with common_ngram_cache_save.
94
+ // filename: the path from which to load the ngram cache.
95
+ // returns: an ngram cache containing the information saved to filename.
96
+ common_ngram_cache common_ngram_cache_load(std::string & filename);
97
+
98
+ // Merge two ngram caches.
99
+ // ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
100
+ // ngram_cache_add: the ngram cache to add to ngram_cache_target.
101
+ void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add);
common/sampling.cpp ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "sampling.h"
2
+
3
+ #include "common.h"
4
+
5
+ #include <cmath>
6
+ #include <unordered_map>
7
+ #include <algorithm>
8
+
9
+ // the ring buffer works similarly to std::deque, but with a fixed capacity
10
+ // TODO: deduplicate with llama-impl.h
11
+ template<typename T>
12
+ struct ring_buffer {
13
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
14
+
15
+ T & front() {
16
+ if (sz == 0) {
17
+ throw std::runtime_error("ring buffer is empty");
18
+ }
19
+ return data[first];
20
+ }
21
+
22
+ const T & front() const {
23
+ if (sz == 0) {
24
+ throw std::runtime_error("ring buffer is empty");
25
+ }
26
+ return data[first];
27
+ }
28
+
29
+ T & back() {
30
+ if (sz == 0) {
31
+ throw std::runtime_error("ring buffer is empty");
32
+ }
33
+ return data[pos];
34
+ }
35
+
36
+ const T & back() const {
37
+ if (sz == 0) {
38
+ throw std::runtime_error("ring buffer is empty");
39
+ }
40
+ return data[pos];
41
+ }
42
+
43
+ void push_back(const T & value) {
44
+ if (sz == capacity) {
45
+ // advance the start when buffer is full
46
+ first = (first + 1) % capacity;
47
+ } else {
48
+ sz++;
49
+ }
50
+ data[pos] = value;
51
+ pos = (pos + 1) % capacity;
52
+ }
53
+
54
+ T pop_front() {
55
+ if (sz == 0) {
56
+ throw std::runtime_error("ring buffer is empty");
57
+ }
58
+ T value = data[first];
59
+ first = (first + 1) % capacity;
60
+ sz--;
61
+ return value;
62
+ }
63
+
64
+ const T & rat(size_t i) const {
65
+ if (i >= sz) {
66
+ throw std::runtime_error("ring buffer: index out of bounds");
67
+ }
68
+ return data[(first + sz - i - 1) % capacity];
69
+ }
70
+
71
+ std::vector<T> to_vector() const {
72
+ std::vector<T> result;
73
+ result.reserve(sz);
74
+ for (size_t i = 0; i < sz; i++) {
75
+ result.push_back(data[(first + i) % capacity]);
76
+ }
77
+ return result;
78
+ }
79
+
80
+ void clear() {
81
+ // here only reset the status of the buffer
82
+ sz = 0;
83
+ first = 0;
84
+ pos = 0;
85
+ }
86
+
87
+ bool empty() const {
88
+ return sz == 0;
89
+ }
90
+
91
+ size_t size() const {
92
+ return sz;
93
+ }
94
+
95
+ size_t capacity = 0;
96
+ size_t sz = 0;
97
+ size_t first = 0;
98
+ size_t pos = 0;
99
+ std::vector<T> data;
100
+ };
101
+
102
+ struct common_sampler {
103
+ common_params_sampling params;
104
+
105
+ struct llama_sampler * grmr;
106
+ struct llama_sampler * chain;
107
+
108
+ ring_buffer<llama_token> prev;
109
+
110
+ std::vector<llama_token_data> cur;
111
+
112
+ llama_token_data_array cur_p;
113
+
114
+ void set_logits(struct llama_context * ctx, int idx) {
115
+ const auto * logits = llama_get_logits_ith(ctx, idx);
116
+
117
+ const llama_model * model = llama_get_model(ctx);
118
+ const llama_vocab * vocab = llama_model_get_vocab(model);
119
+
120
+ const int n_vocab = llama_vocab_n_tokens(vocab);
121
+
122
+ cur.resize(n_vocab);
123
+
124
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
125
+ cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
126
+ }
127
+
128
+ cur_p = { cur.data(), cur.size(), -1, false };
129
+ }
130
+ };
131
+
132
+ std::string common_params_sampling::print() const {
133
+ char result[1024];
134
+
135
+ snprintf(result, sizeof(result),
136
+ "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
137
+ "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
138
+ "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
139
+ "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
140
+ penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
141
+ dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
142
+ top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
143
+ mirostat, mirostat_eta, mirostat_tau);
144
+
145
+ return std::string(result);
146
+ }
147
+
148
+ struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
149
+ const llama_vocab * vocab = llama_model_get_vocab(model);
150
+
151
+ llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
152
+
153
+ lparams.no_perf = params.no_perf;
154
+
155
+ struct llama_sampler * grmr;
156
+ if (params.grammar.compare(0, 11, "%llguidance") == 0) {
157
+ #ifdef LLAMA_USE_LLGUIDANCE
158
+ grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
159
+ #else
160
+ GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
161
+ #endif // LLAMA_USE_LLGUIDANCE
162
+ } else {
163
+ std::vector<std::string> patterns_at_start;
164
+ std::vector<std::string> patterns_anywhere;
165
+ std::vector<llama_token> trigger_tokens;
166
+ for (const auto & trigger : params.grammar_triggers) {
167
+ switch (trigger.type) {
168
+ case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
169
+ {
170
+ const auto & word = trigger.value;
171
+ patterns_anywhere.push_back(regex_escape(word));
172
+ break;
173
+ }
174
+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
175
+ case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
176
+ {
177
+ const auto & pattern = trigger.value;
178
+ (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
179
+ break;
180
+ }
181
+ case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
182
+ {
183
+ const auto token = trigger.token;
184
+ trigger_tokens.push_back(token);
185
+ break;
186
+ }
187
+ default:
188
+ GGML_ASSERT(false && "unknown trigger type");
189
+ }
190
+ }
191
+
192
+ std::vector<std::string> trigger_patterns;
193
+ if (!patterns_at_start.empty()) {
194
+ trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
195
+ }
196
+ if (!patterns_anywhere.empty()) {
197
+ trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
198
+ }
199
+
200
+ std::vector<const char *> trigger_patterns_c;
201
+ trigger_patterns_c.reserve(trigger_patterns.size());
202
+ for (const auto & regex : trigger_patterns) {
203
+ trigger_patterns_c.push_back(regex.c_str());
204
+ }
205
+
206
+ grmr = params.grammar_lazy
207
+ ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
208
+ trigger_patterns_c.data(), trigger_patterns_c.size(),
209
+ trigger_tokens.data(), trigger_tokens.size())
210
+ : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
211
+ }
212
+
213
+ auto * result = new common_sampler {
214
+ /* .params = */ params,
215
+ /* .grmr = */ grmr,
216
+ /* .chain = */ llama_sampler_chain_init(lparams),
217
+ /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
218
+ /* .cur = */ {},
219
+ /* .cur_p = */ {},
220
+ };
221
+
222
+ llama_sampler_chain_add(result->chain,
223
+ llama_sampler_init_logit_bias(
224
+ llama_vocab_n_tokens(vocab),
225
+ params.logit_bias.size(),
226
+ params.logit_bias.data()));
227
+
228
+ if (params.mirostat == 0) {
229
+ if (params.top_n_sigma >= 0) {
230
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
231
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
232
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
233
+ } else {
234
+ for (const auto & cnstr : params.samplers) {
235
+ switch (cnstr) {
236
+ case COMMON_SAMPLER_TYPE_DRY:
237
+ {
238
+ std::vector<const char *> c_breakers;
239
+ c_breakers.reserve(params.dry_sequence_breakers.size());
240
+ for (const auto & str : params.dry_sequence_breakers) {
241
+ c_breakers.push_back(str.c_str());
242
+ }
243
+
244
+ llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
245
+ }
246
+ break;
247
+ case COMMON_SAMPLER_TYPE_TOP_K:
248
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
249
+ break;
250
+ case COMMON_SAMPLER_TYPE_TOP_P:
251
+ llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
252
+ break;
253
+ case COMMON_SAMPLER_TYPE_MIN_P:
254
+ llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
255
+ break;
256
+ case COMMON_SAMPLER_TYPE_XTC:
257
+ llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
258
+ break;
259
+ case COMMON_SAMPLER_TYPE_TYPICAL_P:
260
+ llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
261
+ break;
262
+ case COMMON_SAMPLER_TYPE_TEMPERATURE:
263
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
264
+ break;
265
+ case COMMON_SAMPLER_TYPE_INFILL:
266
+ llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
267
+ break;
268
+ case COMMON_SAMPLER_TYPE_PENALTIES:
269
+ llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
270
+ break;
271
+ default:
272
+ GGML_ASSERT(false && "unknown sampler type");
273
+ }
274
+ }
275
+ }
276
+ llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
277
+ } else if (params.mirostat == 1) {
278
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
279
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
280
+ } else if (params.mirostat == 2) {
281
+ llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
282
+ llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
283
+ } else {
284
+ GGML_ASSERT(false && "unknown mirostat version");
285
+ }
286
+
287
+ return result;
288
+ }
289
+
290
+ void common_sampler_free(struct common_sampler * gsmpl) {
291
+ if (gsmpl) {
292
+ llama_sampler_free(gsmpl->grmr);
293
+
294
+ llama_sampler_free(gsmpl->chain);
295
+
296
+ delete gsmpl;
297
+ }
298
+ }
299
+
300
+ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
301
+ if (accept_grammar) {
302
+ llama_sampler_accept(gsmpl->grmr, token);
303
+ }
304
+
305
+ llama_sampler_accept(gsmpl->chain, token);
306
+
307
+ gsmpl->prev.push_back(token);
308
+ }
309
+
310
+ void common_sampler_reset(struct common_sampler * gsmpl) {
311
+ llama_sampler_reset(gsmpl->grmr);
312
+
313
+ llama_sampler_reset(gsmpl->chain);
314
+ }
315
+
316
+ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
317
+ return new common_sampler {
318
+ /* .params = */ gsmpl->params,
319
+ /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
320
+ /* .chain = */ llama_sampler_clone(gsmpl->chain),
321
+ /* .prev = */ gsmpl->prev,
322
+ /* .cur = */ gsmpl->cur,
323
+ /* .cur_p = */ gsmpl->cur_p,
324
+ };
325
+ }
326
+
327
+ void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
328
+ // TODO: measure grammar performance
329
+
330
+ if (gsmpl) {
331
+ llama_perf_sampler_print(gsmpl->chain);
332
+ }
333
+ if (ctx) {
334
+ llama_perf_context_print(ctx);
335
+ }
336
+ }
337
+
338
+ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
339
+ gsmpl->set_logits(ctx, idx);
340
+
341
+ auto & grmr = gsmpl->grmr;
342
+ auto & chain = gsmpl->chain;
343
+ auto & cur_p = gsmpl->cur_p; // initialized by set_logits
344
+
345
+ if (grammar_first) {
346
+ llama_sampler_apply(grmr, &cur_p);
347
+ }
348
+
349
+ llama_sampler_apply(chain, &cur_p);
350
+
351
+ GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
352
+
353
+ const llama_token id = cur_p.data[cur_p.selected].id;
354
+
355
+ if (grammar_first) {
356
+ return id;
357
+ }
358
+
359
+ // check if it the sampled token fits the grammar
360
+ {
361
+ llama_token_data single_token_data = { id, 1.0f, 0.0f };
362
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
363
+
364
+ llama_sampler_apply(grmr, &single_token_data_array);
365
+
366
+ const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
367
+ if (is_valid) {
368
+ return id;
369
+ }
370
+ }
371
+
372
+ // resampling:
373
+ // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
374
+ gsmpl->set_logits(ctx, idx);
375
+
376
+ llama_sampler_apply(grmr, &cur_p);
377
+ llama_sampler_apply(chain, &cur_p);
378
+
379
+ GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
380
+
381
+ return cur_p.data[cur_p.selected].id;
382
+ }
383
+
384
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
385
+ GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
386
+
387
+ std::vector<llama_token> result;
388
+ result.reserve(idxs.size());
389
+
390
+ size_t i = 0;
391
+ for (; i < draft.size(); i++) {
392
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
393
+
394
+ common_sampler_accept(gsmpl, id, true);
395
+
396
+ result.push_back(id);
397
+
398
+ if (draft[i] != id) {
399
+ break;
400
+ }
401
+ }
402
+
403
+ if (i == draft.size()) {
404
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
405
+
406
+ common_sampler_accept(gsmpl, id, true);
407
+
408
+ result.push_back(id);
409
+ }
410
+
411
+ return result;
412
+ }
413
+
414
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
415
+ std::vector<int> idxs(draft.size() + 1);
416
+ for (size_t i = 0; i < idxs.size(); ++i) {
417
+ idxs[i] = i;
418
+ }
419
+
420
+ return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
421
+ }
422
+
423
+ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
424
+ return llama_sampler_get_seed(gsmpl->chain);
425
+ }
426
+
427
+ // helpers
428
+
429
+ llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
430
+ return &gsmpl->cur_p;
431
+ }
432
+
433
+ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
434
+ return gsmpl->prev.rat(0);
435
+ }
436
+
437
+ std::string common_sampler_print(const struct common_sampler * gsmpl) {
438
+ std::string result = "logits ";
439
+
440
+ for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
441
+ const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
442
+ result += std::string("-> ") + llama_sampler_name(smpl) + " ";
443
+ }
444
+
445
+ return result;
446
+ }
447
+
448
+ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) {
449
+ n = std::min(n, (int) gsmpl->prev.size());
450
+
451
+ if (n <= 0) {
452
+ return "";
453
+ }
454
+
455
+ std::string result;
456
+ result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab
457
+
458
+ for (int i = n - 1; i >= 0; i--) {
459
+ const llama_token id = gsmpl->prev.rat(i);
460
+
461
+ GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");
462
+
463
+ result += common_token_to_piece(ctx_main, id);
464
+ }
465
+
466
+ return result;
467
+ }
468
+
469
+ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
470
+ switch (cnstr) {
471
+ case COMMON_SAMPLER_TYPE_DRY: return 'd';
472
+ case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
473
+ case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
474
+ case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
475
+ case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
476
+ case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
477
+ case COMMON_SAMPLER_TYPE_XTC: return 'x';
478
+ case COMMON_SAMPLER_TYPE_INFILL: return 'i';
479
+ case COMMON_SAMPLER_TYPE_PENALTIES: return 'e';
480
+ default : return '?';
481
+ }
482
+ }
483
+
484
+ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
485
+ switch (cnstr) {
486
+ case COMMON_SAMPLER_TYPE_DRY: return "dry";
487
+ case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
488
+ case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
489
+ case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
490
+ case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
491
+ case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
492
+ case COMMON_SAMPLER_TYPE_XTC: return "xtc";
493
+ case COMMON_SAMPLER_TYPE_INFILL: return "infill";
494
+ case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties";
495
+ default : return "";
496
+ }
497
+ }
498
+
499
+ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
500
+ std::unordered_map<std::string, common_sampler_type> sampler_canonical_name_map {
501
+ { "dry", COMMON_SAMPLER_TYPE_DRY },
502
+ { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
503
+ { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
504
+ { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
505
+ { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
506
+ { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
507
+ { "xtc", COMMON_SAMPLER_TYPE_XTC },
508
+ { "infill", COMMON_SAMPLER_TYPE_INFILL },
509
+ { "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
510
+ };
511
+
512
+ // since samplers names are written multiple ways
513
+ // make it ready for both system names and input names
514
+ std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
515
+ { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
516
+ { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
517
+ { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
518
+ { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
519
+ { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
520
+ { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
521
+ { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
522
+ { "min-p", COMMON_SAMPLER_TYPE_MIN_P },
523
+ { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
524
+ };
525
+
526
+ std::vector<common_sampler_type> samplers;
527
+ samplers.reserve(names.size());
528
+
529
+ for (const auto & name : names) {
530
+ auto sampler = sampler_canonical_name_map.find(name);
531
+ if (sampler != sampler_canonical_name_map.end()) {
532
+ samplers.push_back(sampler->second);
533
+ } else {
534
+ if (allow_alt_names) {
535
+ sampler = sampler_alt_name_map.find(name);
536
+ if (sampler != sampler_alt_name_map.end()) {
537
+ samplers.push_back(sampler->second);
538
+ }
539
+ }
540
+ }
541
+ }
542
+
543
+ return samplers;
544
+ }
545
+
546
+ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::string & chars) {
547
+ std::unordered_map<char, common_sampler_type> sampler_name_map = {
548
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY },
549
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
550
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
551
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
552
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
553
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
554
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
555
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
556
+ { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES },
557
+ };
558
+
559
+ std::vector<common_sampler_type> samplers;
560
+ samplers.reserve(chars.size());
561
+
562
+ for (const auto & c : chars) {
563
+ const auto sampler = sampler_name_map.find(c);
564
+ if (sampler != sampler_name_map.end()) {
565
+ samplers.push_back(sampler->second);
566
+ }
567
+ }
568
+
569
+ return samplers;
570
+ }
common/sampling.h ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+
5
+ #include "common.h"
6
+
7
+ #include <string>
8
+ #include <vector>
9
+
10
+ // common_sampler extends llama_sampler with additional functionality:
11
+ //
12
+ // - grammar support
13
+ // - custom sampler logic based on the parameters
14
+ // - history of the last accepted tokens
15
+ // - performance metrics
16
+ //
17
+ // This goal is to have a common implementation of the sampling logic shared across the examples.
18
+ // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
19
+ // complex (top-k, top-p, etc).
20
+ //
21
+ // Another example is related to the grammar. In general, the grammar constraints applied on the full
22
+ // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
23
+ // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
24
+ // grammar constraints are applied to the full vocabulary and the token is resampled.
25
+ //
26
+ // The common_sampler also maintains a container with the last accepted tokens. In the future, this can
27
+ // be moved into the core llama library.
28
+ //
29
+ // For convenience, the common_sampler also maintains a container with the current candidate tokens.
30
+ // This can be used to access the probabilities of the rest of the non-sampled tokens.
31
+ //
32
+ // TODO: measure grammar performance
33
+ //
34
+
35
+ struct common_sampler;
36
+
37
+ // llama_sampler API overloads
38
+
39
+ struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
40
+
41
+ void common_sampler_free(struct common_sampler * gsmpl);
42
+
43
+ // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
44
+ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
45
+ void common_sampler_reset (struct common_sampler * gsmpl);
46
+ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
47
+
48
+ // arguments can be nullptr to skip printing
49
+ void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
50
+
51
+ // extended sampling implementation:
52
+ //
53
+ // - set logits
54
+ // - apply the configured sampler chain
55
+ // - check if the token fits the grammar (if any)
56
+ // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
57
+ //
58
+ // if grammar_first is true, the grammar is applied before the samplers (slower)
59
+ // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
60
+ //
61
+ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
62
+
63
+ // generalized version of common_sampler_sample
64
+ //
65
+ // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
66
+ // if the sampler disagrees at some point, we stop and return the accepted tokens up to now
67
+ //
68
+ // common_sampler_sample_n(gsmpl, ctx, { idx }, {});
69
+ //
70
+ // is equivalent to
71
+ //
72
+ // common_sampler_sample(gsmpl, ctx, idx);
73
+ // common_sampler_accept(gsmpl, token, true);
74
+ //
75
+ // requires: idxs.size() == draft.size() + 1
76
+ //
77
+ // returns at least 1 token, up to idxs.size()
78
+ //
79
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
80
+
81
+ // assume idxs == [ 0, 1, 2, ..., draft.size() ]
82
+ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
83
+
84
+ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
85
+
86
+ // helpers
87
+
88
+ // access the internal list of current candidate tokens
89
+ llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
90
+
91
+ // get the last accepted token
92
+ llama_token common_sampler_last(const struct common_sampler * gsmpl);
93
+
94
+ // print the sampler chain into a string
95
+ std::string common_sampler_print(const struct common_sampler * gsmpl);
96
+
97
+ // get a string representation of the last accepted tokens
98
+ std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
99
+
100
+ char common_sampler_type_to_chr(enum common_sampler_type cnstr);
101
+ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
102
+
103
+ std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
104
+ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
105
+
106
+ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
107
+ const char * grammar_kind, const char * grammar_data);
common/speculative.cpp ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "speculative.h"
2
+
3
+ #include "log.h"
4
+ #include "common.h"
5
+ #include "sampling.h"
6
+
7
+ #include <cstring>
8
+ #include <algorithm>
9
+
10
+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
11
+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
12
+
13
+ struct common_speculative {
14
+ struct llama_context * ctx;
15
+ struct common_sampler * smpl;
16
+
17
+ llama_batch batch;
18
+ llama_tokens prompt;
19
+ };
20
+
21
+ struct common_speculative * common_speculative_init(
22
+ struct llama_context * ctx_dft) {
23
+ auto * result = new common_speculative {
24
+ /* .ctx = */ ctx_dft,
25
+ /* .smpl = */ nullptr,
26
+ /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
27
+ /* .prompt = */ {},
28
+ };
29
+
30
+ // TODO: optimize or pass from outside?
31
+ #if 0
32
+ {
33
+ common_params_sampling params;
34
+ params.no_perf = false;
35
+
36
+ params.top_k = 40;
37
+ params.top_p = 0.9;
38
+
39
+ params.samplers = {
40
+ COMMON_SAMPLER_TYPE_TOP_K,
41
+ COMMON_SAMPLER_TYPE_TOP_P,
42
+ COMMON_SAMPLER_TYPE_INFILL,
43
+ };
44
+
45
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
46
+ }
47
+ #else
48
+ {
49
+ common_params_sampling params;
50
+ params.no_perf = false;
51
+
52
+ params.top_k = 10;
53
+
54
+ params.samplers = {
55
+ COMMON_SAMPLER_TYPE_TOP_K,
56
+ };
57
+
58
+ result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
59
+ }
60
+ #endif
61
+
62
+ return result;
63
+ }
64
+
65
+ void common_speculative_free(struct common_speculative * spec) {
66
+ if (spec == nullptr) {
67
+ return;
68
+ }
69
+
70
+ common_sampler_free(spec->smpl);
71
+
72
+ llama_batch_free(spec->batch);
73
+
74
+ delete spec;
75
+ }
76
+
77
+ bool common_speculative_are_compatible(
78
+ const struct llama_context * ctx_tgt,
79
+ const struct llama_context * ctx_dft) {
80
+ const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
81
+ const struct llama_model * model_dft = llama_get_model(ctx_dft);
82
+
83
+ const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
84
+ const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
85
+
86
+ const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
87
+ LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
88
+
89
+ const bool vocab_type_dft = llama_vocab_type(vocab_dft);
90
+ LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
91
+
92
+ if (vocab_type_tgt != vocab_type_dft) {
93
+ LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
94
+ "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
95
+ return false;
96
+ }
97
+
98
+ if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
99
+ llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
100
+ llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
101
+ llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
102
+ LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
103
+ LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
104
+ LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
105
+ return false;
106
+ }
107
+
108
+ {
109
+ const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
110
+ const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
111
+
112
+ const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
113
+
114
+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
115
+ LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
116
+ "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
117
+ __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
118
+ return false;
119
+ }
120
+
121
+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
122
+ const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
123
+ const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
124
+ if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
125
+ LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
126
+ "token %d content differs - target '%s', draft '%s'\n", __func__, i,
127
+ common_token_to_piece(ctx_tgt, i).c_str(),
128
+ common_token_to_piece(ctx_dft, i).c_str());
129
+ return false;
130
+ }
131
+ }
132
+ }
133
+
134
+ return true;
135
+ }
136
+
137
+ llama_tokens common_speculative_gen_draft(
138
+ struct common_speculative * spec,
139
+ struct common_speculative_params params,
140
+ const llama_tokens & prompt_tgt,
141
+ llama_token id_last) {
142
+ auto & batch = spec->batch;
143
+ auto & ctx = spec->ctx;
144
+ auto & smpl = spec->smpl;
145
+ auto & prompt = spec->prompt;
146
+
147
+ int reuse_i = 0;
148
+ int reuse_n = 0;
149
+
150
+ const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
151
+
152
+ const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
153
+
154
+ // reuse as much as possible from the old draft context
155
+ // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
156
+ for (int i = 0; i < (int) prompt.size(); ++i) {
157
+ int cur = 0;
158
+ while (i_start + cur < (int) prompt_tgt.size() &&
159
+ i + cur < (int) prompt.size() &&
160
+ prompt_tgt[i_start + cur] == prompt[i + cur]) {
161
+ cur++;
162
+ }
163
+
164
+ if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) {
165
+ reuse_i = i;
166
+ reuse_n = cur;
167
+ }
168
+ }
169
+
170
+ LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
171
+
172
+ llama_tokens result;
173
+ result.reserve(params.n_draft);
174
+
175
+ if (reuse_n == 0) {
176
+ llama_kv_cache_clear(ctx);
177
+
178
+ prompt.clear();
179
+ } else {
180
+ // this happens when a previous draft has been discarded (for example, due to being too small), but the
181
+ // target model agreed with it. in this case, we simply pass back the previous results to save compute
182
+ if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
183
+ for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
184
+ result.push_back(prompt[i]);
185
+
186
+ if (params.n_draft <= (int) result.size()) {
187
+ break;
188
+ }
189
+ }
190
+
191
+ return result;
192
+ }
193
+
194
+ if (reuse_i > 0) {
195
+ llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
196
+ llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
197
+
198
+ prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
199
+ }
200
+
201
+ if (reuse_n < (int) prompt.size()) {
202
+ llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
203
+
204
+ prompt.erase(prompt.begin() + reuse_n, prompt.end());
205
+ }
206
+ }
207
+
208
+ // prepare a batch to evaluate any new tokens in the prompt
209
+ common_batch_clear(batch);
210
+
211
+ for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
212
+ //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
213
+ common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
214
+
215
+ prompt.push_back(prompt_tgt[i]);
216
+ }
217
+
218
+ // we should rarely end-up here during normal decoding
219
+ if (batch.n_tokens > 0) {
220
+ //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
221
+
222
+ llama_decode(ctx, batch);
223
+ }
224
+
225
+ const llama_pos n_past = prompt.size();
226
+
227
+ LOG_DBG("%s: n_past = %d\n", __func__, n_past);
228
+
229
+ common_batch_clear(batch);
230
+ common_batch_add (batch, id_last, n_past, { 0 }, true);
231
+
232
+ prompt.push_back(id_last);
233
+
234
+ //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
235
+
236
+ llama_decode(ctx, batch);
237
+
238
+ common_sampler_reset(smpl);
239
+
240
+ // sample n_draft tokens from the draft model
241
+ for (int i = 0; i < params.n_draft; ++i) {
242
+ common_batch_clear(batch);
243
+
244
+ common_sampler_sample(smpl, ctx, 0, true);
245
+
246
+ const auto * cur_p = common_sampler_get_candidates(smpl);
247
+
248
+ for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
249
+ LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
250
+ k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
251
+ }
252
+
253
+ // add drafted token for each sequence
254
+ const llama_token id = cur_p->data[0].id;
255
+
256
+ common_sampler_accept(smpl, id, true);
257
+
258
+ result.push_back(id);
259
+
260
+ if (params.n_draft <= (int) result.size()) {
261
+ break;
262
+ }
263
+
264
+ // only collect very high-confidence draft tokens
265
+ if (cur_p->data[0].p < params.p_min) {
266
+ break;
267
+ }
268
+
269
+ common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
270
+
271
+ // evaluate the drafted tokens on the draft model
272
+ llama_decode(ctx, batch);
273
+
274
+ prompt.push_back(id);
275
+ }
276
+
277
+ return result;
278
+ }
common/speculative.h ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "llama.h"
4
+ #include "common.h"
5
+
6
+ struct common_speculative;
7
+
8
+ struct common_speculative_params {
9
+ int n_draft = 16; // max drafted tokens
10
+ int n_reuse = 256;
11
+
12
+ float p_min = 0.75f; // min probability required to accept a token in the draft
13
+ };
14
+
15
+ struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
16
+
17
+ void common_speculative_free(struct common_speculative * spec);
18
+
19
+ bool common_speculative_are_compatible(
20
+ const struct llama_context * ctx_tgt,
21
+ const struct llama_context * ctx_dft);
22
+
23
+ // sample up to n_draft tokens and add them to the batch using the draft model
24
+ llama_tokens common_speculative_gen_draft(
25
+ struct common_speculative * spec,
26
+ struct common_speculative_params params,
27
+ const llama_tokens & prompt,
28
+ llama_token id_last);