yuukicammy commited on
Commit
a6394d0
·
1 Parent(s): d3e071f

Changed to be able to run in Modal.

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md DELETED
@@ -1,11 +0,0 @@
1
- ---
2
- title: Vit Gpt2 Image Captioning
3
- emoji: 👀
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.27.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
 
 
 
 
 
 
 
 
 
 
 
 
app.py DELETED
@@ -1,4 +0,0 @@
1
- import gradio as gr
2
-
3
- gr.Interface.load("models/nlpconnect/vit-gpt2-image-captioning").launch()
4
-
 
 
 
 
 
frontend/app.jsx ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function Spinner({ config }) {
2
+ const ref = React.useRef(null);
3
+
4
+ React.useEffect(() => {
5
+ const spinner = new Spin.Spinner({
6
+ lines: 13,
7
+ color: "#ffffff",
8
+ ...config,
9
+ });
10
+ spinner.spin(ref.current);
11
+ return () => spinner.stop();
12
+ }, [ref]);
13
+
14
+ return <span ref={ref} />;
15
+ }
16
+
17
+ function Result({ callId, selectedFile }) {
18
+ const [result, setResult] = React.useState();
19
+ const [intervalId, setIntervalId] = React.useState();
20
+
21
+ React.useEffect(() => {
22
+ if (result) {
23
+ clearInterval(intervalId);
24
+ return;
25
+ }
26
+
27
+ const _intervalID = setInterval(async () => {
28
+ const resp = await fetch(`/result/${callId}`);
29
+ if (resp.status === 200) {
30
+ setResult(await resp.json());
31
+ }
32
+ }, 100);
33
+
34
+ setIntervalId(_intervalID);
35
+
36
+ return () => clearInterval(intervalId);
37
+ }, [result]);
38
+
39
+ return (
40
+ <div class="flex items-center content-center justify-center space-x-4 ">
41
+ <img src={URL.createObjectURL(selectedFile)} class="h-[300px]" />
42
+ {!result && <Spinner config={{}} />}
43
+ {result && (
44
+ <p class="w-[200px] p-4 bg-zinc-200 rounded-lg whitespace-pre-wrap text-xs font-mono">
45
+ {JSON.stringify(result, undefined, 1)}
46
+ </p>
47
+ )}
48
+ </div>
49
+ );
50
+ }
51
+
52
+ function Form({ onSubmit, onFileSelect, selectedFile }) {
53
+ return (
54
+ <form class="flex flex-col space-y-4 items-center">
55
+ <div class="text-2xl font-semibold text-gray-700"> ViT-GPT2 Image Captioning </div>
56
+ <input
57
+ accept="image/*"
58
+ type="file"
59
+ name="file"
60
+ onChange={onFileSelect}
61
+ class="block w-full text-sm text-gray-900 bg-gray-50 rounded-lg border border-gray-300 cursor-pointer"
62
+ />
63
+ {selectedFile ? (
64
+ <img src={URL.createObjectURL(selectedFile)} class="h-[300px]" />
65
+ ) : null}
66
+ <div>
67
+ <button
68
+ type="button"
69
+ onClick={onSubmit}
70
+ disabled={!selectedFile}
71
+ class="bg-indigo-400 disabled:bg-zinc-500 hover:bg-indigo-600 text-white font-bold py-2 px-4 rounded text-sm"
72
+ >
73
+ Upload
74
+ </button>
75
+ </div>
76
+ </form>
77
+ );
78
+ }
79
+
80
+ function App() {
81
+ const [selectedFile, setSelectedFile] = React.useState();
82
+ const [callId, setCallId] = React.useState();
83
+
84
+ const handleSubmission = async () => {
85
+ const formData = new FormData();
86
+ formData.append("image", selectedFile);
87
+
88
+ const resp = await fetch("/parse", {
89
+ method: "POST",
90
+ body: formData,
91
+ });
92
+
93
+ if (resp.status !== 200) {
94
+ throw new Error("An error occurred: " + resp.status);
95
+ }
96
+ const body = await resp.json();
97
+ setCallId(body.call_id);
98
+ };
99
+
100
+ return (
101
+ <div class="absolute inset-0 bg-gradient-to-r from-indigo-300 via-purple-300 to-pink-300">
102
+ <div class="mx-auto max-w-md py-8">
103
+ <main class="rounded-xl bg-white p-6">
104
+ {!callId && (
105
+ <Form
106
+ onSubmit={handleSubmission}
107
+ onFileSelect={(e) => setSelectedFile(e.target.files[0])}
108
+ selectedFile={selectedFile}
109
+ />
110
+ )}
111
+ {callId && <Result callId={callId} selectedFile={selectedFile} />}
112
+ </main>
113
+ </div>
114
+ </div>
115
+ );
116
+ }
117
+
118
+ const container = document.getElementById("react");
119
+ ReactDOM.createRoot(container).render(<App />);
frontend/index.html ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="utf-8" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
7
+ <title>ViT-GPT2 Image Captioning powered by Modal</title>
8
+ <script src="https://cdn.tailwindcss.com"></script>
9
+ <script crossorigin src="https://unpkg.com/react@18/umd/react.development.js"></script>
10
+ <script crossorigin src="https://unpkg.com/react-dom@18/umd/react-dom.development.js"></script>
11
+ <script crossorigin src="https://unpkg.com/@babel/standalone/babel.min.js"></script>
12
+ <script crossorigin src="https://spin.js.org/spin.umd.js"></script>
13
+ <link rel="stylesheet" href="https://spin.js.org/spin.css" />
14
+ </head>
15
+
16
+ <body class="bg-gray-50">
17
+ <noscript>You must have JavaScript enabled to use this app.</noscript>
18
+ <script type="text/babel" src="/app.jsx"></script>
19
+ <div id="react"></div>
20
+ </body>
21
+
22
+ </html>
vit_gpt2_image_caption_webapp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import fastapi
4
+ import fastapi.staticfiles
5
+
6
+ from modal import Function, Mount, Stub, asgi_app
7
+
8
+ stub = Stub("vit-gpt2-image-caption-webapp")
9
+ web_app = fastapi.FastAPI()
10
+
11
+
12
+ @web_app.post("/parse")
13
+ async def parse(request: fastapi.Request):
14
+ predict_step = Function.lookup("vit-gpt2-image-captioning", "predict_step")
15
+
16
+ form = await request.form()
17
+ image = await form["image"].read() # type: ignore
18
+ call = predict_step.spawn(image)
19
+ return {"call_id": call.object_id}
20
+
21
+
22
+ @web_app.get("/result/{call_id}")
23
+ async def poll_results(call_id: str):
24
+ from modal.functions import FunctionCall
25
+
26
+ function_call = FunctionCall.from_id(call_id)
27
+ try:
28
+ result = function_call.get(timeout=0)
29
+ except TimeoutError:
30
+ return fastapi.responses.JSONResponse(content="", status_code=202)
31
+
32
+ return result[0]
33
+
34
+
35
+ assets_path = Path(__file__).parent / "frontend"
36
+
37
+
38
+ @stub.function(mounts=[Mount.from_local_dir(assets_path, remote_path="/assets")])
39
+ @asgi_app()
40
+ def wrapper():
41
+ web_app.mount("/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True))
42
+
43
+ return web_app
vit_gpt2_image_captioning.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
2
+
3
+ import urllib.request
4
+ import modal
5
+
6
+ stub = modal.Stub("vit-gpt2-image-captioning")
7
+ volume = modal.SharedVolume().persist("shared_vol")
8
+ CACHE_PATH = "/root/model_cache"
9
+
10
+
11
+ @stub.function(
12
+ gpu="any",
13
+ image=modal.Image.debian_slim().pip_install("Pillow", "transformers", "torch"),
14
+ shared_volumes={CACHE_PATH: volume},
15
+ retries=3,
16
+ )
17
+ def predict_step(image):
18
+ import io
19
+ from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
20
+ import torch
21
+ from PIL import Image
22
+
23
+ model = VisionEncoderDecoderModel.from_pretrained(
24
+ "nlpconnect/vit-gpt2-image-captioning"
25
+ )
26
+ feature_extractor = ViTImageProcessor.from_pretrained(
27
+ "nlpconnect/vit-gpt2-image-captioning"
28
+ )
29
+ tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ model.to(device)
33
+
34
+ max_length = 16
35
+ num_beams = 4
36
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
37
+ input_img = Image.open(io.BytesIO(image))
38
+ pixel_values = feature_extractor(
39
+ images=[input_img], return_tensors="pt"
40
+ ).pixel_values
41
+ pixel_values = pixel_values.to(device)
42
+
43
+ output_ids = model.generate(pixel_values, **gen_kwargs)
44
+
45
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
46
+ preds = [pred.strip() for pred in preds]
47
+ return preds
48
+
49
+
50
+ @stub.local_entrypoint()
51
+ def main():
52
+ from pathlib import Path
53
+
54
+ image_filepath = Path(__file__).parent / "sample.png"
55
+ if image_filepath.exists():
56
+ with open(image_filepath, "rb") as f:
57
+ image = f.read()
58
+ else:
59
+ try:
60
+ image = urllib.request.urlopen(
61
+ "https://drive.google.com/uc?id=0B0TjveMhQDhgLTlpOENiOTZ6Y00&export=download"
62
+ ).read()
63
+ except urllib.error.URLError as e:
64
+ print(e.reason)
65
+ print(predict_step.call(image)[0])