multimodalart HF Staff commited on
Commit
f29396c
·
verified ·
1 Parent(s): 076e3f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -11
app.py CHANGED
@@ -40,28 +40,68 @@ from demo_utils.constant import ZERO_VAE_CACHE
40
  from demo_utils.vae_block3 import VAEDecoderWrapper
41
  from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
42
 
43
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
- model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
 
 
47
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
48
- model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  enhancer = pipeline(
50
- 'text2text-generation',
51
  model=model,
52
  tokenizer=tokenizer,
53
- repetition_penalty= 1.2,
54
- device=device
55
  )
56
- max_target_length = 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  @spaces.GPU
59
  def enhance_prompt(prompt):
60
- prefix = "enhance prompt: "
61
- short_prompt = prompt
62
- answer = enhancer(prefix + short_prompt, max_length=max_target_length)
 
 
 
 
 
 
 
 
63
  final_answer = answer[0]['generated_text']
64
- return final_answer
65
 
66
  # --- Argument Parsing ---
67
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
 
40
  from demo_utils.vae_block3 import VAEDecoderWrapper
41
  from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
42
 
43
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
+
47
+ model_checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
48
+
49
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
50
+
51
+ quantization_config = BitsAndBytesConfig(
52
+ load_in_4bit=True,
53
+ bnb_4bit_compute_dtype=torch.bfloat16,
54
+ bnb_4bit_quant_type="nf4",
55
+ bnb_4bit_use_double_quant=True,
56
+ )
57
+
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_checkpoint,
60
+ torch_dtype=torch.bfloat16,
61
+ attn_implementation="flash_attention_2",
62
+ quantization_config=quantization_config,
63
+ device_map="auto"
64
+ )
65
  enhancer = pipeline(
66
+ 'text-generation',
67
  model=model,
68
  tokenizer=tokenizer,
69
+ repetition_penalty=1.2,
 
70
  )
71
+
72
+ T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
73
+ Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
74
+ Start directly with the action, and keep descriptions literal and precise.
75
+ Think like a cinematographer describing a shot list.
76
+ Do not change the user input intent, just enhance it.
77
+ Keep within 150 words.
78
+ For best results, build your prompts using this structure:
79
+ Start with main action in a single sentence
80
+ Add specific details about movements and gestures
81
+ Describe character/object appearances precisely
82
+ Include background and environment details
83
+ Specify camera angles and movements
84
+ Describe lighting and colors
85
+ Note any changes or sudden events
86
+ Do not exceed the 150 word limit!
87
+ Output the enhanced prompt only.
88
+ """
89
 
90
  @spaces.GPU
91
  def enhance_prompt(prompt):
92
+ messages = [
93
+ {"role": "system", "content": T2V_CINEMATIC_PROMPT},
94
+ {"role": "user", "content": f"user_prompt: {prompt}"},
95
+ ]
96
+ answer = enhancer(
97
+ messages,
98
+ max_new_tokens=256,
99
+ return_full_text=False,
100
+ pad_token_id=tokenizer.eos_token_id
101
+ )
102
+
103
  final_answer = answer[0]['generated_text']
104
+ return final_answer.strip()
105
 
106
  # --- Argument Parsing ---
107
  parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")