Spaces:
Sleeping
Sleeping
import torch | |
import requests | |
import gradio as gr | |
from PIL import Image | |
from transformers import AutoImageProcessor, ResNetForImageClassification | |
target_folder = "JungminChung/India_ResNet" | |
def load_model_and_preprocessor(target_folder): | |
model = ResNetForImageClassification.from_pretrained(target_folder) | |
image_processor = AutoImageProcessor.from_pretrained(target_folder) | |
return model, image_processor | |
def fetch_image(url): | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36' | |
} | |
image_raw = requests.get(url, headers=headers, stream=True).raw | |
image = Image.open(image_raw) | |
return image | |
def infer_image(image, model, image_processor, k): | |
processed_img = image_processor(images=image.convert("RGB"), return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**processed_img) | |
logits = outputs.logits | |
prob = torch.nn.functional.softmax(logits, dim=-1) | |
topk_prob, topk_indices = torch.topk(prob, k=k) | |
res = "" | |
for idx, (prob, index) in enumerate(zip(topk_prob[0], topk_indices[0])): | |
res += f"{idx+1}. {model.config.id2label[index.item()]:<15} ({prob.item()*100:.2f} %) \n" | |
return res | |
def infer(url, k, target_folder=target_folder): | |
try : | |
image = fetch_image(url) | |
model, image_processor = load_model_and_preprocessor(target_folder) | |
res = infer_image(image, model, image_processor, k) | |
except : | |
image = Image.new('RGB', (224, 224)) | |
res = "์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์ค๋๋ฐ ๋ฌธ์ ๊ฐ ์๋๋ด์. ๋ค๋ฅธ ์ด๋ฏธ์ง url๋ก ๋ค์ ์๋ํด์ฃผ์ธ์." | |
return image, res | |
demo = gr.Interface( | |
fn=infer, | |
inputs=[ | |
gr.Textbox(value="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRpE-UHBp8ZufNUd3BKw8gtIxSe3IUwspOfqw&s", | |
label="Image URL"), | |
gr.Slider(minimum=0, maximum=20, step=1, value=3, label="์์ ๋ช๊ฐ๊น์ง ๋ณด์ฌ์ค๊น์?") | |
], | |
outputs=[ | |
gr.Image(type="pil", label="์ ๋ ฅ ์ด๋ฏธ์ง"), | |
gr.Textbox(label="์ข ๋ฅ (ํ๋ฅ )") | |
], | |
) | |
demo.launch() | |
# demo.launch(share=True) |