Spaces:
Runtime error
Runtime error
import os | |
import random | |
from threading import Thread | |
from typing import Iterable | |
import torch | |
from huggingface_hub import HfApi | |
from datasets import load_dataset | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
TOKEN = os.environ.get("HF_TOKEN", None) | |
type2dataset = { | |
"re2text-easy": load_dataset('3B-Group/ConvRe', "en-re2text", token=TOKEN, split="prompt1"), | |
"re2text-hard": load_dataset('3B-Group/ConvRe', "en-re2text", token=TOKEN, split="prompt4"), | |
"text2re-easy": load_dataset('3B-Group/ConvRe', "en-text2re", token=TOKEN, split="prompt1"), | |
"text2re-hard": load_dataset('3B-Group/ConvRe', "en-text2re", token=TOKEN, split="prompt3") | |
} | |
model_id = "meta-llama/Llama-2-7b-chat-hf" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=TOKEN) | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=TOKEN).eval() | |
# type2dataset = {} | |
def generate(input_text, sys_prompt) -> str: | |
sys_prompt = f'''[INST] <<SYS>> | |
{sys_prompt} | |
<</SYS>> | |
''' | |
input_str = sys_prompt + input_text + " [/INST]" | |
input_ids = tokenizer(input_str, return_tensors="pt").input_ids | |
outputs = model.generate(input_ids, max_length=512) | |
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
result = result.split(' [/INST]') | |
result = result[0] + '\n\n' + result[1] | |
return result | |
def random_examples(dataset_key) -> str: | |
# target_dataset = type2dataset[f"{task.lower()}-{type.lower()}"] | |
target_dataset = type2dataset[dataset_key] | |
idx = random.randint(0, len(target_dataset) - 1) | |
item = target_dataset[idx] | |
return item['query'] | |