File size: 5,358 Bytes
7def60a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package gallery_test

import (
	"errors"
	"os"
	"path/filepath"

	"github.com/mudler/LocalAI/core/config"
	. "github.com/mudler/LocalAI/core/gallery"
	. "github.com/onsi/ginkgo/v2"
	. "github.com/onsi/gomega"
	"gopkg.in/yaml.v3"
)

var _ = Describe("Model test", func() {

	Context("Downloading", func() {
		It("applies model correctly", func() {
			tempdir, err := os.MkdirTemp("", "test")
			Expect(err).ToNot(HaveOccurred())
			defer os.RemoveAll(tempdir)
			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
			Expect(err).ToNot(HaveOccurred())
			err = InstallModel(tempdir, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
			Expect(err).ToNot(HaveOccurred())

			for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
				_, err = os.Stat(filepath.Join(tempdir, f))
				Expect(err).ToNot(HaveOccurred())
			}

			content := map[string]interface{}{}

			dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml"))
			Expect(err).ToNot(HaveOccurred())

			err = yaml.Unmarshal(dat, content)
			Expect(err).ToNot(HaveOccurred())

			Expect(content["context_size"]).To(Equal(1024))
		})

		It("applies model from gallery correctly", func() {
			tempdir, err := os.MkdirTemp("", "test")
			Expect(err).ToNot(HaveOccurred())
			defer os.RemoveAll(tempdir)

			gallery := []GalleryModel{{
				Name: "bert",
				URL:  "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
			}}
			out, err := yaml.Marshal(gallery)
			Expect(err).ToNot(HaveOccurred())
			galleryFilePath := filepath.Join(tempdir, "gallery_simple.yaml")
			err = os.WriteFile(galleryFilePath, out, 0600)
			Expect(err).ToNot(HaveOccurred())
			Expect(filepath.IsAbs(galleryFilePath)).To(BeTrue(), galleryFilePath)
			galleries := []config.Gallery{
				{
					Name: "test",
					URL:  "file://" + galleryFilePath,
				},
			}

			models, err := AvailableGalleryModels(galleries, tempdir)
			Expect(err).ToNot(HaveOccurred())
			Expect(len(models)).To(Equal(1))
			Expect(models[0].Name).To(Equal("bert"))
			Expect(models[0].URL).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"))
			Expect(models[0].Installed).To(BeFalse())

			err = InstallModelFromGallery(galleries, "test@bert", tempdir, GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true)
			Expect(err).ToNot(HaveOccurred())

			dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
			Expect(err).ToNot(HaveOccurred())

			content := map[string]interface{}{}
			err = yaml.Unmarshal(dat, &content)
			Expect(err).ToNot(HaveOccurred())
			Expect(content["backend"]).To(Equal("bert-embeddings"))

			models, err = AvailableGalleryModels(galleries, tempdir)
			Expect(err).ToNot(HaveOccurred())
			Expect(len(models)).To(Equal(1))
			Expect(models[0].Installed).To(BeTrue())

			// delete
			err = DeleteModelFromSystem(tempdir, "bert", []string{})
			Expect(err).ToNot(HaveOccurred())

			models, err = AvailableGalleryModels(galleries, tempdir)
			Expect(err).ToNot(HaveOccurred())
			Expect(len(models)).To(Equal(1))
			Expect(models[0].Installed).To(BeFalse())

			_, err = os.Stat(filepath.Join(tempdir, "bert.yaml"))
			Expect(err).To(HaveOccurred())
			Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
		})

		It("renames model correctly", func() {
			tempdir, err := os.MkdirTemp("", "test")
			Expect(err).ToNot(HaveOccurred())
			defer os.RemoveAll(tempdir)
			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
			Expect(err).ToNot(HaveOccurred())

			err = InstallModel(tempdir, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
			Expect(err).ToNot(HaveOccurred())

			for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
				_, err = os.Stat(filepath.Join(tempdir, f))
				Expect(err).ToNot(HaveOccurred())
			}
		})

		It("overrides parameters", func() {
			tempdir, err := os.MkdirTemp("", "test")
			Expect(err).ToNot(HaveOccurred())
			defer os.RemoveAll(tempdir)
			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
			Expect(err).ToNot(HaveOccurred())

			err = InstallModel(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
			Expect(err).ToNot(HaveOccurred())

			for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
				_, err = os.Stat(filepath.Join(tempdir, f))
				Expect(err).ToNot(HaveOccurred())
			}

			content := map[string]interface{}{}

			dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml"))
			Expect(err).ToNot(HaveOccurred())

			err = yaml.Unmarshal(dat, content)
			Expect(err).ToNot(HaveOccurred())

			Expect(content["backend"]).To(Equal("foo"))
		})

		It("catches path traversals", func() {
			tempdir, err := os.MkdirTemp("", "test")
			Expect(err).ToNot(HaveOccurred())
			defer os.RemoveAll(tempdir)
			c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml"))
			Expect(err).ToNot(HaveOccurred())

			err = InstallModel(tempdir, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
			Expect(err).To(HaveOccurred())
		})
	})
})