File size: 2,248 Bytes
7def60a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package main

// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
	"fmt"
	"path/filepath"

	"github.com/donomii/go-rwkv.cpp"
	"github.com/mudler/LocalAI/pkg/grpc/base"
	pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)

const tokenizerSuffix = ".tokenizer.json"

type LLM struct {
	base.SingleThread

	rwkv *rwkv.RwkvState
}

func (llm *LLM) Load(opts *pb.ModelOptions) error {
	tokenizerFile := opts.Tokenizer
	if tokenizerFile == "" {
		modelFile := filepath.Base(opts.ModelFile)
		tokenizerFile = modelFile + tokenizerSuffix
	}
	modelPath := filepath.Dir(opts.ModelFile)
	tokenizerPath := filepath.Join(modelPath, tokenizerFile)

	model := rwkv.LoadFiles(opts.ModelFile, tokenizerPath, uint32(opts.GetThreads()))

	if model == nil {
		return fmt.Errorf("rwkv could not load model")
	}
	llm.rwkv = model
	return nil
}

func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
	stopWord := "\n"
	if len(opts.StopPrompts) > 0 {
		stopWord = opts.StopPrompts[0]
	}

	if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
		return "", err
	}

	response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil)

	return response, nil
}

func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
	go func() {

		stopWord := "\n"
		if len(opts.StopPrompts) > 0 {
			stopWord = opts.StopPrompts[0]
		}

		if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
			fmt.Println("Error processing input: ", err)
			return
		}

		llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool {
			results <- s
			return true
		})
		close(results)
	}()

	return nil
}

func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) {
	tokens, err := llm.rwkv.Tokenizer.Encode(opts.Prompt)
	if err != nil {
		return pb.TokenizationResponse{}, err
	}

	l := len(tokens)
	i32Tokens := make([]int32, l)

	for i, t := range tokens {
		i32Tokens[i] = int32(t.ID)
	}

	return pb.TokenizationResponse{
		Length: int32(l),
		Tokens: i32Tokens,
	}, nil
}