apruvd's picture
Update app.py
77fb036
raw
history blame
7.03 kB
import whisper
model = whisper.load_model("base")
model.device
import gradio as gr
from keybert import KeyBERT
import random as r
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
import torch
model_id = 'prompthero/midjourney-v4-diffusion' #"stabilityai/stable-diffusion-2"
# Use the Euler scheduler here instead
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(model_id , torch_dtype=torch.float16) #pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
# from IPython.display import Image
from PIL import Image
import time
import matplotlib.pyplot as plt
import numpy as np
import PIL
# import cv2
def transcribe(audio,prompt_num,user_keywords):
# load audio and pad/trim it to fit 30 seconds
audio1 = whisper.load_audio(audio)
audio1 = whisper.pad_or_trim(audio1)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio1).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
print(f"Detected language: {max(probs, key=probs.get)}")
# decode the audio
options = whisper.DecodingOptions()
result = whisper.decode(model, mel, options)
print(result.text)
# model = whisper.load_model("base")
audio2 = whisper.load_audio(audio)
final_result = model.transcribe(audio2)
print(final_result["text"])
return final_result["text"],int(prompt_num),user_keywords
def keywords(text,prompt_num,user_keywords):
# ub = UrlBuilder("demo.imgix.net")
kw_model = KeyBERT()
a = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 3), stop_words=None)
set_1 = [i[0] for i in a]
b = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 3), stop_words='english',
use_maxsum=True, nr_candidates=20, top_n=5)
set_2 = [i[0] for i in b]
c = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 3), stop_words='english',
use_mmr=True, diversity=0.7)
set_3 = [i[0] for i in c]
d = kw_model.extract_keywords(text, keyphrase_ngram_range=(1, 3), stop_words='english',
use_mmr=True, diversity=0.2)
set_4 = [i[0] for i in d]
keyword_pool = set_1 + set_2 + set_3 + set_4
print("keywords: ", keyword_pool, "length: ", len(keyword_pool))
generated_prompts = []
count = 0
while count != int(prompt_num):
sentence = []
style_prompts = ["perfect shading, soft studio lighting, ultra-realistic, photorealistic, octane render, cinematic lighting, hdr, in-frame, 4k, 8k, edge lighting", "detailed, colourful, psychedelic, unreal engine, octane render, blender effect", "mechanical features, cybernetic eyes, baroque, rococo, anodized titanium highly detailed mechanisms, gears, fiber, cogs, bulbs, wires, cables, 70mm, Canon EOS 6D Mark II, 4k, 35mm (FX, Full-Frame), f/2.5, extremely detailed, very high details, photorealistic, hi res, hdr, UHD, hyper-detailed, ultra-realistic, vibrant, centered, vivid colors, Wide angle, zoom out", "detailed, soft ambiance, japanese influence, unreal engine 5, octane render", "perfect shading, soft studio lighting, ultra-realistic, photorealistic, octane render, cinematic lighting, hdr, in-frame, 4k, 8k, edge lighting --v 4"]
my_list = user_keywords.split(',')
print(my_list)
# for i in range(len(my_list)):
# sentence.append(my_list[i])
# numb = 5
for i in range(len(my_list)):
# print("keyword_pool",keyword_pool, len(keyword_pool))
sentence.append("mdjrny-v4 style")
for i in range (len(my_list)):
sentence.append(my_list[i])
rand_1 = r.randint(1, 4)
if rand_1 == 1:
sentence.append(r.choice(set_1))
sentence.append(r.choice(set_1))
sentence.append(r.choice(set_2))
sentence.append(r.choice(set_3))
sentence.append(r.choice(set_4))
elif rand_1 == 2:
sentence.append(r.choice(set_2))
sentence.append(r.choice(set_2))
sentence.append(r.choice(set_1))
sentence.append(r.choice(set_3))
sentence.append(r.choice(set_4))
elif rand_1 == 3:
sentence.append(r.choice(set_3))
sentence.append(r.choice(set_3))
sentence.append(r.choice(set_1))
sentence.append(r.choice(set_2))
sentence.append(r.choice(set_4))
else:
sentence.append(r.choice(set_4))
sentence.append(r.choice(set_4))
sentence.append(r.choice(set_1))
sentence.append(r.choice(set_2))
sentence.append(r.choice(set_3))
# Add Style Tail Prompt
sentence.append(r.choice(style_prompts))
print("sentence: ", sentence)
# Formatting Data as comma-delimited for Mid Journey
myprompt = ', '.join(str(e) for e in sentence)
sentence = []
print("prompt: ",myprompt)
generated_prompts.append(myprompt)
count += 1
print("no. of prompts: ", len(generated_prompts))
print("generated prompts: ", generated_prompts)
count = 0
images = []
# np_images = []
while count != int(len(generated_prompts)):
for i in generated_prompts:
count += 1
print(i)
image = pipe(i, height=768, width=768, guidance_scale = 10).images[0]
# image.save("/content/drive/MyDrive/ColabNotebooks/GeneratedImages/" + "sd_image_" +str(count)+ ".png")
images.append(image)
# pick the image which is the smallest, and resize the others to match it (can be arbitrary image shape here)
min_shape = sorted( [(np.sum(i.size), i.size ) for i in images])[0][1]
imgs_comb = np.hstack([i.resize(min_shape) for i in images])
# save that beautiful picture
imgs_comb = Image.fromarray( imgs_comb)
# imgs_comb.save("/content/drive/MyDrive/ColabNotebooks/GeneratedImages/" + "Combined.png")
# return imgs_comb #for combined image
return images
speech_text = gr.Interface(fn=transcribe, inputs=[gr.Audio(source="microphone", type="filepath"),gr.Number(placeholder = "Number of Images to be generated (int): "),gr.Textbox(placeholder = "Additional keywords (comma delimitied): ")], outputs=["text","number","text"], title = 'Speech to Image Generator', enable_queue=True)
text_prompts = gr.Interface(fn=keywords, inputs=["text","number","text"], outputs=gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto"), title = 'Speech to Image Generator', enable_queue=True)
gr.Series(speech_text,text_prompts).launch(inline = False, enable_queue=True).queue()