File size: 1,896 Bytes
5372b88
 
 
 
 
 
 
 
 
 
 
7c6ffc8
 
5372b88
 
7c6ffc8
5372b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c6ffc8
5372b88
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
from huggingface_hub import model_info

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler


def main():
    REPOS = {
        "tom_cruise_plain": {"hub_model_id": "asrimanth/person-thumbs-up-plain-lora", "model_dir": "/l/vision/v5/sragas/easel_ai/models_plain/"},
        "tom_cruise": {"hub_model_id": "asrimanth/person-thumbs-up-lora", "model_dir": "/l/vision/v5/sragas/easel_ai/models/"},
        "tom_cruise_no_cap": {"hub_model_id": "asrimanth/person-thumbs-up-lora-no-cap", "model_dir": "/l/vision/v5/sragas/easel_ai/models_no_cap/"},
        "srimanth_plain": {"hub_model_id": "asrimanth/srimanth-thumbs-up-lora-plain", "model_dir": "/l/vision/v5/sragas/easel_ai/models_srimanth_plain/"}
    }
    N_IMAGES = 50
    current_repo_id = "tom_cruise_no_cap"

    SAVE_DIR = f"./results/{current_repo_id}/"
    os.makedirs(SAVE_DIR, exist_ok=True)

    current_repo = REPOS[current_repo_id]

    print(f"{'-'*20} CURRENT REPO: {current_repo_id} {'-'*20}")
    hub_model_id = current_repo["hub_model_id"]
    model_dir = current_repo["model_dir"]

    info = model_info(hub_model_id)
    model_base = info.cardData["base_model"]
    print(f"Base model is: {model_base}")

    pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16, cache_dir="/l/vision/v5/sragas/hf_models/")
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

    pipe.unet.load_attn_procs(hub_model_id)
    pipe.to("cuda")

    generators = [torch.Generator("cuda").manual_seed(i) for i in range(N_IMAGES)]
    prompt = "<tom_cruise> showing #thumbsup"
    print(f"Inferencing '{prompt}' for {N_IMAGES} images.")
    for i in range(N_IMAGES):
        image = pipe(prompt, generator=generators[i], num_inference_steps=25).images[0]
        image.save(f"{SAVE_DIR}out_{i}.png")

if __name__ == "__main__":
    main()