diane / app.py
oscurantismo's picture
Update app.py
60e6d37 verified
raw
history blame
8.08 kB
import os
import gradio as gr
import openai
from openai import OpenAI
from PIL import Image, ImageEnhance
import cv2
import torch
from transformers import CLIPProcessor, CLIPModel
import requests
from io import BytesIO
# Set OpenAI API Key
openai.api_key = os.getenv("OPENAI_API_KEY")
# Load CLIP model and processor
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
# Expanded object labels
object_labels = [
"cat", "dog", "house", "tree", "car", "mountain", "flower", "bird", "person", "robot",
"a digital artwork", "a portrait", "a landscape", "a futuristic cityscape", "horse",
"lion", "tiger", "elephant", "giraffe", "airplane", "train", "ship", "book", "laptop",
"keyboard", "pen", "clock", "cup", "bottle", "backpack", "chair", "table", "sofa",
"bed", "building", "street", "forest", "desert", "waterfall", "sunset", "beach",
"bridge", "castle", "statue", "3D model"
]
# Example image for contrast check
EXAMPLE_IMAGE_URL = "https://www.watercoloraffair.com/wp-content/uploads/2023/04/monet-houses-of-parliament-low-key.jpg" # Square example image
example_image = Image.open(BytesIO(requests.get(EXAMPLE_IMAGE_URL).content))
# Initialize OpenAI client
client = OpenAI()
def process_chat(user_text):
if not user_text.strip():
yield "⚠️ Please enter a valid question."
return
try:
# Use the client to create a completion
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a helpful assistant named Diane specializing in digital art advice."},
{"role": "user", "content": user_text},
],
stream=True # Enable streaming
)
response_text = ""
for chunk in response:
if chunk.choices[0].delta.get("content"):
token = chunk.choices[0].delta["content"]
response_text += token
yield response_text
except Exception as e:
yield f"❌ An error occurred: {str(e)}"
# Function to analyze image contrast
def analyze_contrast_opencv(image_path):
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
contrast = img.std()
return contrast
# Function to identify objects using CLIP
def identify_objects_with_clip(image_path):
image = Image.open(image_path).convert("RGB")
inputs = clip_processor(text=object_labels, images=image, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).numpy().flatten()
best_match_label = object_labels[probs.argmax()]
return best_match_label
# Function to enhance image contrast
def enhance_contrast(image):
enhancer = ImageEnhance.Contrast(image)
enhanced_image = enhancer.enhance(1.5)
enhanced_path = "enhanced_image.png"
enhanced_image.save(enhanced_path)
return enhanced_path
# Function to provide additional suggestions with streaming
def provide_suggestions_streaming(object_identified):
if not object_identified:
yield "Sorry, I couldn't find an object in your artwork. Try a different image."
return
stream = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are an expert digital art advisor."},
{"role": "user", "content": f"Suggest ways to improve a digital artwork featuring a {object_identified}."}
],
stream=True
)
response = ""
for chunk in stream:
if chunk.choices[0].delta.get("content"):
token = chunk.choices[0].delta["content"]
response += token
yield response
# Main image processing function
def process_image(image):
if not image:
return "⚠️ Please upload an image.", None, None
image.save("uploaded_image.png")
contrast = analyze_contrast_opencv("uploaded_image.png")
object_identified = identify_objects_with_clip("uploaded_image.png")
if contrast < 25:
enhanced_image_path = enhance_contrast(Image.open("uploaded_image.png"))
return (
f"Hey, great artwork of {object_identified}! However, it looks like the contrast is a little low. I've improved the contrast for you. ✨",
enhanced_image_path,
object_identified
)
return (
f"Hey, great artwork of {object_identified}! Looks like the color contrast is great. Be proud of yourself! 🌟",
None,
object_identified
)
# Gradio Blocks Interface
demo = gr.Blocks(css="""
#upload-image, #example-image {
height: 300px !important;
}
.button {
height: 50px;
font-size: 16px;
}
""")
with demo:
gr.Markdown("## 🎨 DIANE (Digital Imaging and Art Neural Enhancer)")
gr.Markdown("DIANE is here to assist you in refining your digital art. She can answer questions about digital art, analyze your images, and provide creative suggestions to enhance your work.")
# Chatbot Section
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ’¬ Ask me about digital art")
user_text = gr.Textbox(label="Enter your question", placeholder="What is the best tool for a beginner?...")
chat_output = gr.Textbox(label="Answer", interactive=False)
chat_button = gr.Button("Ask", elem_classes="button")
chat_button.click(process_chat, inputs=user_text, outputs=chat_output)
# Image Analysis Section
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ–ΌοΈ Upload an image to check its contrast levels")
with gr.Row(equal_height=True):
# Left: Image upload field
with gr.Column():
image_input = gr.Image(label="Upload an image", type="pil", elem_id="upload-image")
image_button = gr.Button("Check", elem_classes="button")
# Right: Example image field
with gr.Column():
gr.Image(value=example_image, label="Example Image", interactive=False, elem_id="example-image")
example_button = gr.Button("Use Example Image", elem_classes="button")
image_output_text = gr.Textbox(label="Analysis", interactive=False)
image_output_image = gr.Image(label="Improved Image", interactive=False)
suggestion_button = gr.Button("I want to improve this artwork. Any suggestions?", visible=False)
suggestions_output = gr.Textbox(label="Suggestions", interactive=True)
state_object = gr.State() # To store identified object
# Load example image into the input
def use_example_image():
return example_image
example_button.click(
use_example_image,
inputs=None,
outputs=image_input
)
# Analyze button
def update_suggestions_visibility(analysis, enhanced_image, identified_object):
return gr.update(visible=True), analysis, enhanced_image
image_button.click(
process_image,
inputs=image_input,
outputs=[
image_output_text,
image_output_image,
state_object
]
)
# Automatically enable suggestions after image processing
image_button.click(
update_suggestions_visibility,
inputs=[image_output_text, image_output_image, state_object],
outputs=[suggestion_button, image_output_text, image_output_image]
)
# Suggestion button functionality with streaming
suggestion_button.click(
provide_suggestions_streaming,
inputs=state_object,
outputs=suggestions_output
)
demo.launch()