chatV / lm /server_lm /main_lm.py
tangmen's picture
add files
113dbd0
import os
import sys
import logging
from flask import Flask, request, jsonify
from flask_cors import CORS
from vllm import LLM, SamplingParams
# from serve import get_model_api
import os
import os
from pathlib import Path
import csv
import json
import openai
import time
import pandas as pd
# Set up the OpenAI API client
api_key = "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"
openai.api_key = api_key
# Set up the chatGPT model and prompt
model_engine = "text-davinci-003"
import gradio as gr
import time
import argparse
from vllm import LLM, SamplingParams
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str) # model path
parser.add_argument("--n_gpu", type=int, default=1) # n_gpu
return parser.parse_args()
def echo(message, history, system_prompt, temperature, max_tokens):
response = f"System prompt: {system_prompt}\n Message: {message}. \n Temperature: {temperature}. \n Max Tokens: {max_tokens}."
for i in range(min(len(response), int(max_tokens))):
time.sleep(0.05)
yield response[: i+1]
# def get_llm_result(input_data, input_domain):
def get_llm_result(input_sys_prompt_str, input_history_str, prompt_str, llm):
# data is file path of topic result
prompt = ""
def predict(message, history, system_prompt, temperature, max_tokens):
instruction = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
for human, assistant in history:
instruction += 'USER: '+ human + ' ASSISTANT: '+ assistant + '</s>'
instruction += 'USER: '+ message + ' ASSISTANT:'
problem = [instruction]
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens)
completions = llm.generate(problem, sampling_params)
for output in completions:
prompt = output.prompt
generated_text = output.outputs[0].text
return generated_text
# for idx in range(len(generated_text)):
# yield generated_text[:idx+1]
try:
# completion = openai.Completion.create(
# engine=model_engine,
# prompt=prompt,
# max_tokens=3000,
# n=1,
# stop=None,
# temperature=0.5,
# )
#
# response = completion.choices[0].text
# shorten_response = response.replace("\n", "").strip()
# len_response = len(shorten_response.split(" "))
# if len_response >= 3500:
# shorten_response = "".join(shorten_response.split(" ")[:3500])
# print("X"*10)
# print(f"shorten_response is {shorten_response}")
# list_shorten = shorten_response.split(" ")
# print(list_shorten)
# print(f"length is {len(list_shorten)}")
# title_prompt = f"{shorten_response},给这个文章写一个头条号风格的标题。增加标题的吸引力,可读性。"
# title_completion = openai.Completion.create(
# engine=model_engine,
# prompt=title_prompt,
# max_tokens=200,
# n=1,
# stop=None,
# temperature=0.5,
# )
# title_response = title_completion.choices[0].text
history = input_history_str
prompt = prompt_str
system_prompt = input_sys_prompt_str
response = predict(prompt, history, system_prompt, 0.5, 3000)
print(response)
# if not os.path.isdir(topic_file_path):
# print("File folder not exist")
# topic_result_file = ""
# topic_file_name_pattern = "step10_json_filestep9_merge_rewrite_"
# for filename in os.listdir(topic_file_path):
# if filename.startswith(topic_file_name_pattern):
# topic_result_file = os.path.join(topic_file_path, filename)
#
# data_aligned = dict()
# output_dir_name = "."
# output_dir = os.path.join(output_dir_name, "result_topic_file")
# Path(output_dir).mkdir(parents=True, exist_ok=True)
# write_file_name = "save_server_" + topic_file_path.split("\\")[-1]
# write_output_file_path = os.path.join(output_dir, write_file_name)
#
# with open(topic_result_file, encoding="utf8") as f:
# json_data = json.load(f)
# return json_data
return response, response
except Exception as ex:
print("File not exist")
raise ex
# config = Config()
# model = NERModel(config)
# define the app
app = Flask(__name__)
CORS(app) # needed for cross-domain requests, allow everything by default
# logging for heroku
if 'DYNO' in os.environ:
app.logger.addHandler(logging.StreamHandler(sys.stdout))
app.logger.setLevel(logging.INFO)
app.logger.addHandler(logging.StreamHandler(sys.stdout))
app.logger.setLevel(logging.INFO)
# load the model
# model_api = get_model_api()
# API route
@app.route('/api', methods=['POST'])
def api():
"""API function
All model-specific logic to be defined in the get_model_api()
function
"""
input_data = request.json
log = open("test_topic_serve_log.csv", 'a', encoding='utf-8')
app.logger.info("api_input: " + str(input_data))
log.write("api_input: " + str(input_data))
# input_title_str = input_data['title']
# input_domain_str = input_data['domain']
input_sys_prompt_str = input_data['system_prompt']
input_USER_str = input_data['USER']
# input_ASSISTANT_str = input_data['ASSISTANT']
input_history_str = input_data['history']
# output_data = model_api(input_title_str, input_domain_str
model_path = "/workspaceblobstore/caxu/trained_models/13Bv2_497kcontinueroleplay_dsys_2048_e4_2e_5/checkpoint-75"
llm = LLM(model=model_path, tensor_parallel_size=1)
output_data = get_llm_result(input_sys_prompt_str, input_history_str, input_USER_str, llm)
app.logger.info("api_output: " + str(output_data))
response = jsonify(output_data)
log.write("api_output: " + str(output_data) + "\n")
return response
# API2 route
@app.route('/labelapi', methods=['POST'])
def labelapi():
"""label API function
record user label action
All model-specific logic to be defined in the get_model_api()
function
"""
input_data = request.json
log = open("test_topic_label_log.csv", 'a', encoding='utf-8')
app.logger.info("api_input: " + str(input_data))
log.write("api_input: " + str(input_data)+ "\n")
output_data = {"input": input_data, "output": "record_success"}
response = output_data
return response
@app.route('/')
def index():
return "Index API"
# HTTP Errors handlers
@app.errorhandler(404)
def url_error(e):
return """
Wrong URL!
<pre>{}</pre>""".format(e), 404
@app.errorhandler(500)
def server_error(e):
return """
An internal error occurred: <pre>{}</pre>
See logs for full stacktrace.
""".format(e), 500
if __name__ == '__main__':
# This is used when running locally.
# llm = LLM(model=model_path, tensor_parallel_size=1)
app.run(host='0.0.0.0',port=4455,debug=True)
# app.run(host='0.0.0.0',port=4456,debug=True)