starriver030515 commited on
Commit
c80eac7
·
verified ·
1 Parent(s): c1dfcc3

Update gen_sd3.py

Browse files
Files changed (1) hide show
  1. gen_sd3.py +12 -5
gen_sd3.py CHANGED
@@ -22,19 +22,26 @@ json_filename = args.json_filename
22
  cuda_device = f"cuda:{args.cuda}"
23
  print(json_filename, cuda_device)
24
 
25
- image_dir = "/mnt/petrelfs/zhuchenglin/LLaVA/playground/data/LLaVA-Pretrain/images"
26
  with open(json_filename, "r") as f:
27
  json_data = json.load(f)
28
 
29
  pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
30
- pipe.to('cuda')
31
 
32
  for text in json_data:
 
 
 
 
 
 
 
33
  image = pipe(
34
- prompt=text["conversations"][1]["value"],
35
- prompt_3=text["conversations"][1]["value"],
36
  negative_prompt="",
37
- num_inference_steps=100,
38
  height=1024,
39
  width=1024,
40
  guidance_scale=10.0,
 
22
  cuda_device = f"cuda:{args.cuda}"
23
  print(json_filename, cuda_device)
24
 
25
+ image_dir = "/mnt/petrelfs/zhuchenglin/LLaVA/playground/data"
26
  with open(json_filename, "r") as f:
27
  json_data = json.load(f)
28
 
29
  pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
30
+ pipe.to(cuda_device)
31
 
32
  for text in json_data:
33
+ prompt = ""
34
+ for caption in text['conversations']:
35
+ if caption['from'] == 'gpt':
36
+ prompt += caption['value']
37
+ # for caption in text['conversations']:
38
+ # prompt += caption['value']
39
+
40
  image = pipe(
41
+ prompt=prompt,
42
+ prompt_3=prompt,
43
  negative_prompt="",
44
+ num_inference_steps=60,
45
  height=1024,
46
  width=1024,
47
  guidance_scale=10.0,