File size: 1,197 Bytes
fa6856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# Optimize prompts by training on prompts-ratings pairings dataset
# taken from https://github.com/JD-P/simulacra-aesthetic-captions

import os
import sqlite3
from urllib.request import urlretrieve

from accelerate import Accelerator

import trlx
from trlx.data.default_configs import default_ilql_config

url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"

if __name__ == "__main__":
    accelerator = Accelerator()
    if os.environ.get("LOCAL_RANK", "0") == "0" and not os.path.exists(dbpath):
        print(f"fetching {dbpath}")
        urlretrieve(url, dbpath)
    accelerator.wait_for_everyone()

    conn = sqlite3.connect(dbpath)
    c = conn.cursor()
    c.execute(
        "SELECT prompt, rating FROM ratings "
        "JOIN images ON images.id=ratings.iid "
        "JOIN generations ON images.gid=generations.id "
        "WHERE rating IS NOT NULL;"
    )

    prompts, ratings = tuple(map(list, zip(*c.fetchall())))
    trlx.train(
        config=default_ilql_config(),
        samples=prompts,
        rewards=ratings,
        eval_prompts=["An astronaut riding a horse"] * 64,
    )