StoryVerseWeaver / core /image_services.py
mgbam's picture
Create image_services.py
f7f6fe3 verified
raw
history blame
7.25 kB
# storyverse_weaver/core/image_services.py
import os
import requests # For generic API calls
import base64
from io import BytesIO
from PIL import Image
# from dotenv import load_dotenv
# load_dotenv()
# --- API Key Configuration (Use specific names for StoryVerse) ---
STABILITY_API_KEY = os.getenv("STORYVERSE_STABILITY_API_KEY")
OPENAI_API_KEY = os.getenv("STORYVERSE_OPENAI_API_KEY") # For DALL-E
# HUGGINGFACE_TOKEN also used by llm_services, can be reused if using HF image models
STABILITY_API_CONFIGURED = bool(STABILITY_API_KEY and STABILITY_API_KEY.strip())
OPENAI_DALLE_CONFIGURED = bool(OPENAI_API_KEY and OPENAI_API_KEY.strip())
# HF_IMAGE_CONFIGURED = bool(HF_TOKEN and HF_TOKEN.strip()) # Assuming HF_TOKEN is also for image models
class ImageGenResponse:
def __init__(self, image: Image.Image = None, image_url: str = None, error: str = None, success: bool = True, provider: str = "unknown"):
self.image = image # PIL Image object
self.image_url = image_url # If API returns a URL
self.error = error
self.success = success
self.provider = provider
def initialize_image_llms(): # Simple function to print status
print("INFO: image_services.py - Initializing Image Generation services...")
if STABILITY_API_CONFIGURED: print("SUCCESS: image_services.py - Stability AI API Key detected.")
else: print("WARNING: image_services.py - STORYVERSE_STABILITY_API_KEY not found. Stability AI disabled.")
if OPENAI_DALLE_CONFIGURED: print("SUCCESS: image_services.py - OpenAI API Key detected (for DALL-E).")
else: print("WARNING: image_services.py - STORYVERSE_OPENAI_API_KEY not found. DALL-E disabled.")
# if HF_IMAGE_CONFIGURED: print("INFO: image_services.py - Hugging Face Token detected (can be used for HF image models).")
print("INFO: image_services.py - Image LLM Init complete.")
# --- Stability AI (Example) ---
def generate_image_stabilityai(prompt: str, style_preset: str = None, negative_prompt: str = None,
engine_id: str = "stable-diffusion-xl-1024-v1-0",
steps: int = 30, cfg_scale: float = 7.0,
width: int = 1024, height: int = 1024) -> ImageGenResponse:
if not STABILITY_API_CONFIGURED:
return ImageGenResponse(error="Stability AI API key not configured.", success=False, provider="StabilityAI")
api_host = os.getenv('API_HOST', 'https://api.stability.ai')
request_url = f"{api_host}/v1/generation/{engine_id}/text-to-image"
payload = {
"text_prompts": [{"text": prompt}],
"cfg_scale": cfg_scale,
"height": height,
"width": width,
"steps": steps,
"samples": 1,
}
if style_preset: payload["style_preset"] = style_preset
if negative_prompt: payload["text_prompts"].append({"text": negative_prompt, "weight": -1.0})
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": f"Bearer {STABILITY_API_KEY}"
}
print(f"DEBUG: image_services.py - Calling Stability AI with prompt: {prompt[:50]}...")
try:
response = requests.post(request_url, headers=headers, json=payload, timeout=60) # Increased timeout
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
artifacts = response.json().get("artifacts")
if not artifacts:
return ImageGenResponse(error="No image artifacts found in Stability AI response.", success=False, provider="StabilityAI", raw_response=response.text)
image_data = base64.b64decode(artifacts[0]["base64"])
image = Image.open(BytesIO(image_data))
print("DEBUG: image_services.py - Stability AI image generated successfully.")
return ImageGenResponse(image=image, provider="StabilityAI")
except requests.exceptions.RequestException as e:
error_msg = f"Stability AI API request failed: {type(e).__name__} - {str(e)}"
if hasattr(e, 'response') and e.response is not None: error_msg += f" - Response: {e.response.text[:200]}"
print(f"ERROR: image_services.py - {error_msg}")
return ImageGenResponse(error=error_msg, success=False, provider="StabilityAI", raw_response=e)
except Exception as e: # Catch other potential errors like JSON decoding
error_msg = f"Error processing Stability AI response: {type(e).__name__} - {str(e)}"
print(f"ERROR: image_services.py - {error_msg}")
return ImageGenResponse(error=error_msg, success=False, provider="StabilityAI", raw_response=e)
# --- DALL-E (Conceptual - you'd need 'openai' library and setup) ---
def generate_image_dalle(prompt: str, size="1024x1024", quality="standard", n=1) -> ImageGenResponse:
if not OPENAI_DALLE_CONFIGURED:
return ImageGenResponse(error="OpenAI DALL-E API key not configured.", success=False, provider="DALL-E")
try:
# from openai import OpenAI # Would be imported at top level
# client = OpenAI(api_key=OPENAI_API_KEY)
# response = client.images.generate(
# model="dall-e-3", # or "dall-e-2"
# prompt=prompt,
# size=size,
# quality=quality,
# n=n,
# response_format="url" # or "b64_json"
# )
# image_url = response.data[0].url
# image_content = requests.get(image_url).content
# image = Image.open(BytesIO(image_content))
# return ImageGenResponse(image=image, image_url=image_url, provider="DALL-E")
print("DEBUG: image_services.py - DALL-E call placeholder.") # Placeholder
# Simulate an image for now
dummy_image = Image.new('RGB', (512, 512), color = 'skyblue')
return ImageGenResponse(image=dummy_image, provider="DALL-E (Simulated)")
except Exception as e:
return ImageGenResponse(error=f"DALL-E API Error: {type(e).__name__} - {str(e)}", success=False, provider="DALL-E")
# --- Hugging Face Image Model (Conceptual - via Inference API or local Diffusers) ---
# def generate_image_hf_model(prompt: str, model_id="stabilityai/stable-diffusion-xl-base-1.0") -> ImageGenResponse:
# if not HF_IMAGE_CONFIGURED:
# return ImageGenResponse(error="HF Token not configured for image models.", success=False, provider="HF")
# try:
# You might use a client similar to hf_inference_text_client but for image generation task
# Or if it's a diffusers pipeline, you'd load and run it.
# This requires the `diffusers` library and often significant compute.
# response_bytes = hf_some_image_client.text_to_image(prompt, model=model_id) # Hypothetical client method
# image = Image.open(BytesIO(response_bytes))
# return ImageGenResponse(image=image, provider=f"HF ({model_id})")
# print("DEBUG: image_services.py - HF Image Model call placeholder.")
# dummy_image = Image.new('RGB', (512, 512), color = 'lightgreen')
# return ImageGenResponse(image=dummy_image, provider=f"HF ({model_id} - Simulated)")
# except Exception as e:
# return ImageGenResponse(error=f"HF Image Model Error: {type(e).__name__} - {str(e)}", success=False, provider=f"HF ({model_id})")
print("DEBUG: core.image_services (for StoryVerseWeaver) - Module defined.")