Spaces:
Running
Running
File size: 4,497 Bytes
dcf2cc2 a6150d4 41102b8 a6150d4 f70beba a6150d4 dcf2cc2 c12c6cc a6150d4 2074c5b c12c6cc 2074c5b c12c6cc 2074c5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
class CustomPrompter(Prompter):
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip().split("### Instruction:")[0]
prompt_template_name = "alpaca" # The prompt template to use, will default to alpaca.
prompter = CustomPrompter(prompt_template_name)
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
def evaluate(instruction):
# Generate a response:
input = None
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
#inputs = inputs.to("cuda:0")
input_ids = inputs["input_ids"]
#play around with generation strategies for better/diverse sequences. https://huggingface.co/docs/transformers/generation_strategies
temperature=0.2
top_p=0.95
top_k=25
num_beams=1
# num_beam_groups=num_beams #see: 'Diverse beam search decoding'
max_new_tokens=256
repetition_penalty = 2.0
do_sample = True # allow 'beam sample': do_sample=True, num_beams > 1
num_return_sequences = 1 #generate multiple candidates, takes longer..
generation_config = transformers.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
min_new_tokens=32,
num_return_sequences=num_return_sequences,
pad_token_id = 0
# num_beam_groups=num_beam_groups
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
print(f'Instruction: {instruction}')
for i,s in enumerate(generation_output.sequences):
output = tokenizer.decode(s,skip_special_tokens=True)
# print(output)
return(f' {prompter.get_response(output)}')
# Define the Gradio interface
interface = gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Explain economic growth.",
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="🌲 ELM - Erasmian Language Model",
description=(
"ELM is a 900M parameter language model finetuned to follow instruction. "
"It is trained on Erasmus University academic outputs and the "
"[Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. "
"For more information, please visit [the GitHub repository](https://github.com/Joaoffg/ELM)."
),
)
# Launch the Gradio interface
interface.queue().launch() |