File size: 1,733 Bytes
a91da3f
 
 
 
 
 
6cfd1e8
a91da3f
 
4339742
2dbb9c7
a91da3f
 
 
 
 
 
 
6c02805
 
a91da3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335a0d0
a91da3f
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
---
language: en
license: apache-2.0
---

# CodeRosetta
## Pushing the Boundaries of Unsupervised Code Translation for Parallel Programming ([📃Paper](https://arxiv.org/abs/2410.20527), [🔗Website](https://coderosetta.com/)).


CodeRosetta is an EncoderDecoder translation model. It supports the translation of C++, CUDA, and Fortran. \
This version of the model is fine-tuned on synthetic dataset for **C++ to CUDA translation.**

### How to use

```python
from transformers import AutoTokenizer, EncoderDecoderModel

# Load the CodeRosetta model and tokenizer
model = EncoderDecoderModel.from_pretrained('CodeRosetta/CodeRosetta_cpp2cuda_ft')
tokenizer = AutoTokenizer.from_pretrained('CodeRosetta/CodeRosetta_cpp2cuda_ft')

# Encode the input C++ Code
input_cpp_code = "void add_100 ( int numElements , int * data ) { for ( int idx = 0 ; idx < numElements ; idx ++ ) { data [ idx ] += 100 ; } }"
input_ids = tokenizer.encode(input_cpp_code, return_tensors="pt")

# Set the start token to <CUDA>
start_token = "<CUDA>"
decoder_start_token_id = tokenizer.convert_tokens_to_ids(start_token)

# Generate the CUDA code
output = model.generate(
    input_ids=input_ids, 
    decoder_start_token_id=decoder_start_token_id,
    max_length=256
)

# Decode and print the generated output
generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_code)
```

### BibTeX 

```bibtex
@inproceedings{coderosetta:neurips:2024,
  title = {CodeRosetta: Pushing the Boundaries of Unsupervised Code Translation for Parallel Programming},
  author = {TehraniJamsaz, Ali and Bhattacharjee, Arijit and Chen, Le and Ahmed, Nesreen K and Yazdanbakhsh, Amir and Jannesari, Ali},
  booktitle = {NeurIPS},
  year = {2024},
}