File size: 2,868 Bytes
f57d7c6
69fb50e
 
 
 
 
 
 
 
 
 
 
 
f57d7c6
69fb50e
 
 
 
 
f57d7c6
69fb50e
81bf9b4
69fb50e
 
 
81bf9b4
69fb50e
 
 
 
 
 
dc53b3a
69fb50e
 
 
 
 
 
 
 
 
 
 
46c2bfc
 
 
 
f57d7c6
46c2bfc
f57d7c6
 
69fb50e
 
 
46c2bfc
69fb50e
 
 
 
 
 
 
 
 
 
 
 
f57d7c6
69fb50e
 
 
 
46c2bfc
f57d7c6
46c2bfc
f57d7c6
 
 
 
 
46c2bfc
f57d7c6
 
69fb50e
f57d7c6
 
 
69fb50e
46c2bfc
 
69fb50e
f57d7c6
 
69fb50e
f57d7c6
69fb50e
 
 
 
 
dc53b3a
 
69fb50e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include "build-info.h"
#include "common.h"
#include "llama.h"

#include <ctime>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

int main(int argc, char ** argv) {
    gpt_params params;

    if (!gpt_params_parse(argc, argv, params)) {
        return 1;
    }

    params.embedding = true;

    print_build_info();

    if (params.seed == LLAMA_DEFAULT_SEED) {
        params.seed = time(NULL);
    }

    fprintf(stderr, "%s: seed  = %u\n", __func__, params.seed);

    std::mt19937 rng(params.seed);
    if (params.random_prompt) {
        params.prompt = gpt_random_prompt(rng);
    }

    llama_backend_init(params.numa);

    llama_model * model;
    llama_context * ctx;

    // load the model
    std::tie(model, ctx) = llama_init_from_gpt_params(params);
    if (model == NULL) {
        fprintf(stderr, "%s: error: unable to load model\n", __func__);
        return 1;
    }

    const int n_ctx_train = llama_n_ctx_train(model);
    const int n_ctx = llama_n_ctx(ctx);

    if (n_ctx > n_ctx_train) {
        fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
                __func__, n_ctx_train, n_ctx);
    }

    // print system information
    {
        fprintf(stderr, "\n");
        fprintf(stderr, "%s\n", get_system_info(params).c_str());
    }

    int n_past = 0;

    // tokenize the prompt
    auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);

    if (params.verbose_prompt) {
        fprintf(stderr, "\n");
        fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
        fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
        for (int i = 0; i < (int) embd_inp.size(); i++) {
            fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
        }
        fprintf(stderr, "\n");
    }

    if (embd_inp.size() > (size_t)n_ctx) {
        fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
                __func__, embd_inp.size(), n_ctx);
        return 1;
    }

    while (!embd_inp.empty()) {
        int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
        if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
            fprintf(stderr, "%s : failed to eval\n", __func__);
            return 1;
        }
        n_past += n_tokens;
        embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
    }

    const int n_embd = llama_n_embd(model);
    const auto * embeddings = llama_get_embeddings(ctx);

    for (int i = 0; i < n_embd; i++) {
        printf("%f ", embeddings[i]);
    }
    printf("\n");

    llama_print_timings(ctx);
    llama_free(ctx);
    llama_free_model(model);

    llama_backend_free();

    return 0;
}