|
import os |
|
import sys |
|
import logging |
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from vllm import LLM, SamplingParams |
|
|
|
|
|
|
|
import os |
|
import os |
|
from pathlib import Path |
|
import csv |
|
import json |
|
import openai |
|
import time |
|
import pandas as pd |
|
|
|
|
|
api_key = "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" |
|
|
|
openai.api_key = api_key |
|
|
|
|
|
model_engine = "text-davinci-003" |
|
import gradio as gr |
|
import time |
|
import argparse |
|
from vllm import LLM, SamplingParams |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model", type=str) |
|
parser.add_argument("--n_gpu", type=int, default=1) |
|
return parser.parse_args() |
|
|
|
def echo(message, history, system_prompt, temperature, max_tokens): |
|
response = f"System prompt: {system_prompt}\n Message: {message}. \n Temperature: {temperature}. \n Max Tokens: {max_tokens}." |
|
for i in range(min(len(response), int(max_tokens))): |
|
time.sleep(0.05) |
|
yield response[: i+1] |
|
|
|
|
|
|
|
def get_llm_result(input_sys_prompt_str, input_history_str, prompt_str, llm): |
|
|
|
prompt = "" |
|
|
|
def predict(message, history, system_prompt, temperature, max_tokens): |
|
instruction = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " |
|
for human, assistant in history: |
|
instruction += 'USER: '+ human + ' ASSISTANT: '+ assistant + '</s>' |
|
instruction += 'USER: '+ message + ' ASSISTANT:' |
|
problem = [instruction] |
|
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"] |
|
sampling_params = SamplingParams(temperature=temperature, top_p=1, max_tokens=max_tokens, stop=stop_tokens) |
|
completions = llm.generate(problem, sampling_params) |
|
for output in completions: |
|
prompt = output.prompt |
|
generated_text = output.outputs[0].text |
|
return generated_text |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
history = input_history_str |
|
prompt = prompt_str |
|
system_prompt = input_sys_prompt_str |
|
|
|
response = predict(prompt, history, system_prompt, 0.5, 3000) |
|
|
|
print(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return response, response |
|
|
|
except Exception as ex: |
|
print("File not exist") |
|
raise ex |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
|
|
if 'DYNO' in os.environ: |
|
app.logger.addHandler(logging.StreamHandler(sys.stdout)) |
|
app.logger.setLevel(logging.INFO) |
|
|
|
app.logger.addHandler(logging.StreamHandler(sys.stdout)) |
|
app.logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
|
|
|
@app.route('/api', methods=['POST']) |
|
def api(): |
|
"""API function |
|
|
|
All model-specific logic to be defined in the get_model_api() |
|
function |
|
""" |
|
input_data = request.json |
|
log = open("test_topic_serve_log.csv", 'a', encoding='utf-8') |
|
app.logger.info("api_input: " + str(input_data)) |
|
log.write("api_input: " + str(input_data)) |
|
|
|
|
|
input_sys_prompt_str = input_data['system_prompt'] |
|
input_USER_str = input_data['USER'] |
|
|
|
input_history_str = input_data['history'] |
|
|
|
model_path = "/workspaceblobstore/caxu/trained_models/13Bv2_497kcontinueroleplay_dsys_2048_e4_2e_5/checkpoint-75" |
|
llm = LLM(model=model_path, tensor_parallel_size=1) |
|
|
|
output_data = get_llm_result(input_sys_prompt_str, input_history_str, input_USER_str, llm) |
|
app.logger.info("api_output: " + str(output_data)) |
|
response = jsonify(output_data) |
|
log.write("api_output: " + str(output_data) + "\n") |
|
|
|
return response |
|
|
|
|
|
@app.route('/labelapi', methods=['POST']) |
|
def labelapi(): |
|
"""label API function |
|
record user label action |
|
All model-specific logic to be defined in the get_model_api() |
|
function |
|
""" |
|
input_data = request.json |
|
log = open("test_topic_label_log.csv", 'a', encoding='utf-8') |
|
app.logger.info("api_input: " + str(input_data)) |
|
log.write("api_input: " + str(input_data)+ "\n") |
|
output_data = {"input": input_data, "output": "record_success"} |
|
|
|
response = output_data |
|
return response |
|
|
|
@app.route('/') |
|
def index(): |
|
return "Index API" |
|
|
|
|
|
@app.errorhandler(404) |
|
def url_error(e): |
|
return """ |
|
Wrong URL! |
|
<pre>{}</pre>""".format(e), 404 |
|
|
|
|
|
@app.errorhandler(500) |
|
def server_error(e): |
|
return """ |
|
An internal error occurred: <pre>{}</pre> |
|
See logs for full stacktrace. |
|
""".format(e), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
app.run(host='0.0.0.0',port=4455,debug=True) |
|
|
|
|
|
|