File size: 2,404 Bytes
19cef65
 
 
 
 
 
 
 
6d33199
19cef65
 
6d33199
19cef65
6d33199
19cef65
 
1f3f4fe
 
 
 
19cef65
 
789bfa3
 
66d5e37
56006e4
6d33199
 
 
 
 
19cef65
 
 
6d33199
9a6b8a0
 
 
 
 
 
 
c81ffe4
9a6b8a0
6d33199
9a6b8a0
6d33199
 
 
 
 
 
 
 
 
4903ee2
6d33199
 
 
 
 
 
19cef65
 
 
6d33199
 
19cef65
 
 
 
 
6d33199
 
 
 
19cef65
 
 
6d33199
 
 
19cef65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import random
from threading import Thread
from typing import Iterable

import torch
from huggingface_hub import HfApi
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer


ground_truth = ""

TOKEN = os.environ.get("HF_TOKEN", None)

type2dataset = {
    "re2text-easy": load_dataset('3B-Group/ConvRe', "en-re2text", token="", split="prompt1"),
    "re2text-hard": load_dataset('3B-Group/ConvRe', "en-re2text", token="", split="prompt4"),
    "text2re-easy": load_dataset('3B-Group/ConvRe', "en-text2re", token="", split="prompt1"),
    "text2re-hard": load_dataset('3B-Group/ConvRe', "en-text2re", token="", split="prompt3")
}

model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=TOKEN)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=TOKEN, device_map="auto").eval()


# model_id = "google/flan-t5-base"
# tokenizer = T5Tokenizer.from_pretrained(model_id)
# model = T5ForConditionalGeneration.from_pretrained(model_id, device_map="auto")

# type2dataset = {}


def generate(input_text, sys_prompt, temperature, max_new_tokens) -> str:
    sys_prompt = f'''[INST] <<SYS>>
{sys_prompt}
<</SYS>>

'''
    input_str = sys_prompt + input_text + " [/INST]"
    
    input_ids = tokenizer(input_str, return_tensors="pt").to("cuda")

    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=float(temperature)
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output


def random_examples(dataset_key) -> str:
    

    # target_dataset = type2dataset[f"{task.lower()}-{type.lower()}"]
    target_dataset = type2dataset[dataset_key]

    idx = random.randint(0, len(target_dataset) - 1)
    item = target_dataset[idx]

    global ground_truth
    ground_truth = item['answer']

    return item['query']


def return_ground_truth() -> str:
    correct_answer = ground_truth
    return correct_answer