File size: 2,528 Bytes
d1d1942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from tqdm.auto import tqdm
from huggingface_hub import cached_download, hf_hub_url
import os

def display_image(image):
    """
    Replace this with your actual image display logic.
    """
    image.show()

def load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name):
    try: 
        pipe = DiffusionPipeline.from_pretrained(
            base_model_id, 
            torch_dtype=torch.float16, 
            scheduler=DPMSolverMultistepScheduler.from_config(
                pipe.scheduler.config),
            variant="fp16",
            use_safetensors=True,
        ).to("cuda")

        lora_url = hf_hub_url(lora_id, revision="main", filename=lora_weight_name)
        lora_path = cached_download(lora_url) 

        with tqdm(desc="Loading LoRA weights", unit="step") as pbar:
            pipe.load_lora_weights(
                lora_path,
                weight_name=lora_weight_name,
                adapter_name=lora_adapter_name,
                progress_callback=lambda step, max_steps: pbar.update(1)
            )

        print("LoRA merged successfully!")
        return pipe
    except Exception as e:
        print(f"Error merging LoRA: {e}")
        return None

def save_merged_model(pipe, save_path):
    """Saves the merged model to the specified path."""
    try:
        pipe.save_pretrained(save_path)
        print(f"Merged model saved successfully to: {save_path}")
    except Exception as e:
        print(f"Error saving the merged model: {e}")

if __name__ == "__main__":
    base_model_id = input("Enter the base model ID: ") 
    lora_id = input("Enter the LoRA Hugging Face Hub ID: ") 
    lora_weight_name = input("Enter the LoRA weight file name: ")
    lora_adapter_name = input("Enter the LoRA adapter name: ")

    pipe = load_and_merge_lora(base_model_id, lora_id, lora_weight_name, lora_adapter_name)

    if pipe:
        prompt = input("Enter your prompt: ")
        lora_scale = float(input("Enter the LoRA scale (e.g., 0.9): ")) 

        image = pipe(
            prompt, 
            num_inference_steps=30, 
            cross_attention_kwargs={"scale": lora_scale}, 
            generator=torch.manual_seed(0)
        ).images[0]

        display_image(image)

        # Ask the user for a directory to save the model
        save_path = input(
            "Enter the directory where you want to save the merged model: "
        )
        save_merged_model(pipe, save_path)