Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
import gradio as gr
|
3 |
+
from gradio_client import Client
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from diffusers import FluxPipeline, AutoencoderKL
|
7 |
+
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
|
13 |
+
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
|
14 |
+
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
|
15 |
+
|
16 |
+
llm_client = Client("CohereForAI/c4ai-command-r-plus-08-2024")
|
17 |
+
|
18 |
+
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
|
19 |
+
|
20 |
+
prompt_template = """λ€μ νλ₯΄μλ μ€λͺ
μ κ°μ§ μΊλ¦ν°λ₯Ό μμ±νμΈμ: {persona_description}
|
21 |
+
λ€μ μ€λͺ
μ μΈκ³μμ: {world_description}
|
22 |
+
λ€μ νλλ₯Ό ν¬ν¨νλ JSON νμμΌλ‘ μΊλ¦ν°λ₯Ό μμ±νμΈμ:
|
23 |
+
- name: μΊλ¦ν°μ μ΄λ¦
|
24 |
+
- background: μΊλ¦ν°μ λ°°κ²½
|
25 |
+
- appearance: μΊλ¦ν°μ μΈλͺ¨
|
26 |
+
- personality: μΊλ¦ν°μ μ±κ²©
|
27 |
+
- skills_and_abilities: μΊλ¦ν°μ κΈ°μ κ³Ό λ₯λ ₯
|
28 |
+
- goals: μΊλ¦ν°μ λͺ©ν
|
29 |
+
- conflicts: μΊλ¦ν°μ κ°λ±
|
30 |
+
- backstory: μΊλ¦ν°μ κ³Όκ±° μ΄μΌκΈ°
|
31 |
+
- current_situation: μΊλ¦ν°μ νμ¬ μν©
|
32 |
+
- spoken_lines: μΊλ¦ν°μ λμ¬ (λ¬Έμμ΄ λ¦¬μ€νΈ)
|
33 |
+
JSON νμμ μΊλ¦ν° μ€λͺ
λ§ μμ±νκ³ λ€λ₯Έ λ΄μ©μ ν¬ν¨νμ§ λ§μΈμ. '```'λ ν¬ν¨νμ§ λ§μΈμ.
|
34 |
+
"""
|
35 |
+
|
36 |
+
world_description_prompt = "λ
νΉνκ³ λ¬΄μμν μΈκ³ μ€λͺ
μ μμ±νμΈμ (μΈκ³ μ€λͺ
λ§ μμ±νκ³ λ€λ₯Έ λ΄μ©μ ν¬ν¨νμ§ λ§μΈμ)."
|
37 |
+
|
38 |
+
def get_random_world_description():
|
39 |
+
result = llm_client.predict(
|
40 |
+
query=world_description_prompt,
|
41 |
+
history=[],
|
42 |
+
system="λΉμ μ μ리λ°λ° ν΄λΌμ°λκ° λ§λ Qwenμ
λλ€. λΉμ μ λμμ΄ λλ μ΄μμ€ν΄νΈμ
λλ€.",
|
43 |
+
api_name="/model_chat",
|
44 |
+
)
|
45 |
+
return result[1][0][-1]
|
46 |
+
|
47 |
+
def get_random_persona_description():
|
48 |
+
return ds.shuffle().select([100])[0]["persona"]
|
49 |
+
|
50 |
+
@spaces.GPU(duration=75)
|
51 |
+
def infer_flux(character_json):
|
52 |
+
for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
|
53 |
+
prompt=character_json["appearance"],
|
54 |
+
guidance_scale=3.5,
|
55 |
+
num_inference_steps=28,
|
56 |
+
width=1024,
|
57 |
+
height=1024,
|
58 |
+
generator=torch.Generator("cpu").manual_seed(0),
|
59 |
+
output_type="pil",
|
60 |
+
good_vae=good_vae,
|
61 |
+
):
|
62 |
+
yield image
|
63 |
+
|
64 |
+
def generate_character(world_description, persona_description, progress=gr.Progress(track_tqdm=True)):
|
65 |
+
result = llm_client.predict(
|
66 |
+
query=prompt_template.format(
|
67 |
+
persona_description=persona_description, world_description=world_description
|
68 |
+
),
|
69 |
+
history=[],
|
70 |
+
system="λΉμ μ μ리λ°λ° ν΄λΌμ°λκ° λ§λ Qwenμ
λλ€. λΉμ μ λμμ΄ λλ μ΄μμ€ν΄νΈμ
λλ€.",
|
71 |
+
api_name="/model_chat",
|
72 |
+
)
|
73 |
+
output = json.loads(result[1][0][-1])
|
74 |
+
return output
|
75 |
+
|
76 |
+
with gr.Blocks(title="μΊλ¦ν° μλ μμ±", theme="Nymbo/Nymbo_Theme") as app:
|
77 |
+
with gr.Column():
|
78 |
+
gr.HTML("<center><h1>μΊλ¦ν° μμ±κΈ°</h1></center>")
|
79 |
+
with gr.Column():
|
80 |
+
with gr.Row():
|
81 |
+
world_description = gr.Textbox(lines=10, label="μΈκ³ μ€λͺ
", scale=4)
|
82 |
+
persona_description = gr.Textbox(lines=10, label="νλ₯΄μλ μ€λͺ
", value=get_random_persona_description(), scale=1)
|
83 |
+
with gr.Row():
|
84 |
+
random_world_button = gr.Button(value="무μμ μΈκ³ μ€λͺ
κ°μ Έμ€κΈ°", variant="secondary", scale=1)
|
85 |
+
submit_button = gr.Button(value="ν₯λ―Έλ‘μ΄ μΊλ¦ν° μμ±νκΈ°!", variant="primary", scale=5)
|
86 |
+
random_persona_button = gr.Button(value="무μμ νλ₯΄μλ μ€λͺ
κ°μ Έμ€κΈ°", variant="secondary", scale=1)
|
87 |
+
with gr.Row():
|
88 |
+
character_image = gr.Image(label="μΊλ¦ν° μ΄λ―Έμ§")
|
89 |
+
character_json = gr.JSON(label="μΊλ¦ν° μ€λͺ
")
|
90 |
+
|
91 |
+
examples = gr.Examples(
|
92 |
+
[
|
93 |
+
"λ§λ²μ΄ μ€μ‘΄νκ³ μ©λ€μ΄ νλμ λ μλ€λλ μΈκ³μμ, λͺ¨νκ°λ€μ μΌνμ΄ μ μ€μ μΈ μ©μμ κ²μ μ°Ύμ λμλλ€.",
|
94 |
+
"μν 리μμ μ€μ κ²μ νμν©λλ€. μ΄κ³³μ 물리 λ²μΉμ΄ κ³ λ λ§λ²μ μμ§μ λ°λΌ νμ΄μ§λ κ΄ννκ³ μ λΉλ‘μ΄ μμμ
λλ€. μ΄ μΈκ³λ λμλ νλμ λ μλ μλ§μ λΆμ μ¬μΌλ‘ ꡬμ±λμ΄ μμΌλ©°, κ° μ¬μ μλͺ
κ³Ό λΉλ°λ‘ κ°λ μ°¬ λ
νΉν μνκ³μ
λλ€. μν 리μμ μ¬λ€μ μΈμ°½ν μ κΈλΆν° ν©λν μμ μ¬λ§κΉμ§ λ€μν©λλ€. μ΄λ€ μ¬μ λμ λΈλ‘λ§νΌ μκ³ , μ΄λ€ μ¬μ μλ°± λ§μΌμ κ±Έμ³ μμ΅λλ€. μ΄ λΆλ¦¬λ λ
λ©μ΄λ¦¬λ€μ μ°κ²°νλ κ²μ μμν μλμ§λ‘ λ§λ€μ΄μ§ λ°μ§μ΄λ λ€λ¦¬μ΄λ©°, beaten pathλ₯Ό λ²μ΄λ λ§οΏ½οΏ½ μ©κ°ν μ¬λλ€μ λ¨Ό 거리λ₯Ό μκ°μ μΌλ‘ μ΄λν μ μλ μ¨κ²¨μ§ ν¬νΈμ μ°Ύμ μ μμ΅λλ€. μν 리μμ μ£Όλ―Όλ€μ κ·Έ νκ²½λ§νΌμ΄λ λ€μν©λλ€. μΈκ°λ€μ λΉμ μν
λ₯΄ μ‘΄μ¬, λ°μ νΌλΆλ₯Ό κ°μ§ κ±°μΈ, λΆλ₯νκΈ° μ΄λ €μ΄ λ³ν μλ¬Όλ€κ³Ό 곡쑴ν©λλ€. κ³ λ μ μ λ€μ΄ μ¬λ€μ μ μ μ΄ μ₯μνκ³ μμ΄, κ³Όνκ³Ό λ§λ²μ κ²½κ³λ₯Ό ν리λ μνμ§ λ¬Έλͺ
κ³Ό κΈ°μ μ μμν©λλ€. μΈκ³λ μν
λ₯΄λΌλ μ λΉν λ¬Όμ§λ‘ μμ§μ΄λ©°, μ΄κ²μ λͺ¨λ κ²μ ν΅ν΄ νλ¦
λλ€. μν
λ₯΄μ νμ λ€λ£° μ μλ μ¬λλ€μ νμ€ μ체λ₯Ό μ‘°μν μ μλ κ°λ ₯ν λ§λ²μ¬κ° λ©λλ€. κ·Έλ¬λ μν
λ₯΄λ νμ λ μμμ΄λ©°, κ·Έ ν¬μμ±μΌλ‘ μΈν΄ ν΅μ κΆμ λκ³ λ€ν¬λ λ€μν μΈλ ₯ μ¬μ΄μ κ°λ±μ΄ μκ²Όμ΅λλ€. μ¬λ€ μ¬μ΄μ νλμμλ μ
μ₯ν λΉνμ λ€μ΄ λ§λ²μ κΈ°λ₯λ₯Ό νκ³ νν΄νλ©° 무μκ³Ό ννμ μ΄μ§ν©λλ€. ν΄μ λ€κ³Ό νλ μ΅κ²©μλ€μ΄ κ΅¬λ¦ κΉμν κ³³μ μ¨μ΄ λ°©μ¬ν λ¨Ήμκ°μ νμ λ
Έλ¦¬κ³ μμ΅λλ€. λΆμ νλ λ
κΉμν μλμλ μΈλ보μ΄λλΌλ μ΄λ‘κ³ μνν μμμ΄ μμΌλ©°, μ
λͺ½ κ°μ μλ¬Όλ€κ³Ό λ§λ‘ ννν μ μλ λΆκ° κ°λν©λλ€. μ€μ§ κ°μ₯ μ©κ°ν λͺ¨νκ°λ€λ§μ΄ κ·Έ κΉμ΄λ₯Ό νννλ € νκ³ , λ μ μ μλ§μ΄ κ·Έ μ΄μΌκΈ°λ₯Ό λ€λ €μ£ΌκΈ° μν΄ λμμ΅λλ€. νμ μ‘΄μ¬νλ μνμΌλ‘, νΌλμ ννμ΄ μλ €μ§ μΈκ³μ κ°μ₯μ리μμ λ§Ήμλ₯Ό λ¨μΉλ©° κ·Έ κΈΈμ μλ λͺ¨λ κ²μ μΌμΌλ²λ¦΄ μνμ κ°νκ³ μμ΅λλ€. μν 리μμ μμ
λ€μκ² μΈκ³μ λΉλ°μ λ°νκ³ λ무 λ¦κΈ° μ μ λ€κ°μ€λ μ΄λ μ λ¬Όλ¦¬μΉ λ°©λ²μ μ°Ύλ μλ¬΄κ° μ£Όμ΄μ‘μ΅λλ€. μν 리μμμλ λͺ¨λ μ¬μ΄ μ΄μΌκΈ°λ₯Ό νκ³ μκ³ , λͺ¨λ μλ¬Όμ΄ λΉλ°μ κ°μ§κ³ μμΌλ©°, λͺ¨λ λͺ¨νμ΄ μ΄ κ²½μ΄λ‘κ³ μνμ μ²ν μΈκ³μ μ΄λͺ
μ λ°κΏ μ μμ΅λλ€.",
|
95 |
+
"μ μμ μ μΈκ³μλ 'μ€λΌν€μ€'λΌλ λμκ° μμ΅λλ€. μ΄ λμλ μμν λΉμ κΈ°λ₯ μμ λ μμ΅λλ€. λμμ λ²½μ μμ μ λ¦¬λ‘ λ§λ€μ΄μ Έ μμ΄ νμ μλ²½κ³Ό ν©νΌμ μμ λ°μ¬νλ©° μμν μ²μμ κ΄μ±λ₯Ό λ°ν©λλ€. 건물λ€μ κ³μ μ λ°λΌ μ¨μ μ¬λ©° ννλ₯Ό λ°κΏλλ€ - λ΄μλ μ±μ₯νκ³ , μ¬λ¦μλ κ°ν΄μ§λ©°, κ°μμλ ν΄μνκΈ° μμν΄ κ²¨μΈμ΄ λλ©΄ μκ°κ° λ©λλ€.",
|
96 |
+
],
|
97 |
+
world_description,
|
98 |
+
)
|
99 |
+
|
100 |
+
submit_button.click(
|
101 |
+
generate_character, [world_description, persona_description], outputs=[character_json]
|
102 |
+
).then(fn=infer_flux, inputs=[character_json], outputs=[character_image])
|
103 |
+
random_world_button.click(
|
104 |
+
get_random_world_description, outputs=[world_description]
|
105 |
+
)
|
106 |
+
random_persona_button.click(
|
107 |
+
get_random_persona_description, outputs=[persona_description]
|
108 |
+
)
|
109 |
+
|
110 |
+
app.queue().launch(share=False)
|