from fastapi import FastAPI, File, UploadFile, HTTPException, Request from fastapi.responses import HTMLResponse, FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware from gradio_client import Client, handle_file import os import tempfile import base64 app = FastAPI() # Retrieve Hugging Face token from environment variables hf_token = os.environ.get('HF_TOKEN') client = Client("Ashrafb/image-to-sketch-ari", hf_token=hf_token) # Configure CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust as needed, '*' allows requests from any origin allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/upload/") async def upload_file(file: UploadFile = File(...)): # Save the uploaded file to a temporary location with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(await file.read()) temp_file_path = temp_file.name try: # Call the Gradio API result = client.predict( img=handle_file(temp_file_path), api_name="/predict" ) # Check if the result is valid if result and len(result) == 2: # Convert the sketch image to a base64 string with open(result[0], "rb") as image_file: image_data = base64.b64encode(image_file.read()).decode("utf-8") return { "sketch_image_base64": f"data:image/png;base64,{image_data}", "result_file": result[1] } else: return {"error": "Invalid result from the prediction API."} except Exception as e: return {"error": str(e)} finally: # Clean up the temporary file if os.path.exists(temp_file_path): os.unlink(temp_file_path)