BeveledCube commited on
Commit
2cb4613
·
1 Parent(s): a1d6a9e

Modified frontend and added a few models

Browse files
Files changed (5) hide show
  1. models/fast.py +16 -0
  2. models/llama2.py +16 -0
  3. models/mamba.py +16 -0
  4. models/tiny.py +16 -0
  5. templates/index.html +90 -2
models/fast.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = "power-greg/super-fast-llm"
4
+
5
+ def load():
6
+ global model
7
+ global tokenizer
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def generate(input_text):
13
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
+
16
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/llama2.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = "meta-llama/Llama-2-13b-chat-hf"
4
+
5
+ def load():
6
+ global model
7
+ global tokenizer
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def generate(input_text):
13
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
+
16
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/mamba.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = "s3nh/mamba-gpt-3b-v3-GGML"
4
+
5
+ def load():
6
+ global model
7
+ global tokenizer
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def generate(input_text):
13
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
+
16
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
models/tiny.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = "roneneldan/TinyStories-1M"
4
+
5
+ def load():
6
+ global model
7
+ global tokenizer
8
+
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ def generate(input_text):
13
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
14
+ output_ids = model.generate(input_ids, num_beams=5, no_repeat_ngram_size=2)
15
+
16
+ return tokenizer.decode(output_ids[0], skip_special_tokens=True)
templates/index.html CHANGED
@@ -2,6 +2,7 @@
2
  <html lang="en">
3
 
4
  <head>
 
5
  <title>AI API</title>
6
  <style>
7
  body {
@@ -10,6 +11,36 @@
10
  background-color: rgb(50, 50, 50);
11
  }
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  .img {
14
  width: 40vh;
15
  height: 40vh;
@@ -31,8 +62,65 @@
31
  </head>
32
 
33
  <body>
34
- <h1 class="text">Hello there!</h1>
35
- <span class="text">For the API use a GET request to /API</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  </body>
37
 
38
  </html>
 
2
  <html lang="en">
3
 
4
  <head>
5
+ <link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
6
  <title>AI API</title>
7
  <style>
8
  body {
 
11
  background-color: rgb(50, 50, 50);
12
  }
13
 
14
+ button {
15
+ cursor: pointer;
16
+
17
+ border-style: solid;
18
+ border-width: 3px;
19
+ border-style: solid;
20
+ border-radius: 5px;
21
+
22
+ text-align: center;
23
+
24
+ margin: 3px;
25
+ margin-top: 0;
26
+ margin-bottom: 0;
27
+ }
28
+
29
+ input {
30
+ width: 200px;
31
+ padding: 10px;
32
+ border: 1px solid #ccc;
33
+ background-color: #6b6e7266;
34
+ color: #e9e9e9;
35
+ border-radius: 4px;
36
+
37
+ transition: all, 0.35s;
38
+ }
39
+
40
+ input:focus {
41
+ outline: none;
42
+ }
43
+
44
  .img {
45
  width: 40vh;
46
  height: 40vh;
 
62
  </head>
63
 
64
  <body>
65
+ <h1 class="text">Chat with me</h1>
66
+ <div id="responses"></div>
67
+
68
+ <input class="input" type="text" id="prompt" placeholder="bake a cake">
69
+ <button class="send-button" id="send-prompt">
70
+ <i class="material-icons">send</i>
71
+ </button>
72
+
73
+ <script>
74
+ const apiUrl = `https://beveledcube-bevelapi.hf.space/api`;
75
+ const sendPromptButton = document.getElementById("send-prompt");
76
+ const responseContainer = document.getElementById("responses");
77
+
78
+ sendPromptButton.addEventListener("click", async () => {
79
+ console.log("Sending prompt")
80
+
81
+ const responseElement = document.createElement("div");
82
+ const requestData = { prompt: getValue("prompt") };
83
+
84
+ responseElement.classList.add("response-container");
85
+
86
+ responseElement.innerHTML = `<span class="text"><p><strong>You<br></strong>${requestData.prompt}</p>`;
87
+
88
+ responseContainer.appendChild(responseElement);
89
+
90
+ fetch(apiUrl, {
91
+ method: "POST",
92
+ headers: {
93
+ "Content-Type": "application/json"
94
+ },
95
+ body: JSON.stringify(requestData)
96
+ })
97
+ .then(response => {
98
+ if (!response.ok) {
99
+ throw new Error("Network response was " + response.status.toString());
100
+ }
101
+
102
+ return response.json();
103
+ })
104
+ .then(data => {
105
+ console.log("Response from API:", data);
106
+ const responseElement = document.createElement("div");
107
+
108
+ responseElement.classList.add("response-container");
109
+
110
+ responseElement.innerHTML = `<span class="text"><p><strong>AI<br></strong>${data.answer.replace("\n", "<br>")}</p>`;
111
+
112
+ responseContainer.appendChild(responseElement);
113
+ })
114
+ .catch(error => {
115
+ console.error("Error:", error.message);
116
+ });
117
+
118
+ });
119
+
120
+ function getValue(elementId) {
121
+ return document.getElementById(elementId).value;
122
+ }
123
+ </script>
124
  </body>
125
 
126
  </html>