Spaces:
Build error
Build error
File size: 5,591 Bytes
69fb50e edc20ac 69fb50e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
#include "common.h"
#include "llama.h"
#include "build-info.h"
#include <vector>
#include <cstdio>
#include <chrono>
int main(int argc, char ** argv) {
gpt_params params;
params.seed = 42;
params.n_threads = 4;
params.repeat_last_n = 64;
params.prompt = "The quick brown fox";
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
if (params.n_predict < 0) {
params.n_predict = 16;
}
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
lparams.n_gqa = params.n_gqa;
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
auto n_past = 0;
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == nullptr) {
return 1;
}
auto ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
}
auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
llama_free(ctx);
llama_free_model(model);
return 1;
}
// evaluate prompt
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens;
const size_t state_size = llama_get_state_size(ctx);
uint8_t * state_mem = new uint8_t[state_size];
// Save state (rng, logits, embedding and kv_cache) to file
{
FILE *fp_write = fopen("dump_state.bin", "wb");
llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
fwrite(state_mem, 1, state_size, fp_write);
fclose(fp_write);
}
// save state (last tokens)
const auto last_n_tokens_data_saved = std::vector<llama_token>(last_n_tokens_data);
const auto n_past_saved = n_past;
// first run
printf("\n%s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
auto logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx, &candidates_p);
auto next_token_str = llama_token_to_str(ctx, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
return 1;
}
n_past += 1;
}
printf("\n\n");
// free old context
llama_free(ctx);
// make new context
auto ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
FILE *fp_read = fopen("dump_state.bin", "rb");
if (state_size != llama_get_state_size(ctx2)) {
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1;
}
const size_t ret = fread(state_mem, 1, state_size, fp_read);
if (ret != state_size) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1;
}
llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file
fclose(fp_read);
}
delete[] state_mem;
// restore state (last tokens)
last_n_tokens_data = last_n_tokens_data_saved;
n_past = n_past_saved;
// second run
for (auto i = 0; i < params.n_predict; i++) {
auto logits = llama_get_logits(ctx2);
auto n_vocab = llama_n_vocab(ctx2);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
auto next_token = llama_sample_token(ctx2, &candidates_p);
auto next_token_str = llama_token_to_str(ctx2, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1;
}
n_past += 1;
}
printf("\n\n");
llama_free(ctx2);
llama_free_model(model);
return 0;
}
|