|
"""Run codes.""" |
|
|
|
|
|
import gc |
|
import os |
|
import platform |
|
import random |
|
import time |
|
from dataclasses import asdict, dataclass |
|
from pathlib import Path |
|
from typing import Optional, Sequence |
|
|
|
|
|
import gradio as gr |
|
import psutil |
|
from about_time import about_time |
|
from ctransformers import AutoModelForCausalLM |
|
from dl_hf_model import dl_hf_model |
|
from examples_list import examples_list |
|
from loguru import logger |
|
|
|
url = "https://huggingface.co/TheBloke/CodeLlama-13B-Python-GGML/blob/main/codellama-13b-python.ggmlv3.Q4_K_M.bin" |
|
|
|
LLM = None |
|
gc.collect() |
|
|
|
try: |
|
logger.debug(f" dl {url}") |
|
model_loc, file_size = dl_hf_model(url) |
|
logger.info(f"done load llm {model_loc=} {file_size=}G") |
|
except Exception as exc_: |
|
logger.error(exc_) |
|
raise SystemExit(1) from exc_ |
|
|
|
|
|
|
|
|
|
|
|
prompt_template = """You are a helpful assistant. Let's think step by step. |
|
### Human: |
|
{question} |
|
### Assistant:""" |
|
|
|
|
|
prompt_template = """ |
|
### System: |
|
This is a system prompt, please behave and help the user. |
|
|
|
### Instruction: |
|
|
|
{question} |
|
|
|
### Response: |
|
""" |
|
prompt_template = """ |
|
[INST] Write code to solve the following coding problem that obeys the constraints and |
|
passes the example test cases. Please wrap your code answer using ```: |
|
{question} |
|
[/INST] |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
_ = psutil.cpu_count(logical=False) - 1 |
|
cpu_count: int = int(_) if _ else 1 |
|
logger.debug(f"{cpu_count=}") |
|
|
|
logger.debug(f"{model_loc=}") |
|
LLM = AutoModelForCausalLM.from_pretrained( |
|
model_loc, |
|
model_type="llama", |
|
threads=cpu_count, |
|
) |
|
|
|
os.environ["TZ"] = "Asia/Shanghai" |
|
try: |
|
time.tzset() |
|
except Exception: |
|
|
|
logger.warning("Windows, cant run time.tzset()") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class GenerationConfig: |
|
temperature: float = 0.7 |
|
top_k: int = 50 |
|
top_p: float = 0.9 |
|
repetition_penalty: float = 1.0 |
|
max_new_tokens: int = 512 |
|
seed: int = 42 |
|
reset: bool = False |
|
stream: bool = True |
|
threads: int = cpu_count |
|
|
|
|
|
|
|
@dataclass |
|
class Config: |
|
|
|
top_k: int = 40 |
|
top_p: float = 0.95 |
|
temperature: float = 0.8 |
|
repetition_penalty: float = 1.1 |
|
last_n_tokens: int = 64 |
|
seed: int = -1 |
|
|
|
|
|
batch_size: int = 8 |
|
threads: int = -1 |
|
|
|
|
|
max_new_tokens: int = 512 |
|
stop: Optional[Sequence[str]] = None |
|
stream: bool = True |
|
reset: bool = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate( |
|
question: str, |
|
llm=LLM, |
|
|
|
config: Config = Config(), |
|
): |
|
"""Run model inference, will return a Generator if streaming is true.""" |
|
|
|
|
|
|
|
prompt = prompt_template.format(question=question) |
|
|
|
return llm( |
|
prompt, |
|
**asdict(config), |
|
|
|
) |
|
|
|
|
|
|
|
logger.debug(f"{Config(stream=True)=}") |
|
logger.debug(f"{vars(Config(stream=True))=}") |
|
|
|
|
|
def user(user_message, history): |
|
|
|
if history is None: |
|
history = [] |
|
history.append([user_message, None]) |
|
return user_message, history |
|
|
|
|
|
def user1(user_message, history): |
|
|
|
if history is None: |
|
history = [] |
|
history.append([user_message, None]) |
|
return "", history |
|
|
|
|
|
def bot_(history): |
|
user_message = history[-1][0] |
|
resp = random.choice(["How are you?", "I love you", "I'm very hungry"]) |
|
bot_message = user_message + ": " + resp |
|
history[-1][1] = "" |
|
for character in bot_message: |
|
history[-1][1] += character |
|
time.sleep(0.02) |
|
yield history |
|
|
|
history[-1][1] = resp |
|
yield history |
|
|
|
|
|
def bot(history): |
|
user_message = "" |
|
try: |
|
user_message = history[-1][0] |
|
except Exception as exc: |
|
logger.error(exc) |
|
response = [] |
|
|
|
logger.debug(f"{user_message=}") |
|
|
|
with about_time() as atime: |
|
flag = 1 |
|
prefix = "" |
|
then = time.time() |
|
|
|
logger.debug("about to generate") |
|
|
|
config = GenerationConfig(reset=True) |
|
for elm in generate(user_message, config=config): |
|
if flag == 1: |
|
logger.debug("in the loop") |
|
prefix = f"({time.time() - then:.2f}s)\n" |
|
flag = 0 |
|
print(prefix, end="", flush=True) |
|
logger.debug(f"{prefix=}") |
|
print(elm, end="", flush=True) |
|
|
|
|
|
response.append(elm) |
|
history[-1][1] = prefix + "".join(response) |
|
yield history |
|
|
|
_ = ( |
|
f"(time elapsed: {atime.duration_human}, " |
|
f"{atime.duration/len(''.join(response)):.2f}s/char)" |
|
) |
|
|
|
history[-1][1] = "".join(response) + f"\n{_}" |
|
yield history |
|
|
|
|
|
def predict_api(prompt): |
|
logger.debug(f"{prompt=}") |
|
try: |
|
|
|
config = GenerationConfig( |
|
temperature=0.2, |
|
top_k=10, |
|
top_p=0.9, |
|
repetition_penalty=1.0, |
|
max_new_tokens=512, |
|
seed=42, |
|
reset=True, |
|
stream=False, |
|
|
|
|
|
) |
|
|
|
response = generate( |
|
prompt, |
|
config=config, |
|
) |
|
|
|
logger.debug(f"api: {response=}") |
|
except Exception as exc: |
|
logger.error(exc) |
|
response = f"{exc=}" |
|
|
|
|
|
|
|
return response |
|
|
|
|
|
css = """ |
|
.importantButton { |
|
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; |
|
border: none !important; |
|
} |
|
.importantButton:hover { |
|
background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important; |
|
border: none !important; |
|
} |
|
.disclaimer {font-variant-caps: all-small-caps; font-size: xx-small;} |
|
.xsmall {font-size: x-small;} |
|
""" |
|
|
|
logger.info("start block") |
|
|
|
with gr.Blocks( |
|
title=f"{Path(model_loc).name}", |
|
|
|
theme=gr.themes.Glass(text_size="sm", spacing_size="sm"), |
|
css=css, |
|
) as block: |
|
|
|
with gr.Accordion("🎈 Info", open=True): |
|
gr.Markdown( |
|
f"""<h5><center>{Path(model_loc).name}</center></h4> |
|
Doesn't quite work yet -- no output or run forever. Maybe the system prompt is not in order. """, |
|
elem_classes="xsmall", |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot(height=500) |
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
msg = gr.Textbox( |
|
label="Chat Message Box", |
|
placeholder="Ask me anything (press Shift+Enter or click Submit to send)", |
|
show_label=False, |
|
|
|
lines=6, |
|
max_lines=30, |
|
show_copy_button=True, |
|
|
|
) |
|
with gr.Column(scale=1, min_width=50): |
|
with gr.Row(): |
|
submit = gr.Button("Submit", elem_classes="xsmall") |
|
stop = gr.Button("Stop", visible=True) |
|
clear = gr.Button("Clear History", visible=True) |
|
with gr.Row(visible=False): |
|
with gr.Accordion("Advanced Options:", open=False): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
system = gr.Textbox( |
|
label="System Prompt", |
|
value=prompt_template, |
|
show_label=False, |
|
container=False, |
|
|
|
) |
|
with gr.Column(): |
|
with gr.Row(): |
|
change = gr.Button("Change System Prompt") |
|
reset = gr.Button("Reset System Prompt") |
|
|
|
with gr.Accordion("Example Inputs", open=True): |
|
examples = gr.Examples( |
|
examples=examples_list, |
|
inputs=[msg], |
|
examples_per_page=40, |
|
) |
|
|
|
|
|
with gr.Accordion("Disclaimer", open=False): |
|
_ = Path(model_loc).name |
|
gr.Markdown( |
|
f"Disclaimer: {_} can produce factually incorrect output, and should not be relied on to produce " |
|
f"factually accurate information. {_} was trained on various public datasets; while great efforts " |
|
"have been taken to clean the pretraining data, it is possible that this model could generate lewd, " |
|
"biased, or otherwise offensive outputs.", |
|
elem_classes=["disclaimer"], |
|
) |
|
|
|
msg_submit_event = msg.submit( |
|
|
|
fn=user, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
show_progress="full", |
|
|
|
).then(bot, chatbot, chatbot, queue=True) |
|
submit_click_event = submit.click( |
|
|
|
fn=user1, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=True, |
|
|
|
show_progress="full", |
|
|
|
).then(bot, chatbot, chatbot, queue=True) |
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[msg_submit_event, submit_click_event], |
|
queue=False, |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
with gr.Accordion("For Chat/Translation API", open=False, visible=False): |
|
input_text = gr.Text() |
|
api_btn = gr.Button("Go", variant="primary") |
|
out_text = gr.Text() |
|
|
|
api_btn.click( |
|
predict_api, |
|
input_text, |
|
out_text, |
|
api_name="api", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_ = """ |
|
# _ = int(psutil.virtual_memory().total / 10**9 // file_size - 1) |
|
# concurrency_count = max(_, 1) |
|
if psutil.cpu_count(logical=False) >= 8: |
|
# concurrency_count = max(int(32 / file_size) - 1, 1) |
|
else: |
|
# concurrency_count = max(int(16 / file_size) - 1, 1) |
|
# """ |
|
|
|
|
|
|
|
|
|
server_port = 7860 |
|
if "forindo" in platform.node(): |
|
server_port = 7861 |
|
block.queue(max_size=5).launch( |
|
debug=True, server_name="0.0.0.0", server_port=server_port |
|
) |
|
|
|
|
|
|