Spaces:
Sleeping
Sleeping
File size: 4,373 Bytes
dcf2cc2 a6150d4 dcf2cc2 c12c6cc a6150d4 c12c6cc a6150d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
class CustomPrompter(Prompter):
def get_response(self, output: str) -> str:
return output.split(self.template["response_split"])[1].strip().split("### Instruction:")[0]
prompter = CustomPrompter(prompt_template_name)
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
def evaluate(instruction):
# Generate a response:
input = None
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
#inputs = inputs.to("cuda:0")
input_ids = inputs["input_ids"]
#play around with generation strategies for better/diverse sequences. https://huggingface.co/docs/transformers/generation_strategies
temperature=0.2
top_p=0.95
top_k=25
num_beams=1
# num_beam_groups=num_beams #see: 'Diverse beam search decoding'
max_new_tokens=256
repetition_penalty = 2.0
do_sample = True # allow 'beam sample': do_sample=True, num_beams > 1
num_return_sequences = 1 #generate multiple candidates, takes longer..
generation_config = transformers.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
min_new_tokens=32,
num_return_sequences=num_return_sequences,
pad_token_id = 0
# num_beam_groups=num_beam_groups
)
generate_params = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": max_new_tokens,
}
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
print(f'Instruction: {instruction}')
for i,s in enumerate(generation_output.sequences):
output = tokenizer.decode(s,skip_special_tokens=True)
# print(output)
return(f' {prompter.get_response(output)}')
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Explain economic growth.",
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="🌲 ELM - Erasmian Language Model",
description="ELM is a 900M parameter language model finetuned to follow instruction. It is trained on Erasmus University academic outputs and the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. For more information, please visit [the GitHub repository](https://github.com/Joaoffg/ELM).", # noqa: E501
).queue().launch(server_name="0.0.0.0", share=True)
# Old testing code follows. |