api / app.py
AdarshJi's picture
Update app.py
05d900a verified
from g4f.Provider.hf_space import BlackForestLabsFlux1Dev, G4F, BlackForestLabsFlux1Schnell
import asyncio
from flask import Flask, request, jsonify, Response, render_template
app = Flask(__name__)
class IMG:
def __init__(self, prompt: str, width: int = 1024, height: int = 1024, guidance_scale: float = 3.5, seed: int = 0):
self.prompt = prompt
self.width = width
self.height = height
self.guidance_scale = guidance_scale
self.seed = seed
self.messages = [{"role": "user", "content": self.prompt}]
async def _run_async_generator(self, generator):
"""Runs the async generator and extracts image URLs safely."""
results = []
try:
async for result in generator:
if hasattr(result, "images") and isinstance(result.images, list):
results.extend(result.images)
else:
results.append(str(result)) # Convert non-image responses to string
except Exception as e:
print("Error processing response:", e)
return results
def _generate_images(self, provider_class, model):
"""Generic method to fetch images from any provider."""
async def main():
try:
async for result in provider_class.create_async_generator(
model=model, messages=self.messages,
width=self.width, height=self.height,
guidance_scale=self.guidance_scale, seed=self.seed
):
yield result
except Exception as e:
print(f"Error generating images from {model}:", e)
yield f"Error: {e}"
return asyncio.run(self._run_async_generator(main()))
def BlackForest(self,model="black-forest-labs-flux-1-dev"):
if model in BlackForestLabsFlux1Dev.get_models():
pass
else:
model = "black-forest-labs-flux-1-dev"
return self._generate_images(BlackForestLabsFlux1Dev, model)
def FluxMidJourny(self,model="flux"):
if model in G4F.get_models():
pass
else:
model = "flux"
return self._generate_images(G4F, model)
def BlackForestSchnell(self,model="black-forest-labs-flux-1-schnell"):
if model in BlackForestLabsFlux1Schnell.get_models():
pass
else:
model = "black-forest-labs-flux-1-schnell"
return self._generate_images(BlackForestLabsFlux1Schnell, model)
@app.route("/generate/image", methods=["POST"])
def generate_image():
data = request.json
prompt = data.get("prompt")
model = data.get("model", "black-forest-labs-flux-1-dev")
width = data.get("width", 1024)
height = data.get("height", 1024)
guidance_scale = data.get("guidance_scale", 3.5)
seed = data.get("seed", 0)
provider = data.get("provider", "flux")
if not prompt:
return jsonify({"error": "prompt is required"}), 400
img = IMG(prompt, width, height, guidance_scale, seed)
def GenerateImage():
if provider == "blackforestlabs":
return img.BlackForest(model)
elif provider == "flux":
return img.FluxMidJourny(model)
elif provider == "blackforestlabs-schnell":
return img.BlackForestSchnell(model)
result = GenerateImage()
print(result)
return jsonify({"Result" : result}), 200
@app.route("/providers", methods=["GET"])
def get_providers():
return jsonify({"providers": ["blackforestlabs", "flux", "blackforestlabs-schnell"]}), 200
@app.route("/generate/image/model", methods=["POST"])
def get_models():
data = request.json
provider = data.get("provider", "blackforestlabs")
if provider == "blackforestlabs":
return jsonify({"models": BlackForestLabsFlux1Dev.get_models()}), 200
elif provider == "flux":
return jsonify({"models": G4F.get_models()}), 200
elif provider == "blackforestlabs-schnell":
return jsonify({"models": BlackForestLabsFlux1Schnell.get_models()}), 200
return jsonify({"error": "provider not found"}), 404
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
if __name__ == "__main__":
app.run(port=7860)