Spaces:
Runtime error
Runtime error
Commit
·
326ad4b
1
Parent(s):
2b7cdd8
fix: bugfix mistral generate method
Browse files- model/mistral.py +7 -2
model/mistral.py
CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
|
|
7 |
|
8 |
# internal imports
|
9 |
from utils import modelling as mdl
|
|
|
10 |
|
11 |
# global model and tokenizer instance (created on inital build)
|
12 |
device = mdl.get_device()
|
@@ -91,6 +92,9 @@ def format_answer(answer: str):
|
|
91 |
# empty answer string
|
92 |
formatted_answer = ""
|
93 |
|
|
|
|
|
|
|
94 |
# extracting text after INST tokens
|
95 |
parts = answer.split("[/INST]")
|
96 |
if len(parts) >= 3:
|
@@ -106,10 +110,11 @@ def format_answer(answer: str):
|
|
106 |
def respond(prompt: str):
|
107 |
|
108 |
# tokenizing inputs and configuring model
|
109 |
-
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"]
|
110 |
|
111 |
# generating text with tokenized input, returning output
|
112 |
-
output_ids = MODEL.generate(input_ids,
|
113 |
output_text = TOKENIZER.batch_decode(output_ids)
|
|
|
114 |
|
115 |
return format_answer(output_text)
|
|
|
7 |
|
8 |
# internal imports
|
9 |
from utils import modelling as mdl
|
10 |
+
from utils import formatting as fmt
|
11 |
|
12 |
# global model and tokenizer instance (created on inital build)
|
13 |
device = mdl.get_device()
|
|
|
92 |
# empty answer string
|
93 |
formatted_answer = ""
|
94 |
|
95 |
+
if type(answer) == list:
|
96 |
+
answer = fmt.format_output_text
|
97 |
+
|
98 |
# extracting text after INST tokens
|
99 |
parts = answer.split("[/INST]")
|
100 |
if len(parts) >= 3:
|
|
|
110 |
def respond(prompt: str):
|
111 |
|
112 |
# tokenizing inputs and configuring model
|
113 |
+
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device)
|
114 |
|
115 |
# generating text with tokenized input, returning output
|
116 |
+
output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
|
117 |
output_text = TOKENIZER.batch_decode(output_ids)
|
118 |
+
output_text.fo
|
119 |
|
120 |
return format_answer(output_text)
|