File size: 2,942 Bytes
53e0942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Importar bibliotecas
import torch
import re
import random
import requests
import shutil
from clip_interrogator import Config, Interrogator
from transformers import pipeline, set_seed, AutoTokenizer, AutoModelForSeq2SeqLM
from PIL import Image
import gradio as gr

# Configurar CLIP
config = Config()
config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
config.blip_offload = False if torch.cuda.is_available() else True
config.chunk_size = 2048
config.flavor_intermediate_count = 512
config.blip_num_beams = 64
config.clip_model_name = "ViT-H-14/laion2b_s32b_b79k"
ci = Interrogator(config)

# Función para generar prompt desde imagen
def get_prompt_from_image(image, mode):
    image = image.convert('RGB')
    if mode == 'best':
        prompt = ci.interrogate(image)
    elif mode == 'classic':
        prompt = ci.interrogate_classic(image)
    elif mode == 'fast':
        prompt = ci.interrogate_fast(image)
    elif mode == 'negative':
        prompt = ci.interrogate_negative(image)
    return prompt

# Función para generar texto
text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator')

def text_generate(input):
    seed = random.randint(100, 1000000)
    set_seed(seed)
    for count in range(6):    
        sequences = text_pipe(input, max_length=random.randint(60, 90), num_return_sequences=8)
        list = []
        for sequence in sequences:
            line = sequence['generated_text'].strip()
            if line != input and len(line) > (len(input) + 4) and line.endswith((':', '-', '—')) is False:
                list.append(line)

        result = "\n".join(list)
        result = re.sub('[^ ]+\.[^ ]+','', result)
        result = result.replace('<', '').replace('>', '')
        if result != '':
            return result
        if count == 5:
            return result

# Crear interfaz gradio
with gr.Blocks() as block:
    with gr.Column():
        gr.HTML('<h1>MidJourney / SD2 Helper Tool</h1>')
        with gr.Tab('Generate from Image'):
            with gr.Row():
                input_image = gr.Image(type='pil')
                with gr.Column():
                    input_mode = gr.Radio(['best', 'fast', 'classic', 'negative'], value='best', label='Mode')
            img_btn = gr.Button('Discover Image Prompt')
            output_image = gr.Textbox(lines=6, label='Generated Prompt')

        with gr.Tab('Generate from Text'):
            input_text = gr.Textbox(lines=6, label='Your Idea', placeholder='Enter your content here...')
            output_text = gr.Textbox(lines=6, label='Generated Prompt')
            text_btn = gr.Button('Generate Prompt')

    img_btn.click(fn=get_prompt_from_image, inputs=[input_image, input_mode], outputs=output_image)
    text_btn.click(fn=text_generate, inputs=input_text, outputs=output_text)

block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')