throaway2854 commited on
Commit
cf60c6f
·
verified ·
1 Parent(s): 0c5b16e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -56
app.py CHANGED
@@ -2,58 +2,30 @@ import gradio as gr
2
  import random
3
  import json
4
  import os
5
- from huggingface_hub import HfApi, hf_hub_download, upload_file, create_repo
6
- from huggingface_hub.utils import RepositoryNotFoundError
7
 
8
- REPO_ID = "throaway2854/promptGen"
9
- DATA_FILE = "saved_data.json"
10
- IMAGES_DIR = "character_images"
11
-
12
- def ensure_repo_exists():
13
- try:
14
- api = HfApi()
15
- api.repo_info(repo_id=REPO_ID, repo_type="space")
16
- except RepositoryNotFoundError:
17
- create_repo(REPO_ID, repo_type="space", space_sdk="gradio")
18
 
19
  def load_data():
20
- ensure_repo_exists()
21
- try:
22
- file_path = hf_hub_download(repo_id=REPO_ID, filename=DATA_FILE)
23
- with open(file_path, 'r') as f:
24
  return json.load(f)
25
- except:
26
- return {
27
- "scene_tags": [], "position_tags": [], "outfit_tags": [],
28
- "camera_tags": [], "concept_tags": [], "lora_tags": [],
29
- "characters": {}
30
- }
31
 
32
  def save_data(data):
33
- ensure_repo_exists()
34
  with open(DATA_FILE, 'w') as f:
35
  json.dump(data, f)
36
- api = HfApi()
37
- api.upload_file(
38
- path_or_fileobj=DATA_FILE,
39
- path_in_repo=DATA_FILE,
40
- repo_id=REPO_ID,
41
- repo_type="space"
42
- )
43
 
44
  def save_character_image(name, image):
45
- ensure_repo_exists()
46
- if not os.path.exists(IMAGES_DIR):
47
- os.makedirs(IMAGES_DIR)
48
  image_path = os.path.join(IMAGES_DIR, f"{name}.png")
49
  image.save(image_path)
50
- api = HfApi()
51
- api.upload_file(
52
- path_or_fileobj=image_path,
53
- path_in_repo=f"{IMAGES_DIR}/{name}.png",
54
- repo_id=REPO_ID,
55
- repo_type="space"
56
- )
57
  return image_path
58
 
59
  def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
@@ -65,20 +37,19 @@ def generate_prompt(scene_tags, num_people, position_tags, selected_characters,
65
  "concept": concept_tags.split(','),
66
  }
67
 
68
- for category in all_tags:
69
- all_tags[category] = [tag.strip() for tag in all_tags[category] if tag.strip()]
70
 
71
- character_prompts = []
72
- for char_name in selected_characters:
73
- if char_name in data["characters"]:
74
- char_traits = data["characters"][char_name]["traits"]
75
- char_prompt = f"{char_name}, " + ", ".join(random.sample(char_traits, min(tag_counts["character"], len(char_traits))))
76
- character_prompts.append(char_prompt)
77
 
78
- selected_tags = []
79
- for category, tags in all_tags.items():
80
- if tags:
81
- selected_tags.extend(f"{tag}:{random.uniform(0.8, 1.2):.2f}" for tag in random.sample(tags, min(tag_counts[category], len(tags))))
 
82
 
83
  if num_people.strip():
84
  selected_tags.append(f"{num_people} people:1.1")
@@ -92,12 +63,10 @@ def generate_prompt(scene_tags, num_people, position_tags, selected_characters,
92
 
93
  fixed_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
94
 
95
- full_prompt = f"{main_prompt}, {fixed_tags} {lora_prompt}".strip()
96
- return full_prompt
97
 
98
  def update_data(data, key, value):
99
- if isinstance(data[key], list):
100
- data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
101
  save_data(data)
102
  return data
103
 
 
2
  import random
3
  import json
4
  import os
 
 
5
 
6
+ DATA_DIR = "/data"
7
+ IMAGES_DIR = "/images"
8
+ DATA_FILE = os.path.join(DATA_DIR, "saved_data.json")
 
 
 
 
 
 
 
9
 
10
  def load_data():
11
+ if os.path.exists(DATA_FILE):
12
+ with open(DATA_FILE, 'r') as f:
 
 
13
  return json.load(f)
14
+ return {
15
+ "scene_tags": [], "position_tags": [], "outfit_tags": [],
16
+ "camera_tags": [], "concept_tags": [], "lora_tags": [],
17
+ "characters": {}
18
+ }
 
19
 
20
  def save_data(data):
21
+ os.makedirs(DATA_DIR, exist_ok=True)
22
  with open(DATA_FILE, 'w') as f:
23
  json.dump(data, f)
 
 
 
 
 
 
 
24
 
25
  def save_character_image(name, image):
26
+ os.makedirs(IMAGES_DIR, exist_ok=True)
 
 
27
  image_path = os.path.join(IMAGES_DIR, f"{name}.png")
28
  image.save(image_path)
 
 
 
 
 
 
 
29
  return image_path
30
 
31
  def generate_prompt(scene_tags, num_people, position_tags, selected_characters, outfit_tags, camera_tags, concept_tags, lora_tags, tag_counts, data):
 
37
  "concept": concept_tags.split(','),
38
  }
39
 
40
+ all_tags = {k: [tag.strip() for tag in v if tag.strip()] for k, v in all_tags.items()}
 
41
 
42
+ character_prompts = [
43
+ f"{char_name}, " + ", ".join(random.sample(data["characters"][char_name]["traits"],
44
+ min(tag_counts["character"], len(data["characters"][char_name]["traits"]))))
45
+ for char_name in selected_characters if char_name in data["characters"]
46
+ ]
 
47
 
48
+ selected_tags = [
49
+ f"{tag}:{random.uniform(0.8, 1.2):.2f}"
50
+ for category, tags in all_tags.items()
51
+ for tag in random.sample(tags, min(tag_counts[category], len(tags)))
52
+ ]
53
 
54
  if num_people.strip():
55
  selected_tags.append(f"{num_people} people:1.1")
 
63
 
64
  fixed_tags = "source_anime, score_9, score_8_up, score_7_up, masterpiece, best quality, very aesthetic, absurdres, anime artwork, anime style, vibrant, studio anime, highly detailed"
65
 
66
+ return f"{main_prompt}, {fixed_tags} {lora_prompt}".strip()
 
67
 
68
  def update_data(data, key, value):
69
+ data[key] = list(set(data[key] + [v.strip() for v in value.split(',') if v.strip()]))
 
70
  save_data(data)
71
  return data
72