File size: 7,606 Bytes
c8c7b71
850b0e4
 
c8c7b71
850b0e4
 
 
 
 
ee2e0b7
 
c8c7b71
ee2e0b7
 
850b0e4
c8c7b71
850b0e4
 
 
 
ee2e0b7
c8c7b71
 
 
 
 
ee2e0b7
c8c7b71
 
 
 
 
 
 
 
1ec2ec6
c8c7b71
 
 
 
 
 
 
ee2e0b7
c8c7b71
 
 
1ec2ec6
ff23853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850b0e4
1ec2ec6
 
ff23853
 
850b0e4
ff23853
 
 
 
 
ee2e0b7
 
 
c8c7b71
ff23853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850b0e4
2c782ec
c8c7b71
 
 
 
 
 
ee2e0b7
 
c8c7b71
ee2e0b7
 
 
 
 
c8c7b71
 
ff23853
 
 
 
 
 
c8c7b71
 
ff23853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d59d1e6
1ec2ec6
ff23853
1ec2ec6
ff23853
 
1ec2ec6
ff23853
c8c7b71
ff23853
141b1fb
d59d1e6
ee2e0b7
 
 
ff23853
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

import gradio as gr
import torch
import uuid
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png

from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles

import os
import uvicorn

mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"

app = FastAPI()

def make_html_file(generated_level):
    level_text = f"""{'''
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
    unique_id = uuid.uuid1()
    with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f:
        f.write(f'''<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="utf-8">
    <title>Mario Game</title>
    <script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>

<body>
</body>
<script>
    cheerpjInit().then(function () {{
        cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
    }});
    cheerpjCreateDisplay(512, 500);
    cheerpjRunJar("/app/static/mario.jar");
</script>
</html>''')
    return f"demo-{unique_id}.html"

def trim_level(level):
    mod = level.shape[-1] % 14
    if mod > 0:
        return level[:, :-mod]
    return level

def reset_state(seed_state):
    length = len(seed_state)
    print(f"Resetting state with {length} levels!")
    for _ in range(length):
        seed_state.pop()

def _generate_level(prompts, seed, level_size, temperature):
    print(f"Using prompts: {prompts}")
    generated_levels = mario_lm.sample(
        prompts=prompts,
        num_steps=level_size,
        temperature=temperature,
        use_tqdm=True,
        seed = seed
    )
    generated_levels = trim_level(generated_levels)
    return generated_levels

def _make_gradio_html(level):
    filename = make_html_file(level)
    gradio_html = f'''<div>
        <iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe>
        <p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
    </div>'''
    return gradio_html

def initialize_generate(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400):
    prompts = [f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"]
    generated_levels = _generate_level(prompts, None, level_size, temperature)
    level = generated_levels.squeeze().detach().cpu()
    img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
    return [img, _make_gradio_html(level)]

def generate_choices(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400, prompt = "", seed_state = []):
    NUM_SAMPLES = 2
    if prompt == "":
        prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
    prompts = [prompt] * NUM_SAMPLES

    seed = None
    if len(seed_state) > 0:
        seed = torch.cat(seed_state).squeeze()[-48*14:].view(1, -1).repeat(NUM_SAMPLES, 1) # context length

    generated_levels = _generate_level(prompts, seed, level_size, temperature).detach().cpu().squeeze()
    level_choices = [generated_level[-level_size:] for generated_level in generated_levels]
    level_choice_images = [convert_level_to_png(generated_level[-level_size:], TILE_DIR, mario_lm.tokenizer)[0] for generated_level in generated_levels]

    # level choices + separate images
    return [level_choices, *level_choice_images]

def update_level_state(choice_id, level_choices, seed_state):
    num_choice = int(choice_id)
    level_choice = level_choices[num_choice]

    # append level choice to seed state
    seed_state.append(level_choice)

    # get new level from concatenation
    level = torch.cat(seed_state).squeeze()

    # final image and gradio html
    img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
    gradio_html = _make_gradio_html(level)

    # return img, gradio html, seed state, level_choice, choice_image_1, choice_image_2, current_level_size
    return img, gradio_html, seed_state, None, None, None, level.shape[-1]


with gr.Blocks().queue() as demo:
    gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
    [[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
    ''')
    with gr.Tabs():
        with gr.TabItem("Compose prompt"):
            with gr.Row():
                pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?")
                enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?")
            with gr.Row():
                blocks = gr.Radio(["little", "some", "many"], label="How many blocks?")
                elevation = gr.Radio(["low", "high"], label="Elevation?")
        with gr.TabItem("Type prompt"):
            text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
        
    with gr.Accordion(label="Advanced settings", open=False):
        temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
        level_size = gr.Number(value=1400, precision=0, label="level_size")

    generate_btn = gr.Button("Generate Level")
    reset_btn = gr.Button("Reset Level")


    with gr.Row():
        with gr.Box():
            level_play = gr.HTML()
        level_image = gr.Image(label="Current Level")
        with gr.Box():
            with gr.Column():
                level_choice1_image = gr.Image(label="Sample Choice 1")
                level_choice1_btn = gr.Button("Sample Choice 1")
            with gr.Column():
                level_choice2_image = gr.Image(label="Sample Choice 2")
                level_choice2_btn = gr.Button("Sample Choice 2")
            current_level_size = gr.Number(0, visible=True, label="Current Level Size")


    seed_state = gr.State([])
    state_choices = gr.State(None)

    image_choice_1_id = gr.Number(0, visible=False)
    image_choice_2_id = gr.Number(1, visible=False)

    # choice buttons
    level_choice1_btn.click(fn=update_level_state, inputs=[image_choice_1_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])
    level_choice2_btn.click(fn=update_level_state, inputs=[image_choice_2_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])

    # generate_btn
    generate_btn.click(fn=generate_choices, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt, seed_state], outputs=[state_choices, level_choice1_image, level_choice2_image])

    # reset btn
    reset_btn.click(fn=reset_state, inputs=[seed_state], outputs=[])

    gr.Examples(
        examples=[
            ["many", "many", "some", "high", 2.0],
            ["no", "some", "many", "high", 2.0],
            ["many", "many", "little", "low", 2.4],
            ["no", "no", "many", "high", 2.8],
        ],
        inputs=[pipes, enemies, blocks, elevation, temperature, level_size],
        outputs=[level_image, level_play],
        fn=initialize_generate,
        cache_examples=True,
    )

app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
uvicorn.run(app, host="0.0.0.0", port=7860)