|
from datasets import load_dataset |
|
import gradio as gr |
|
from gradio_client import Client |
|
import json |
|
import torch |
|
from diffusers import FluxPipeline, AutoencoderKL |
|
from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images |
|
import spaces |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device) |
|
good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device) |
|
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe) |
|
|
|
llm_client = Client("Qwen/Qwen2.5-72B-Instruct") |
|
|
|
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train") |
|
|
|
prompt_template = """λ€μ νλ₯΄μλ μ€λͺ
μ κ°μ§ μΊλ¦ν°λ₯Ό μμ±νμΈμ: {persona_description} |
|
λ€μ μ€λͺ
μ μΈκ³μμ: {world_description} |
|
λ€μ νλλ₯Ό ν¬ν¨νλ JSON νμμΌλ‘ μΊλ¦ν°λ₯Ό μμ±νμΈμ: |
|
- name: μΊλ¦ν°μ μ΄λ¦ |
|
- background: μΊλ¦ν°μ λ°°κ²½ |
|
- appearance: μΊλ¦ν°μ μΈλͺ¨ |
|
- personality: μΊλ¦ν°μ μ±κ²© |
|
- skills_and_abilities: μΊλ¦ν°μ κΈ°μ κ³Ό λ₯λ ₯ |
|
- goals: μΊλ¦ν°μ λͺ©ν |
|
- conflicts: μΊλ¦ν°μ κ°λ± |
|
- backstory: μΊλ¦ν°μ κ³Όκ±° μ΄μΌκΈ° |
|
- current_situation: μΊλ¦ν°μ νμ¬ μν© |
|
- spoken_lines: μΊλ¦ν°μ λμ¬ (λ¬Έμμ΄ λ¦¬μ€νΈ) |
|
JSON νμμ μΊλ¦ν° μ€λͺ
λ§ μμ±νκ³ λ€λ₯Έ λ΄μ©μ ν¬ν¨νμ§ λ§μΈμ. '```'λ ν¬ν¨νμ§ λ§μΈμ. |
|
""" |
|
|
|
world_description_prompt = "λ
νΉνκ³ λ¬΄μμν μΈκ³ μ€λͺ
μ μμ±νμΈμ (μΈκ³ μ€λͺ
λ§ μμ±νκ³ λ€λ₯Έ λ΄μ©μ ν¬ν¨νμ§ λ§μΈμ)." |
|
|
|
def get_random_world_description(): |
|
result = llm_client.predict( |
|
query=world_description_prompt, |
|
history=[], |
|
system="λ°λμ νκΈλ‘ μΆλ ₯νλΌ. λΉμ μ λμμ΄ λλ μ΄μμ€ν΄νΈμ
λλ€.", |
|
api_name="/model_chat", |
|
) |
|
return result[1][0][-1] |
|
|
|
def get_random_persona_description(): |
|
return ds.shuffle().select([100])[0]["persona"] |
|
|
|
@spaces.GPU(duration=75) |
|
def infer_flux(character_json): |
|
for image in pipe.flux_pipe_call_that_returns_an_iterable_of_images( |
|
prompt=character_json["appearance"], |
|
guidance_scale=3.5, |
|
num_inference_steps=28, |
|
width=1024, |
|
height=1024, |
|
generator=torch.Generator("cpu").manual_seed(0), |
|
output_type="pil", |
|
good_vae=good_vae, |
|
): |
|
yield image |
|
|
|
def generate_character(world_description, persona_description, progress=gr.Progress(track_tqdm=True)): |
|
result = llm_client.predict( |
|
query=prompt_template.format( |
|
persona_description=persona_description, world_description=world_description |
|
), |
|
history=[], |
|
system="λ°λμ νκΈλ‘ μΆλ ₯νλΌ. λΉμ μ λμμ΄ λλ μ΄μμ€ν΄νΈμ
λλ€.", |
|
api_name="/model_chat", |
|
) |
|
output = json.loads(result[1][0][-1]) |
|
return output |
|
|
|
with gr.Blocks(title="μΊλ¦ν° μλ μμ±", theme="Nymbo/Nymbo_Theme") as app: |
|
with gr.Column(): |
|
gr.HTML("<center><h1>μΊλ¦ν° μμ±κΈ°</h1></center>") |
|
with gr.Column(): |
|
with gr.Row(): |
|
world_description = gr.Textbox(lines=10, label="μΈκ³ μ€λͺ
", scale=4) |
|
persona_description = gr.Textbox(lines=10, label="νλ₯΄μλ μ€λͺ
", value=get_random_persona_description(), scale=1) |
|
with gr.Row(): |
|
random_world_button = gr.Button(value="무μμ μΈκ³ μ€λͺ
κ°μ Έμ€κΈ°", variant="secondary", scale=1) |
|
submit_button = gr.Button(value="ν₯λ―Έλ‘μ΄ μΊλ¦ν° μμ±νκΈ°!", variant="primary", scale=5) |
|
random_persona_button = gr.Button(value="무μμ νλ₯΄μλ μ€λͺ
κ°μ Έμ€κΈ°", variant="secondary", scale=1) |
|
with gr.Row(): |
|
character_image = gr.Image(label="μΊλ¦ν° μ΄λ―Έμ§") |
|
character_json = gr.JSON(label="μΊλ¦ν° μ€λͺ
") |
|
|
|
examples = gr.Examples( |
|
[ |
|
"λ§λ²μ΄ μ€μ‘΄νκ³ μ©λ€μ΄ νλμ λ μλ€λλ μΈκ³μμ, λͺ¨νκ°λ€μ μΌνμ΄ μ μ€μ μΈ μ©μμ κ²μ μ°Ύμ λμλλ€.", |
|
"μν 리μμ μ€μ κ²μ νμν©λλ€. μ΄κ³³μ 물리 λ²μΉμ΄ κ³ λ λ§λ²μ μμ§μ λ°λΌ νμ΄μ§λ κ΄ννκ³ μ λΉλ‘μ΄ μμμ
λλ€. μ΄ μΈκ³λ λμλ νλμ λ μλ μλ§μ λΆμ μ¬μΌλ‘ ꡬμ±λμ΄ μμΌλ©°, κ° μ¬μ μλͺ
κ³Ό λΉλ°λ‘ κ°λ μ°¬ λ
νΉν μνκ³μ
λλ€. μν 리μμ μ¬λ€μ μΈμ°½ν μ κΈλΆν° ν©λν μμ μ¬λ§κΉμ§ λ€μν©λλ€. μ΄λ€ μ¬μ λμ λΈλ‘λ§νΌ μκ³ , μ΄λ€ μ¬μ μλ°± λ§μΌμ κ±Έμ³ μμ΅λλ€. μ΄ λΆλ¦¬λ λ
λ©μ΄λ¦¬λ€μ μ°κ²°νλ κ²μ μμν μλμ§λ‘ λ§λ€μ΄μ§ λ°μ§μ΄λ λ€λ¦¬μ΄λ©°, beaten pathλ₯Ό λ²μ΄λ λ§νΌ μ©κ°ν μ¬λλ€μ λ¨Ό 거리λ₯Ό μκ°μ μΌλ‘ μ΄λν μ μλ μ¨κ²¨μ§ ν¬νΈμ μ°Ύμ μ μμ΅λλ€. μν 리μμ μ£Όλ―Όλ€μ κ·Έ νκ²½λ§νΌμ΄λ λ€μν©λλ€. μΈκ°λ€μ λΉμ μν
λ₯΄ μ‘΄μ¬, λ°μ νΌλΆλ₯Ό κ°μ§ κ±°μΈ, λΆλ₯νκΈ° μ΄λ €μ΄ λ³ν μλ¬Όλ€κ³Ό 곡쑴ν©λλ€. κ³ λ μ μ λ€μ΄ μ¬λ€μ μ μ μ΄ μ₯μνκ³ μμ΄, κ³Όνκ³Ό λ§λ²μ κ²½κ³λ₯Ό ν리λ μνμ§ λ¬Έλͺ
κ³Ό κΈ°μ μ μμν©λλ€. μΈκ³λ μν
λ₯΄λΌλ μ λΉν λ¬Όμ§λ‘ μμ§μ΄λ©°, μ΄κ²μ λͺ¨λ κ²μ ν΅ν΄ νλ¦
λλ€. μν
λ₯΄μ νμ λ€λ£° μ μλ μ¬λλ€μ νμ€ μ체λ₯Ό μ‘°μν μ μλ κ°λ ₯ν λ§λ²μ¬κ° λ©λλ€. κ·Έλ¬λ μν
λ₯΄λ νμ λ μμμ΄λ©°, κ·Έ ν¬μμ±μΌλ‘ μΈν΄ ν΅μ κΆμ λκ³ λ€ν¬λ λ€μν μΈλ ₯ μ¬μ΄μ κ°λ±μ΄ μκ²Όμ΅λλ€. μ¬λ€ μ¬μ΄μ νλμμλ μ
μ₯ν λΉνμ λ€μ΄ λ§λ²μ κΈ°λ₯λ₯Ό νκ³ νν΄νλ©° 무μκ³Ό ννμ μ΄μ§ν©λλ€. ν΄μ λ€κ³Ό νλ μ΅κ²©μλ€μ΄ κ΅¬λ¦ κΉμν κ³³μ μ¨μ΄ λ°©μ¬ν λ¨Ήμκ°μ νμ λ
Έλ¦¬κ³ μμ΅λλ€. λΆμ νλ λ
κΉμν μλμλ μΈλ보μ΄λλΌλ μ΄λ‘κ³ μνν μμμ΄ μμΌλ©°, μ
λͺ½ κ°μ μλ¬Όλ€κ³Ό λ§λ‘ ννν μ μλ λΆκ° κ°λν©λλ€. μ€μ§ κ°μ₯ μ©κ°ν λͺ¨νκ°λ€λ§μ΄ κ·Έ κΉμ΄λ₯Ό νννλ € νκ³ , λ μ μ μλ§μ΄ κ·Έ μ΄μΌκΈ°λ₯Ό λ€λ €μ£ΌκΈ° μν΄ λμμ΅λλ€. νμ μ‘΄μ¬νλ μνμΌλ‘, νΌλμ ννμ΄ μλ €μ§ μΈκ³μ κ°μ₯μ리μμ λ§Ήμλ₯Ό λ¨μΉλ©° κ·Έ κΈΈμ μλ λͺ¨λ κ²μ μΌμΌλ²λ¦΄ μνμ κ°νκ³ μμ΅λλ€. μν 리μμ μμ
λ€μκ² μΈκ³μ λΉλ°μ λ°νκ³ λ무 λ¦κΈ° μ μ λ€κ°μ€λ μ΄λ μ λ¬Όλ¦¬μΉ λ°©λ²μ μ°Ύλ μλ¬΄κ° μ£Όμ΄μ‘μ΅λλ€. μν 리μμμλ λͺ¨λ μ¬μ΄ μ΄μΌκΈ°λ₯Ό νκ³ μκ³ , λͺ¨λ μλ¬Όμ΄ λΉλ°μ κ°μ§κ³ μμΌλ©°, λͺ¨λ λͺ¨νμ΄ μ΄ κ²½μ΄λ‘κ³ μνμ μ²ν μΈκ³μ μ΄λͺ
μ λ°κΏ μ μμ΅λλ€.", |
|
"μ μμ μ μΈκ³μλ 'μ€λΌν€μ€'λΌλ λμκ° μμ΅λλ€. μ΄ λμλ μμν λΉμ κΈ°λ₯ μμ λ μμ΅λλ€. λμμ λ²½μ μμ μ λ¦¬λ‘ λ§λ€μ΄μ Έ μμ΄ νμ μλ²½κ³Ό ν©νΌμ μμ λ°μ¬νλ©° μμν μ²μμ κ΄μ±λ₯Ό λ°ν©λλ€. 건물λ€μ κ³μ μ λ°λΌ μ¨μ μ¬λ©° ννλ₯Ό λ°κΏλλ€ - λ΄μλ μ±μ₯νκ³ , μ¬λ¦μλ κ°ν΄μ§λ©°, κ°μμλ ν΄μνκΈ° μμν΄ κ²¨μΈμ΄ λλ©΄ μκ°κ° λ©λλ€.", |
|
], |
|
world_description, |
|
) |
|
|
|
submit_button.click( |
|
generate_character, [world_description, persona_description], outputs=[character_json] |
|
).then(fn=infer_flux, inputs=[character_json], outputs=[character_image]) |
|
random_world_button.click( |
|
get_random_world_description, outputs=[world_description] |
|
) |
|
random_persona_button.click( |
|
get_random_persona_description, outputs=[persona_description] |
|
) |
|
|
|
app.queue().launch(share=False) |