Safetensors
gemma
File size: 5,179 Bytes
9e50c66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---
license: mit
datasets:
- codeparrot/apps
base_model:
- google/codegemma-7b-it
---
# Model Card for MA-RLHF
<a href="https://iclr.cc/Conferences/2024" target="_blank">
      <img alt="ICLR 2025" src="https://img.shields.io/badge/Proceedings-ICLR2025-red" />
</a>
<a href="https://github.com/ernie-research/MA-RLHF" target="_blank">
      <img alt="Github" src="https://img.shields.io/badge/Github-MA_RLHF-green" />
   </a>

This repository contains the official checkpoint for [Reinforcement Learning From Human Feedback with Macro Actions (MA-RLHF)](https://arxiv.org/pdf/2410.02743). 

## Model Description

MA-RLHF is a novel framework that integrates macro actions into conventional RLHF. The macro actions are sequences of tokens or higher-level language constructs, with can be computed through different defined termination conditions, like n-gram based, perplexity-based, or parsing-based termination conditions. By introducing macro actions into RLHF, we reduce the number of decision points and shorten decision trajectories, alleviating the credit assignment problem caused by long temporal distances.


|Model|Checkpoint|Base Model|Dataset| 
|-----|----------|-|-|
|TLDR-Gemma-2B-MA-PPO-Fixed5|πŸ€— [HF Link](https://huggingface.co/baidu/TLDR-Gemma-2B-MA-PPO-Fixed5)|[google/gemma-2b](https://huggingface.co/google/gemma-2b)|[openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
|TLDR-Gemma-7B-MA-PPO-Fixed5|πŸ€— [HF Link](https://huggingface.co/baidu/TLDR-Gemma-7B-MA-PPO-Fixed5)|[google/gemma-7b](https://huggingface.co/google/gemma-7b)|[openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
|TLDR-Gemma-2-27B-MA-PPO-Fixed5|πŸ€— [HF Link](https://huggingface.co/baidu/TLDR-Gemma-2-27B-MA-PPO-Fixed5)|[google/gemma-2-27b](https://huggingface.co/google/gemma-2-27b)|[openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
|HH-RLHF-Gemma-2B-MA-PPO-Fixed5|πŸ€— [HF Link](https://huggingface.co/baidu/HH-RLHF-Gemma-2B-MA-PPO-Fixed5) |[google/gemma-2b](https://huggingface.co/google/gemma-2b)|[Dahoas/full-hh-rlhf](https://huggingface.co/datasets/Dahoas/full-hh-rlhf)
|HH-RLHF-Gemma-7B-MA-PPO-Fixed5|πŸ€— [HF Link](https://huggingface.co/baidu/HH-RLHF-Gemma-7B-MA-PPO-Fixed5) |[google/gemma-7b](https://huggingface.co/google/gemma-7b)|[Dahoas/full-hh-rlhf](https://huggingface.co/datasets/Dahoas/full-hh-rlhf)
|APPS-Gemma-2B-MA-PPO-Fixed10|πŸ€— [HF Link](https://huggingface.co/baidu/APPS-Gemma-2B-MA-PPO-Fixed10) |[google/codegemma-2b](https://huggingface.co/google/codegemma-2b)|[codeparrot/apps](https://huggingface.co/datasets/codeparrot/apps)
|APPS-Gemma-7B-MA-PPO-Fixed10|πŸ€— [HF Link](https://huggingface.co/baidu/APPS-Gemma-7B-MA-PPO-Fixed10) |[google/codegemma-7b-it](https://huggingface.co/google/codegemma-7b-it)|[codeparrot/apps](https://huggingface.co/datasets/codeparrot/apps)


## Model Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "baidu/APPS-Gemma-7B-MA-PPO-Fixed10"

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype='auto', trust_remote_code=True)

input_text = """
An accordion is a string (yes, in the real world accordions are musical instruments, but let's forget about it for a while) which can be represented as a concatenation of: an opening bracket (ASCII code $091$), a colon (ASCII code $058$), some (possibly zero) vertical line characters (ASCII code $124$), another colon, and a closing bracket (ASCII code $093$). The length of the accordion is the number of characters in it. For example, [::], [:||:] and [:|||:] are accordions having length $4$, $6$ and $7$. (:|:), {:||:}, [:], ]:||:[ are not accordions. You are given a string $s$. You want to transform it into an accordion by removing some (possibly zero) characters from it. Note that you may not insert new characters or reorder existing ones. Is it possible to obtain an accordion by removing characters from $s$, and if so, what is the maximum possible length of the result? -----Input----- The only line contains one string $s$ ($1 \le |s| \le 500000$). It consists of lowercase Latin letters and characters [, ], : and |. -----Output----- If it is not possible to obtain an accordion by removing some characters from $s$, print $-1$. Otherwise print maximum possible length of the resulting accordion. -----Examples----- Input |[a:b:|] Output 4 Input |]:[|:] Output -1
"""

input_ids = tokenizer(input_text, return_tensors='pt').to(model.device)
output_ids = model.generate(**input_ids, max_new_tokens=20)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(response)
```

## Citation

```
@inproceedings{
  chai2025marlhf,
  title={{MA}-{RLHF}: Reinforcement Learning from Human Feedback with Macro Actions},
  author={Yekun Chai and Haoran Sun and Huang Fang and Shuohuan Wang and Yu Sun and Hua Wu},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=WWXjMYZxfH}
}
```