|
|
|
import time |
|
from huggingface_hub import InferenceClient |
|
import gradio as gr |
|
|
|
|
|
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
|
|
|
|
SYSTEM_PROMPT = ( |
|
"You are a prompt enhancer and your work is to enhance the given prompt under 100 words " |
|
"without changing the essence, only write the enhanced prompt and nothing else." |
|
) |
|
|
|
def format_prompt(message): |
|
""" |
|
Format the input message using the system prompt and a timestamp to ensure uniqueness. |
|
""" |
|
timestamp = time.time() |
|
formatted = ( |
|
f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]" |
|
f"[INST] {message} {timestamp} [/INST]" |
|
) |
|
return formatted |
|
|
|
def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0): |
|
""" |
|
Generate an enhanced prompt using the new LLM. |
|
This function yields intermediate results as they are generated. |
|
""" |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
generate_kwargs = { |
|
"temperature": temperature, |
|
"max_new_tokens": int(max_new_tokens), |
|
"top_p": top_p, |
|
"repetition_penalty": float(repetition_penalty), |
|
"do_sample": True, |
|
} |
|
formatted_prompt = format_prompt(message) |
|
stream = client.text_generation( |
|
formatted_prompt, |
|
**generate_kwargs, |
|
stream=True, |
|
details=True, |
|
return_full_text=False, |
|
) |
|
output = "" |
|
for response in stream: |
|
token_text = response.token.text |
|
output += token_text |
|
yield output.strip('</s>') |
|
return output.strip('</s>') |
|
|
|
|