Commit
·
a6394d0
1
Parent(s):
d3e071f
Changed to be able to run in Modal.
Browse files- .gitignore +1 -0
- README.md +0 -11
- app.py +0 -4
- frontend/app.jsx +119 -0
- frontend/index.html +22 -0
- vit_gpt2_image_caption_webapp.py +43 -0
- vit_gpt2_image_captioning.py +65 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: Vit Gpt2 Image Captioning
|
3 |
-
emoji: 👀
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: blue
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.27.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
gr.Interface.load("models/nlpconnect/vit-gpt2-image-captioning").launch()
|
4 |
-
|
|
|
|
|
|
|
|
|
|
frontend/app.jsx
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
function Spinner({ config }) {
|
2 |
+
const ref = React.useRef(null);
|
3 |
+
|
4 |
+
React.useEffect(() => {
|
5 |
+
const spinner = new Spin.Spinner({
|
6 |
+
lines: 13,
|
7 |
+
color: "#ffffff",
|
8 |
+
...config,
|
9 |
+
});
|
10 |
+
spinner.spin(ref.current);
|
11 |
+
return () => spinner.stop();
|
12 |
+
}, [ref]);
|
13 |
+
|
14 |
+
return <span ref={ref} />;
|
15 |
+
}
|
16 |
+
|
17 |
+
function Result({ callId, selectedFile }) {
|
18 |
+
const [result, setResult] = React.useState();
|
19 |
+
const [intervalId, setIntervalId] = React.useState();
|
20 |
+
|
21 |
+
React.useEffect(() => {
|
22 |
+
if (result) {
|
23 |
+
clearInterval(intervalId);
|
24 |
+
return;
|
25 |
+
}
|
26 |
+
|
27 |
+
const _intervalID = setInterval(async () => {
|
28 |
+
const resp = await fetch(`/result/${callId}`);
|
29 |
+
if (resp.status === 200) {
|
30 |
+
setResult(await resp.json());
|
31 |
+
}
|
32 |
+
}, 100);
|
33 |
+
|
34 |
+
setIntervalId(_intervalID);
|
35 |
+
|
36 |
+
return () => clearInterval(intervalId);
|
37 |
+
}, [result]);
|
38 |
+
|
39 |
+
return (
|
40 |
+
<div class="flex items-center content-center justify-center space-x-4 ">
|
41 |
+
<img src={URL.createObjectURL(selectedFile)} class="h-[300px]" />
|
42 |
+
{!result && <Spinner config={{}} />}
|
43 |
+
{result && (
|
44 |
+
<p class="w-[200px] p-4 bg-zinc-200 rounded-lg whitespace-pre-wrap text-xs font-mono">
|
45 |
+
{JSON.stringify(result, undefined, 1)}
|
46 |
+
</p>
|
47 |
+
)}
|
48 |
+
</div>
|
49 |
+
);
|
50 |
+
}
|
51 |
+
|
52 |
+
function Form({ onSubmit, onFileSelect, selectedFile }) {
|
53 |
+
return (
|
54 |
+
<form class="flex flex-col space-y-4 items-center">
|
55 |
+
<div class="text-2xl font-semibold text-gray-700"> ViT-GPT2 Image Captioning </div>
|
56 |
+
<input
|
57 |
+
accept="image/*"
|
58 |
+
type="file"
|
59 |
+
name="file"
|
60 |
+
onChange={onFileSelect}
|
61 |
+
class="block w-full text-sm text-gray-900 bg-gray-50 rounded-lg border border-gray-300 cursor-pointer"
|
62 |
+
/>
|
63 |
+
{selectedFile ? (
|
64 |
+
<img src={URL.createObjectURL(selectedFile)} class="h-[300px]" />
|
65 |
+
) : null}
|
66 |
+
<div>
|
67 |
+
<button
|
68 |
+
type="button"
|
69 |
+
onClick={onSubmit}
|
70 |
+
disabled={!selectedFile}
|
71 |
+
class="bg-indigo-400 disabled:bg-zinc-500 hover:bg-indigo-600 text-white font-bold py-2 px-4 rounded text-sm"
|
72 |
+
>
|
73 |
+
Upload
|
74 |
+
</button>
|
75 |
+
</div>
|
76 |
+
</form>
|
77 |
+
);
|
78 |
+
}
|
79 |
+
|
80 |
+
function App() {
|
81 |
+
const [selectedFile, setSelectedFile] = React.useState();
|
82 |
+
const [callId, setCallId] = React.useState();
|
83 |
+
|
84 |
+
const handleSubmission = async () => {
|
85 |
+
const formData = new FormData();
|
86 |
+
formData.append("image", selectedFile);
|
87 |
+
|
88 |
+
const resp = await fetch("/parse", {
|
89 |
+
method: "POST",
|
90 |
+
body: formData,
|
91 |
+
});
|
92 |
+
|
93 |
+
if (resp.status !== 200) {
|
94 |
+
throw new Error("An error occurred: " + resp.status);
|
95 |
+
}
|
96 |
+
const body = await resp.json();
|
97 |
+
setCallId(body.call_id);
|
98 |
+
};
|
99 |
+
|
100 |
+
return (
|
101 |
+
<div class="absolute inset-0 bg-gradient-to-r from-indigo-300 via-purple-300 to-pink-300">
|
102 |
+
<div class="mx-auto max-w-md py-8">
|
103 |
+
<main class="rounded-xl bg-white p-6">
|
104 |
+
{!callId && (
|
105 |
+
<Form
|
106 |
+
onSubmit={handleSubmission}
|
107 |
+
onFileSelect={(e) => setSelectedFile(e.target.files[0])}
|
108 |
+
selectedFile={selectedFile}
|
109 |
+
/>
|
110 |
+
)}
|
111 |
+
{callId && <Result callId={callId} selectedFile={selectedFile} />}
|
112 |
+
</main>
|
113 |
+
</div>
|
114 |
+
</div>
|
115 |
+
);
|
116 |
+
}
|
117 |
+
|
118 |
+
const container = document.getElementById("react");
|
119 |
+
ReactDOM.createRoot(container).render(<App />);
|
frontend/index.html
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
|
4 |
+
<head>
|
5 |
+
<meta charset="utf-8" />
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
7 |
+
<title>ViT-GPT2 Image Captioning powered by Modal</title>
|
8 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
9 |
+
<script crossorigin src="https://unpkg.com/react@18/umd/react.development.js"></script>
|
10 |
+
<script crossorigin src="https://unpkg.com/react-dom@18/umd/react-dom.development.js"></script>
|
11 |
+
<script crossorigin src="https://unpkg.com/@babel/standalone/babel.min.js"></script>
|
12 |
+
<script crossorigin src="https://spin.js.org/spin.umd.js"></script>
|
13 |
+
<link rel="stylesheet" href="https://spin.js.org/spin.css" />
|
14 |
+
</head>
|
15 |
+
|
16 |
+
<body class="bg-gray-50">
|
17 |
+
<noscript>You must have JavaScript enabled to use this app.</noscript>
|
18 |
+
<script type="text/babel" src="/app.jsx"></script>
|
19 |
+
<div id="react"></div>
|
20 |
+
</body>
|
21 |
+
|
22 |
+
</html>
|
vit_gpt2_image_caption_webapp.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import fastapi
|
4 |
+
import fastapi.staticfiles
|
5 |
+
|
6 |
+
from modal import Function, Mount, Stub, asgi_app
|
7 |
+
|
8 |
+
stub = Stub("vit-gpt2-image-caption-webapp")
|
9 |
+
web_app = fastapi.FastAPI()
|
10 |
+
|
11 |
+
|
12 |
+
@web_app.post("/parse")
|
13 |
+
async def parse(request: fastapi.Request):
|
14 |
+
predict_step = Function.lookup("vit-gpt2-image-captioning", "predict_step")
|
15 |
+
|
16 |
+
form = await request.form()
|
17 |
+
image = await form["image"].read() # type: ignore
|
18 |
+
call = predict_step.spawn(image)
|
19 |
+
return {"call_id": call.object_id}
|
20 |
+
|
21 |
+
|
22 |
+
@web_app.get("/result/{call_id}")
|
23 |
+
async def poll_results(call_id: str):
|
24 |
+
from modal.functions import FunctionCall
|
25 |
+
|
26 |
+
function_call = FunctionCall.from_id(call_id)
|
27 |
+
try:
|
28 |
+
result = function_call.get(timeout=0)
|
29 |
+
except TimeoutError:
|
30 |
+
return fastapi.responses.JSONResponse(content="", status_code=202)
|
31 |
+
|
32 |
+
return result[0]
|
33 |
+
|
34 |
+
|
35 |
+
assets_path = Path(__file__).parent / "frontend"
|
36 |
+
|
37 |
+
|
38 |
+
@stub.function(mounts=[Mount.from_local_dir(assets_path, remote_path="/assets")])
|
39 |
+
@asgi_app()
|
40 |
+
def wrapper():
|
41 |
+
web_app.mount("/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True))
|
42 |
+
|
43 |
+
return web_app
|
vit_gpt2_image_captioning.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
|
2 |
+
|
3 |
+
import urllib.request
|
4 |
+
import modal
|
5 |
+
|
6 |
+
stub = modal.Stub("vit-gpt2-image-captioning")
|
7 |
+
volume = modal.SharedVolume().persist("shared_vol")
|
8 |
+
CACHE_PATH = "/root/model_cache"
|
9 |
+
|
10 |
+
|
11 |
+
@stub.function(
|
12 |
+
gpu="any",
|
13 |
+
image=modal.Image.debian_slim().pip_install("Pillow", "transformers", "torch"),
|
14 |
+
shared_volumes={CACHE_PATH: volume},
|
15 |
+
retries=3,
|
16 |
+
)
|
17 |
+
def predict_step(image):
|
18 |
+
import io
|
19 |
+
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
|
20 |
+
import torch
|
21 |
+
from PIL import Image
|
22 |
+
|
23 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
24 |
+
"nlpconnect/vit-gpt2-image-captioning"
|
25 |
+
)
|
26 |
+
feature_extractor = ViTImageProcessor.from_pretrained(
|
27 |
+
"nlpconnect/vit-gpt2-image-captioning"
|
28 |
+
)
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
30 |
+
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
model.to(device)
|
33 |
+
|
34 |
+
max_length = 16
|
35 |
+
num_beams = 4
|
36 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
37 |
+
input_img = Image.open(io.BytesIO(image))
|
38 |
+
pixel_values = feature_extractor(
|
39 |
+
images=[input_img], return_tensors="pt"
|
40 |
+
).pixel_values
|
41 |
+
pixel_values = pixel_values.to(device)
|
42 |
+
|
43 |
+
output_ids = model.generate(pixel_values, **gen_kwargs)
|
44 |
+
|
45 |
+
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
46 |
+
preds = [pred.strip() for pred in preds]
|
47 |
+
return preds
|
48 |
+
|
49 |
+
|
50 |
+
@stub.local_entrypoint()
|
51 |
+
def main():
|
52 |
+
from pathlib import Path
|
53 |
+
|
54 |
+
image_filepath = Path(__file__).parent / "sample.png"
|
55 |
+
if image_filepath.exists():
|
56 |
+
with open(image_filepath, "rb") as f:
|
57 |
+
image = f.read()
|
58 |
+
else:
|
59 |
+
try:
|
60 |
+
image = urllib.request.urlopen(
|
61 |
+
"https://drive.google.com/uc?id=0B0TjveMhQDhgLTlpOENiOTZ6Y00&export=download"
|
62 |
+
).read()
|
63 |
+
except urllib.error.URLError as e:
|
64 |
+
print(e.reason)
|
65 |
+
print(predict_step.call(image)[0])
|