Spaces:
Running
Running
import gradio as gr | |
# Use a pipeline as a high-level helper | |
from transformers import pipeline | |
# Use a pipeline as a high-level helper | |
# Load model directly | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
# processor = AutoImageProcessor.from_pretrained("AZIIIIIIIIZ/vit-base-patch16-224-finetuned-eurosat") | |
# model = AutoModelForImageClassification.from_pretrained("AZIIIIIIIIZ/vit-base-patch16-224-finetuned-eurosat") | |
pipe = pipeline("image-classification", model="AZIIIIIIIIZ/vit-base-patch16-224-finetuned-eurosat") | |
# $ pip install gradio_client fastapi uvicorn | |
import requests | |
from PIL import Image | |
from transformers import pipeline | |
import io | |
import base64 | |
# Initialize the pipeline | |
# pipe = pipeline('image-classification') | |
def load_image_from_path(image_path): | |
return Image.open(image_path) | |
def load_image_from_url(image_url): | |
response = requests.get(image_url) | |
return Image.open(io.BytesIO(response.content)) | |
def load_image_from_base64(base64_string): | |
image_data = base64.b64decode(base64_string) | |
return Image.open(io.BytesIO(image_data)) | |
def predict(image_input): | |
if isinstance(image_input, str): | |
if image_input.startswith('http'): | |
image = load_image_from_url(image_input) | |
elif image_input.startswith('/'): | |
image = load_image_from_path(image_input) | |
else: | |
image = load_image_from_base64(image_input) | |
elif isinstance(image_input, Image.Image): | |
image = image_input | |
else: | |
raise ValueError("Incorrect format used for image. Should be an URL linking to an image, a base64 string, a local path, or a PIL image.") | |
return pipe(image) | |
# def predict(image): | |
# return pipe(image) | |
def main(): | |
# image_input = 'path_or_url_or_base64' # Update with actual input | |
# output = predict(image_input) | |
# print(output) | |
demo = gr.Interface( | |
fn=predict, | |
inputs='image', | |
outputs='text', | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |