Spaces:
Runtime error
Runtime error
File size: 4,758 Bytes
0f06116 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from huggingface_hub import InferenceClient
import gradio as gr
import os
import re
# Get secret (HF_TOKEN)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
#HTML/CSS stuff
DESCRIPTION = """
<div>
<h1 style="text-align: center;">Llama 3 Poem Analysis (Work-in-progress)</h1>
<p><h2>Copy-paste poem into textbox --> get Llama 3-generated commentary *hallucinations likely*</h2></p>
</div>
"""
LICENSE = """
<p/>
---
Built with Meta Llama 3
"""
#Not being used currently; having trouble integrating as a gr.Textbox in the params to gr.ChatInterface framework (end)
PLACEHOLDER = """
<div>
<img src="TBD" style="opacity: 0.55; ">
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
"""
#Initialize Llama as model; using InferenceClient for speed
client = InferenceClient(
"meta-llama/Meta-Llama-3-8B-Instruct"
)
#Get few-shot samples from PoemAnalysisSamples.txt
with open("PoemAnalysisSamples.txt", 'r') as f:
sample_poems = f.read()
pairs = re.findall(r'<poem>(.*?)</poem>\s*<response>(.*?)</response>', sample_poems, re.DOTALL)
#System message to initialize poetry assistant
sys_message = """
Assistant provides detailed analysis of poems following the format of the few-shot samples given. Assistant uses the following poetic terms and concepts to describe poem entered by user: simile, metaphor, metonymy, imagery, synecdoche, meter, diction, end rhyme, internal rhyme, and slant rhyme."
"""
#Helper function for formatting
def format_prompt(message, history):
"""Formats the prompt for the LLM
Args:
message: current user text entry
history: conversation history tracked by Gradio
Returns:
prompt: formatted properly for inference
"""
#Start with system message in Llama 3 message format: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
prompt=f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>{sys_message}<|eot_id|>"
#Unpack the user and assistant messages from few-shot samples
for poem, response in pairs:
prompt+=f"<|start_header_id|>user<|end_header_id|>{poem}<|eot_id|>"
prompt+=f"<|start_header_id|>assistant<|end_header_id|>{response}<|eot_id|>"
#Unpack the conversation history stored by Gradio
for user_prompt, bot_response in history:
prompt+=f"<|start_header_id|>user<|end_header_id|>{user_prompt}<|eot_id|>"
prompt+=f"<|start_header_id|>assistant<|end_header_id|>{bot_response}<|eot_id|>"
#Add new message
prompt+=f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>{message}<|eot_id|><|begin_of_text|><|start_header_id|>assistant<|end_header_id|>"
return prompt
#Function to generate LLM response
def generate(
prompt, history, temperature=0.1, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
stop_sequences=["<|eot_id|>"] #Llama 3 requires this stop token
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True) #change last to True for debugging conversation history
output = ""
for response in stream:
output += response.token.text
yield output
return output
# Initialize sliders
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.1,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=4096,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
#Gradio UI
with gr.Blocks(css=css) as demo:
gr.ChatInterface(
fn=generate,
description=DESCRIPTION,
additional_inputs=additional_inputs
)
gr.Markdown(LICENSE)
demo.queue(concurrency_count=75, max_size=100).launch(debug=True) |