Spaces:
Runtime error
Runtime error
File size: 7,606 Bytes
c8c7b71 850b0e4 c8c7b71 850b0e4 ee2e0b7 c8c7b71 ee2e0b7 850b0e4 c8c7b71 850b0e4 ee2e0b7 c8c7b71 ee2e0b7 c8c7b71 1ec2ec6 c8c7b71 ee2e0b7 c8c7b71 1ec2ec6 ff23853 850b0e4 1ec2ec6 ff23853 850b0e4 ff23853 ee2e0b7 c8c7b71 ff23853 850b0e4 2c782ec c8c7b71 ee2e0b7 c8c7b71 ee2e0b7 c8c7b71 ff23853 c8c7b71 ff23853 d59d1e6 1ec2ec6 ff23853 1ec2ec6 ff23853 1ec2ec6 ff23853 c8c7b71 ff23853 141b1fb d59d1e6 ee2e0b7 ff23853 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import gradio as gr
import torch
import uuid
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
import os
import uvicorn
mario_lm = MarioLM()
device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"
app = FastAPI()
def make_html_file(generated_level):
level_text = f"""{'''
'''.join(view_level(generated_level,mario_lm.tokenizer))}"""
unique_id = uuid.uuid1()
with open(f"static/demo-{unique_id}.html", 'w', encoding='utf-8') as f:
f.write(f'''<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<title>Mario Game</title>
<script src="https://cjrtnc.leaningtech.com/20230216/loader.js"></script>
</head>
<body>
</body>
<script>
cheerpjInit().then(function () {{
cheerpjAddStringFile("/str/mylevel.txt", `{level_text}`);
}});
cheerpjCreateDisplay(512, 500);
cheerpjRunJar("/app/static/mario.jar");
</script>
</html>''')
return f"demo-{unique_id}.html"
def trim_level(level):
mod = level.shape[-1] % 14
if mod > 0:
return level[:, :-mod]
return level
def reset_state(seed_state):
length = len(seed_state)
print(f"Resetting state with {length} levels!")
for _ in range(length):
seed_state.pop()
def _generate_level(prompts, seed, level_size, temperature):
print(f"Using prompts: {prompts}")
generated_levels = mario_lm.sample(
prompts=prompts,
num_steps=level_size,
temperature=temperature,
use_tqdm=True,
seed = seed
)
generated_levels = trim_level(generated_levels)
return generated_levels
def _make_gradio_html(level):
filename = make_html_file(level)
gradio_html = f'''<div>
<iframe width=512 height=512 style="margin: 0 auto" src="static/{filename}"></iframe>
<p style="text-align:center">Press the arrow keys to move. Press <code>a</code> to run, <code>s</code> to jump and <code>d</code> to shoot fireflowers</p>
</div>'''
return gradio_html
def initialize_generate(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400):
prompts = [f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"]
generated_levels = _generate_level(prompts, None, level_size, temperature)
level = generated_levels.squeeze().detach().cpu()
img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
return [img, _make_gradio_html(level)]
def generate_choices(pipes, enemies, blocks, elevation, temperature = 2.4, level_size = 1400, prompt = "", seed_state = []):
NUM_SAMPLES = 2
if prompt == "":
prompt = f"{pipes} pipes, {enemies} enemies, {blocks} blocks, {elevation} elevation"
prompts = [prompt] * NUM_SAMPLES
seed = None
if len(seed_state) > 0:
seed = torch.cat(seed_state).squeeze()[-48*14:].view(1, -1).repeat(NUM_SAMPLES, 1) # context length
generated_levels = _generate_level(prompts, seed, level_size, temperature).detach().cpu().squeeze()
level_choices = [generated_level[-level_size:] for generated_level in generated_levels]
level_choice_images = [convert_level_to_png(generated_level[-level_size:], TILE_DIR, mario_lm.tokenizer)[0] for generated_level in generated_levels]
# level choices + separate images
return [level_choices, *level_choice_images]
def update_level_state(choice_id, level_choices, seed_state):
num_choice = int(choice_id)
level_choice = level_choices[num_choice]
# append level choice to seed state
seed_state.append(level_choice)
# get new level from concatenation
level = torch.cat(seed_state).squeeze()
# final image and gradio html
img = convert_level_to_png(level, TILE_DIR, mario_lm.tokenizer)[0]
gradio_html = _make_gradio_html(level)
# return img, gradio html, seed state, level_choice, choice_image_1, choice_image_2, current_level_size
return img, gradio_html, seed_state, None, None, None, level.shape[-1]
with gr.Blocks().queue() as demo:
gr.Markdown('''### Playable demo for MarioGPT: Open-Ended Text2Level Generation through Large Language Models
[[Github](https://github.com/shyamsn97/mario-gpt)], [[Paper](https://arxiv.org/abs/2302.05981)]
''')
with gr.Tabs():
with gr.TabItem("Compose prompt"):
with gr.Row():
pipes = gr.Radio(["no", "little", "some", "many"], label="How many pipes?")
enemies = gr.Radio(["no", "little", "some", "many"], label="How many enemies?")
with gr.Row():
blocks = gr.Radio(["little", "some", "many"], label="How many blocks?")
elevation = gr.Radio(["low", "high"], label="Elevation?")
with gr.TabItem("Type prompt"):
text_prompt = gr.Textbox(value="", label="Enter your MarioGPT prompt. ex: 'many pipes, many enemies, some blocks, low elevation'")
with gr.Accordion(label="Advanced settings", open=False):
temperature = gr.Number(value=2.0, label="temperature: Increase these for more diverse, but lower quality, generations")
level_size = gr.Number(value=1400, precision=0, label="level_size")
generate_btn = gr.Button("Generate Level")
reset_btn = gr.Button("Reset Level")
with gr.Row():
with gr.Box():
level_play = gr.HTML()
level_image = gr.Image(label="Current Level")
with gr.Box():
with gr.Column():
level_choice1_image = gr.Image(label="Sample Choice 1")
level_choice1_btn = gr.Button("Sample Choice 1")
with gr.Column():
level_choice2_image = gr.Image(label="Sample Choice 2")
level_choice2_btn = gr.Button("Sample Choice 2")
current_level_size = gr.Number(0, visible=True, label="Current Level Size")
seed_state = gr.State([])
state_choices = gr.State(None)
image_choice_1_id = gr.Number(0, visible=False)
image_choice_2_id = gr.Number(1, visible=False)
# choice buttons
level_choice1_btn.click(fn=update_level_state, inputs=[image_choice_1_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])
level_choice2_btn.click(fn=update_level_state, inputs=[image_choice_2_id, state_choices, seed_state], outputs=[level_image, level_play, seed_state, state_choices, level_choice1_image, level_choice2_image, current_level_size])
# generate_btn
generate_btn.click(fn=generate_choices, inputs=[pipes, enemies, blocks, elevation, temperature, level_size, text_prompt, seed_state], outputs=[state_choices, level_choice1_image, level_choice2_image])
# reset btn
reset_btn.click(fn=reset_state, inputs=[seed_state], outputs=[])
gr.Examples(
examples=[
["many", "many", "some", "high", 2.0],
["no", "some", "many", "high", 2.0],
["many", "many", "little", "low", 2.4],
["no", "no", "many", "high", 2.8],
],
inputs=[pipes, enemies, blocks, elevation, temperature, level_size],
outputs=[level_image, level_play],
fn=initialize_generate,
cache_examples=True,
)
app.mount("/static", StaticFiles(directory="static", html=True), name="static")
app = gr.mount_gradio_app(app, demo, "/", gradio_api_url="http://localhost:7860/")
uvicorn.run(app, host="0.0.0.0", port=7860)
|