imranali291 commited on
Commit
d2893c5
·
verified ·
1 Parent(s): c318566

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import random
4
+ import transformers
5
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
6
+
7
+ if torch.cuda.is_available():
8
+ device = "cuda"
9
+ print("Using GPU")
10
+ else:
11
+ device = "cpu"
12
+ print("Using CPU")
13
+
14
+ tokenizer = T5Tokenizer.from_pretrained("imranali291/flux-prompt-enhancer")
15
+ model = T5ForConditionalGeneration.from_pretrained("imranali291/flux-prompt-enhancer", device_map="auto", torch_dtype="auto")
16
+ model.to(device)
17
+
18
+ def generate(your_prompt, task_prefix, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
19
+
20
+ if seed == 0:
21
+ seed = random.randint(1, 2**32-1)
22
+ transformers.set_seed(seed)
23
+
24
+ if model_precision_type == "fp16":
25
+ dtype = torch.float16
26
+ elif model_precision_type == "fp32":
27
+ dtype = torch.float32
28
+
29
+ model.to(dtype)
30
+
31
+ repetition_penalty = float(repetition_penalty)
32
+
33
+ input_text = f"{task_prefix}: {your_prompt}"
34
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
35
+
36
+ outputs = model.generate(
37
+ input_ids,
38
+ max_new_tokens=max_new_tokens,
39
+ repetition_penalty=repetition_penalty,
40
+ do_sample=True,
41
+ temperature=temperature,
42
+ top_p=top_p,
43
+ top_k=top_k,
44
+ )
45
+
46
+ better_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
+ return better_prompt
48
+
49
+
50
+ your_prompt = gr.Textbox(label="Your Prompt", info="Your Prompt that you want to enhanced")
51
+
52
+ max_new_tokens = gr.Slider(value=128, minimum=25, maximum=512, step=1, label="Max New Tokens", info="The maximum numbers of new tokens, controls how long is the output")
53
+
54
+ repetition_penalty = gr.Slider(value=2.5, minimum=0, maximum=3.0, step=0.05, label="Repetition Penalty", info="Penalize repeated tokens, making the AI repeat less itself")
55
+
56
+ temperature = gr.Slider(value=0.7, minimum=0, maximum=1, step=0.05, label="Temperature", info="Higher values produce more diverse outputs")
57
+
58
+ model_precision_type = gr.Dropdown(["fp16", "fp32"], value="fp16", label="Model Precision Type", info="The precision type to load the model, like fp16 which is faster, or fp32 which is more precise but more resource consuming")
59
+
60
+ top_p = gr.Slider(value=0.9, minimum=0, maximum=2, step=0.05, label="Top P", info="Higher values sample more low-probability tokens")
61
+
62
+ top_k = gr.Slider(value=50, minimum=1, maximum=100, step=1, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
63
+
64
+ seed = gr.Slider(value=42, minimum=0, maximum=2**32-1, step=1, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
65
+
66
+ examples = [
67
+ ["Futuristic cityscape at twilight descent.", "A captivating digital artwork portraying a futuristic cityscape at twilight from an aerial insects viewpoint, featuring glistening skyscrapers with neon lights casting reflections on rain-soaked streets below. The scene harmoniously blends optimism and wistfulness as vivid, multi-colored holographic billboards flicker against the looming shadows of towering structures, creating a poignant contrast that encapsulates both anticipation and solitude in this thriving urban hub. As clouds veil the setting sun, a warm amber radiance envelops the cityscape, extending elongated silhouettes across the damp pavement while raindrops sparkle like precious gems on outstretched wings, symbolizing the intricate equilibrium between technological advancement and natural beauty within this bustling metropolis.", 128, 2.5, 0.5, "fp16", 0.9, 50, 42]
68
+ ]
69
+
70
+ gr.Interface(
71
+ fn=generate,
72
+ inputs=[your_prompt, task_prefix, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed],
73
+ outputs=gr.Textbox(label="Prompt Enhancer"),
74
+ title="Prompt Enhancer",
75
+ description='Make your prompts more detailed!',
76
+ examples=examples,
77
+ ).launch(share=True)