|
from huggingface_hub import InferenceClient, HfApi, upload_file |
|
import datetime |
|
import gradio as gr |
|
import random |
|
import prompts |
|
import json |
|
import uuid |
|
import os |
|
|
|
|
|
|
|
token=os.environ.get("HF_TOKEN") |
|
username="omnibus" |
|
dataset_name="tmp" |
|
api=HfApi(token="") |
|
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
|
|
history = [] |
|
hist_out= [] |
|
summary =[] |
|
main_point=[] |
|
summary.append("") |
|
main_point.append("") |
|
|
|
def format_prompt(message, history): |
|
prompt = "<s>" |
|
for user_prompt, bot_response in history: |
|
prompt += f"[INST] {user_prompt} [/INST]" |
|
prompt += f" {bot_response}</s> " |
|
prompt += f"[INST] {message} [/INST]" |
|
return prompt |
|
|
|
agents =[ |
|
"COMMENTER", |
|
"BLOG_POSTER", |
|
"COMPRESS_HISTORY_PROMPT" |
|
] |
|
|
|
temperature=0.9 |
|
max_new_tokens=256 |
|
max_new_tokens2=10480 |
|
top_p=0.95 |
|
repetition_penalty=1.0, |
|
|
|
def compress_history(formatted_prompt): |
|
|
|
seed = random.randint(1,1111111111111111) |
|
agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0]) |
|
|
|
system_prompt=agent |
|
temperature = 0.9 |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=30480, |
|
top_p=0.95, |
|
repetition_penalty=1.0, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
|
|
|
|
formatted_prompt = formatted_prompt |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
|
|
print(output) |
|
print(main_point[0]) |
|
return output |
|
|
|
|
|
def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,): |
|
|
|
seed = random.randint(1,1111111111111111) |
|
agent=prompts.COMMENTER.format(focus=main_point[0]) |
|
system_prompt=agent |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
|
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
for response in stream: |
|
output += response.token.text |
|
|
|
|
|
return output |
|
def create_valid_filename(invalid_filename: str) -> str: |
|
"""Converts invalid characters in a string to be suitable for a filename.""" |
|
invalid_filename.replace(" ","-") |
|
valid_chars = '-'.join(invalid_filename.split()) |
|
allowed_chars = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', |
|
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', |
|
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', |
|
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', |
|
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '_', '-') |
|
return ''.join(char for char in valid_chars if char in allowed_chars) |
|
|
|
def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,): |
|
main_point[0]=prompt |
|
|
|
uid=uuid.uuid4() |
|
current_time = str(datetime.datetime.now()) |
|
|
|
current_time=current_time.replace(":","-") |
|
current_time=current_time.replace(".","-") |
|
print (current_time) |
|
agent=prompts.BLOG_POSTER |
|
system_prompt=agent |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
hist_out=[] |
|
sum_out=[] |
|
json_hist={} |
|
json_obj={} |
|
filename=create_valid_filename(f'{prompt}---{current_time}') |
|
while True: |
|
seed = random.randint(1,1111111111111111) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens2, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=seed, |
|
) |
|
if prompt.startswith(' \"'): |
|
prompt=prompt.strip(' \"') |
|
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) |
|
if len(formatted_prompt) < (50000): |
|
print(len(formatted_prompt)) |
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
|
|
|
|
|
|
for response in stream: |
|
output += response.token.text |
|
yield '', [(prompt,output)],summary[0],json_obj, json_hist |
|
out_json = {"prompt":prompt,"output":output} |
|
|
|
prompt = question_generate(output, history) |
|
|
|
history.append((prompt,output)) |
|
print ( f'Prompt:: {len(prompt)}') |
|
|
|
print ( f'history:: {len(formatted_prompt)}') |
|
hist_out.append(out_json) |
|
|
|
|
|
with open(f'{uid}.json', 'w') as f: |
|
json_hist=json.dumps(hist_out, indent=4) |
|
f.write(json_hist) |
|
f.close() |
|
|
|
upload_file( |
|
path_or_fileobj =f"{uid}.json", |
|
path_in_repo = f"book1/{filename}.json", |
|
repo_id =f"{username}/{dataset_name}", |
|
repo_type = "dataset", |
|
token=token, |
|
) |
|
else: |
|
formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history) |
|
|
|
|
|
|
|
history = [] |
|
output = compress_history(formatted_prompt) |
|
summary[0]=output |
|
sum_json = {"summary":summary[0]} |
|
sum_out.append(sum_json) |
|
with open(f'{uid}-sum.json', 'w') as f: |
|
json_obj=json.dumps(sum_out, indent=4) |
|
f.write(json_obj) |
|
f.close() |
|
upload_file( |
|
path_or_fileobj =f"{uid}-sum.json", |
|
path_in_repo = f"summary/{filename}-summary.json", |
|
repo_id =f"{username}/{dataset_name}", |
|
repo_type = "dataset", |
|
token=token, |
|
) |
|
|
|
|
|
prompt = question_generate(output, history) |
|
|
|
return prompt, history, summary[0],json_obj,json_hist |
|
|
|
def load_html(): |
|
with open('index.html','r') as h: |
|
html=h.read() |
|
html = html.replace("$name","Test") |
|
h.close() |
|
return html |
|
|
|
|
|
with gr.Blocks() as app: |
|
html = gr.HTML() |
|
|
|
chatbot=gr.Chatbot() |
|
msg = gr.Textbox() |
|
with gr.Row(): |
|
submit_b = gr.Button() |
|
stop_b = gr.Button("Stop") |
|
clear = gr.ClearButton([msg, chatbot]) |
|
sumbox=gr.Textbox("Summary", max_lines=100) |
|
with gr.Column(): |
|
sum_out_box=gr.JSON(label="Summaries") |
|
hist_out_box=gr.JSON(label="History") |
|
|
|
sub_b = submit_b.click(generate, [msg,chatbot],[msg,chatbot,sumbox,sum_out_box,hist_out_box]) |
|
sub_e = msg.submit(generate, [msg, chatbot], [msg, chatbot,sumbox,sum_out_box,hist_out_box]) |
|
stop_b.click(None,None,None, cancels=[sub_b,sub_e]) |
|
|
|
|
|
|
|
app.load(load_html,None,html) |
|
app.launch() |