Spaces:
Runtime error
Runtime error
File size: 1,197 Bytes
fa6856c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# Optimize prompts by training on prompts-ratings pairings dataset
# taken from https://github.com/JD-P/simulacra-aesthetic-captions
import os
import sqlite3
from urllib.request import urlretrieve
from accelerate import Accelerator
import trlx
from trlx.data.default_configs import default_ilql_config
url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"
if __name__ == "__main__":
accelerator = Accelerator()
if os.environ.get("LOCAL_RANK", "0") == "0" and not os.path.exists(dbpath):
print(f"fetching {dbpath}")
urlretrieve(url, dbpath)
accelerator.wait_for_everyone()
conn = sqlite3.connect(dbpath)
c = conn.cursor()
c.execute(
"SELECT prompt, rating FROM ratings "
"JOIN images ON images.id=ratings.iid "
"JOIN generations ON images.gid=generations.id "
"WHERE rating IS NOT NULL;"
)
prompts, ratings = tuple(map(list, zip(*c.fetchall())))
trlx.train(
config=default_ilql_config(),
samples=prompts,
rewards=ratings,
eval_prompts=["An astronaut riding a horse"] * 64,
)
|