Spaces:
Runtime error
Runtime error
File size: 2,404 Bytes
19cef65 6d33199 19cef65 6d33199 19cef65 6d33199 19cef65 1f3f4fe 19cef65 789bfa3 66d5e37 56006e4 6d33199 19cef65 6d33199 9a6b8a0 c81ffe4 9a6b8a0 6d33199 9a6b8a0 6d33199 4903ee2 6d33199 19cef65 6d33199 19cef65 6d33199 19cef65 6d33199 19cef65 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import random
from threading import Thread
from typing import Iterable
import torch
from huggingface_hub import HfApi
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
ground_truth = ""
TOKEN = os.environ.get("HF_TOKEN", None)
type2dataset = {
"re2text-easy": load_dataset('3B-Group/ConvRe', "en-re2text", token="", split="prompt1"),
"re2text-hard": load_dataset('3B-Group/ConvRe', "en-re2text", token="", split="prompt4"),
"text2re-easy": load_dataset('3B-Group/ConvRe', "en-text2re", token="", split="prompt1"),
"text2re-hard": load_dataset('3B-Group/ConvRe', "en-text2re", token="", split="prompt3")
}
model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=TOKEN, device_map="auto").eval()
# model_id = "google/flan-t5-base"
# tokenizer = T5Tokenizer.from_pretrained(model_id)
# model = T5ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
# type2dataset = {}
def generate(input_text, sys_prompt, temperature, max_new_tokens) -> str:
sys_prompt = f'''[INST] <<SYS>>
{sys_prompt}
<</SYS>>
'''
input_str = sys_prompt + input_text + " [/INST]"
input_ids = tokenizer(input_str, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=float(temperature)
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Pull the generated text from the streamer, and update the model output.
model_output = ""
for new_text in streamer:
model_output += new_text
yield model_output
return model_output
def random_examples(dataset_key) -> str:
# target_dataset = type2dataset[f"{task.lower()}-{type.lower()}"]
target_dataset = type2dataset[dataset_key]
idx = random.randint(0, len(target_dataset) - 1)
item = target_dataset[idx]
global ground_truth
ground_truth = item['answer']
return item['query']
def return_ground_truth() -> str:
correct_answer = ground_truth
return correct_answer
|