Spaces:
Runtime error
Runtime error
File size: 2,386 Bytes
d2f4f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from transformers import AutoModelForCausalLM, AutoTokenizer
from flask import Flask, request
import argparse
import logging
class LLMInstance:
def __init__(self, model_path: str, device: str = "cuda"):
self.model = AutoModelForCausalLM.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model.to(device)
self.device = device
def query(self, message):
try:
messages = [
{"role": "user", "content": message},
]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(self.device)
generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded = self.tokenizer.batch_decode(generated_ids)
# output is the string decoded[0] after "[/INST]". There may exist "</s>", delete it.
output = decoded[0].split("[/INST]")[1].split("</s>")[0]
return {
'code': 0,
'ret': True,
'error_msg': None,
'output': output
}
except Exception as e:
return {
'code': 1,
'ret': False,
'error_msg': str(e),
'output': None
}
def create_app(core):
app = Flask(__name__)
@app.route('/ask_llm_for_answer', methods=['POST'])
def ask_llm_for_answer():
user_text = request.json['user_text']
return core.query(user_text)
return app
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', required=True, default='Mistral-7B-Instruct-v0.1', help='the model path of reward model')
parser.add_argument('--ip', default='0.0.0.0')
parser.add_argument('-p', '--port', default=8001)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)
else:
logging.getLogger().setLevel(logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
logging.getLogger().handlers[0].setFormatter(logging.Formatter("%(message)s"))
core = LLMInstance(args.model_path)
app = create_app(core)
app.run(host=args.ip, port=args.port)
|