Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
from utils.prompt_utils import remove_color | |
from utils.tagger import modelLoad, analysis | |
class PromptAnalysis: | |
def __init__(self, app_config, post_filter=True, | |
default_nagative_prompt="lowres, error, extra digit, fewer digits, cropped, worst quality, " | |
"low quality, normal quality, jpeg artifacts, blurry"): | |
self.default_nagative_prompt = default_nagative_prompt | |
self.post_filter = post_filter | |
self.model = None | |
self.model_dir = os.path.join(app_config.dpath, 'models/tagger') | |
def layout(self, lang_util, input_image): | |
with gr.Column(): | |
with gr.Row(): | |
self.prompt = gr.Textbox(label=lang_util.get_text("prompt"), lines=3) | |
with gr.Row(): | |
self.negative_prompt = gr.Textbox(label=lang_util.get_text("negative_prompt"), lines=3, value=self.default_nagative_prompt) | |
with gr.Row(): | |
self.prompt_analysis_button = gr.Button(lang_util.get_text("analyze_prompt")) | |
self.prompt_analysis_button.click( | |
self.process_prompt_analysis, | |
inputs=[input_image], | |
outputs=self.prompt | |
) | |
return [self.prompt, self.negative_prompt] | |
def process_prompt_analysis(self, input_image_path): | |
if self.model is None: | |
self.model = modelLoad(self.model_dir) | |
tags = analysis(input_image_path, self.model_dir, self.model) | |
tags_list = tags | |
if self.post_filter: | |
tags_list = remove_color(tags) | |
return tags_list | |