Update
Browse files- app.py +16 -18
- gradio_helpers.py +0 -43
- models.py +8 -10
- requirements-cpu.txt +1 -2
- vae-oid.npz +0 -3
app.py
CHANGED
@@ -9,7 +9,6 @@ import os
|
|
9 |
import time
|
10 |
|
11 |
import gradio as gr
|
12 |
-
import jax
|
13 |
import PIL.Image
|
14 |
import gradio_helpers
|
15 |
import models
|
@@ -66,7 +65,8 @@ def compute(image, prompt, model_name, sampler):
|
|
66 |
else:
|
67 |
if not model_name:
|
68 |
raise gr.Error('Models not loaded yet')
|
69 |
-
output = models.generate(model_name, sampler, image, prompt)
|
|
|
70 |
logging.info('output="%s"', output)
|
71 |
|
72 |
width, height = image.size
|
@@ -217,20 +217,20 @@ def create_app():
|
|
217 |
|
218 |
status = gr.Markdown(f'Startup: {datetime.datetime.now()}')
|
219 |
gpu_kind = gr.Markdown(f'GPU=?')
|
220 |
-
demo.load(
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
)
|
228 |
-
def get_gpu_kind():
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
demo.load(get_gpu_kind, None, gpu_kind)
|
234 |
|
235 |
return demo
|
236 |
|
@@ -240,8 +240,6 @@ if __name__ == '__main__':
|
|
240 |
logging.basicConfig(level=logging.INFO,
|
241 |
format='%(asctime)s - %(levelname)s - %(message)s')
|
242 |
|
243 |
-
logging.info('JAX devices: %s', jax.devices())
|
244 |
-
|
245 |
for k, v in os.environ.items():
|
246 |
logging.info('environ["%s"] = %r', k, v)
|
247 |
|
|
|
9 |
import time
|
10 |
|
11 |
import gradio as gr
|
|
|
12 |
import PIL.Image
|
13 |
import gradio_helpers
|
14 |
import models
|
|
|
65 |
else:
|
66 |
if not model_name:
|
67 |
raise gr.Error('Models not loaded yet')
|
68 |
+
# output = models.generate(model_name, sampler, image, prompt)
|
69 |
+
output = 'output'
|
70 |
logging.info('output="%s"', output)
|
71 |
|
72 |
width, height = image.size
|
|
|
217 |
|
218 |
status = gr.Markdown(f'Startup: {datetime.datetime.now()}')
|
219 |
gpu_kind = gr.Markdown(f'GPU=?')
|
220 |
+
# demo.load(
|
221 |
+
# lambda: [
|
222 |
+
# gradio_helpers.get_status(),
|
223 |
+
# make_model(list(gradio_helpers.get_paths())),
|
224 |
+
# ],
|
225 |
+
# None,
|
226 |
+
# [status, model],
|
227 |
+
# )
|
228 |
+
# def get_gpu_kind():
|
229 |
+
# device = jax.devices()[0]
|
230 |
+
# if not gradio_helpers.should_mock() and device.platform != 'gpu':
|
231 |
+
# raise gr.Error('GPU not visible to JAX!')
|
232 |
+
# return f'GPU={device.device_kind}'
|
233 |
+
# demo.load(get_gpu_kind, None, gpu_kind)
|
234 |
|
235 |
return demo
|
236 |
|
|
|
240 |
logging.basicConfig(level=logging.INFO,
|
241 |
format='%(asctime)s - %(levelname)s - %(message)s')
|
242 |
|
|
|
|
|
243 |
for k, v in os.environ.items():
|
244 |
logging.info('environ["%s"] = %r', k, v)
|
245 |
|
gradio_helpers.py
CHANGED
@@ -7,57 +7,14 @@ import functools
|
|
7 |
import logging
|
8 |
import os
|
9 |
import shutil
|
10 |
-
import subprocess
|
11 |
-
import sys
|
12 |
-
import tempfile
|
13 |
import threading
|
14 |
import time
|
15 |
-
import unittest.mock
|
16 |
|
17 |
import huggingface_hub
|
18 |
-
import jax
|
19 |
import numpy as np
|
20 |
import psutil
|
21 |
|
22 |
|
23 |
-
def _clone_git(url, destination_folder, commit_hash=None):
|
24 |
-
subprocess.run([
|
25 |
-
'git', 'clone', '--depth=1',
|
26 |
-
url, destination_folder
|
27 |
-
], check=True)
|
28 |
-
if commit_hash:
|
29 |
-
subprocess.run(
|
30 |
-
['git', '-C', destination_folder, 'checkout', commit_hash], check=True
|
31 |
-
)
|
32 |
-
|
33 |
-
|
34 |
-
def setup():
|
35 |
-
"""Installs big_vision repo and mocks tensorflow_text."""
|
36 |
-
for url, dst_name, commit_hash in (
|
37 |
-
(
|
38 |
-
'https://github.com/google-research/big_vision',
|
39 |
-
'big_vision_repo',
|
40 |
-
None,
|
41 |
-
),
|
42 |
-
):
|
43 |
-
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
|
44 |
-
if os.path.exists(dst_path):
|
45 |
-
print('Found existing "%s" at "%s"' % (url, dst_path))
|
46 |
-
else:
|
47 |
-
print('Cloning "%s" into "%s"' % (url, dst_path))
|
48 |
-
_clone_git(url, dst_path, commit_hash)
|
49 |
-
|
50 |
-
if dst_path not in sys.path:
|
51 |
-
sys.path.insert(0, dst_path)
|
52 |
-
|
53 |
-
# Imported in `big_vision.pp.ops_text` but we don't use it.
|
54 |
-
sys.modules['tensorflow_text'] = unittest.mock.MagicMock()
|
55 |
-
|
56 |
-
|
57 |
-
# Must be run in main app before other BV imports:
|
58 |
-
setup()
|
59 |
-
|
60 |
-
|
61 |
def should_mock():
|
62 |
"""Returns `True` if `MOCK_MODEL=yes` is set in environment."""
|
63 |
return os.environ.get('MOCK_MODEL') == 'yes'
|
|
|
7 |
import logging
|
8 |
import os
|
9 |
import shutil
|
|
|
|
|
|
|
10 |
import threading
|
11 |
import time
|
|
|
12 |
|
13 |
import huggingface_hub
|
|
|
14 |
import numpy as np
|
15 |
import psutil
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def should_mock():
|
19 |
"""Returns `True` if `MOCK_MODEL=yes` is set in environment."""
|
20 |
return os.environ.get('MOCK_MODEL') == 'yes'
|
models.py
CHANGED
@@ -8,10 +8,9 @@ import PIL.Image
|
|
8 |
|
9 |
# pylint: disable=g-bad-import-order
|
10 |
import gradio_helpers
|
11 |
-
import paligemma_bv
|
12 |
|
13 |
|
14 |
-
ORGANIZATION = '
|
15 |
BASE_MODELS = [
|
16 |
('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
|
17 |
('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
|
@@ -42,7 +41,6 @@ MODELS_INFO = {
|
|
42 |
|
43 |
MODELS_RES_SEQ = {
|
44 |
'paligemma-3b-mix-224': (224, 256),
|
45 |
-
'paligemma-3b-mix-448': (448, 512),
|
46 |
}
|
47 |
|
48 |
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
|
@@ -50,13 +48,13 @@ MODELS_RES_SEQ = {
|
|
50 |
# A single bf16 is about 5860 MB.
|
51 |
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
|
52 |
|
53 |
-
config = paligemma_bv.PaligemmaConfig(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
)
|
60 |
|
61 |
|
62 |
def get_cached_model(
|
|
|
8 |
|
9 |
# pylint: disable=g-bad-import-order
|
10 |
import gradio_helpers
|
|
|
11 |
|
12 |
|
13 |
+
ORGANIZATION = 'abetlen'
|
14 |
BASE_MODELS = [
|
15 |
('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
|
16 |
('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
|
|
|
41 |
|
42 |
MODELS_RES_SEQ = {
|
43 |
'paligemma-3b-mix-224': (224, 256),
|
|
|
44 |
}
|
45 |
|
46 |
# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
|
|
|
48 |
# A single bf16 is about 5860 MB.
|
49 |
MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
|
50 |
|
51 |
+
# config = paligemma_bv.PaligemmaConfig(
|
52 |
+
# ckpt='', # will be set below
|
53 |
+
# res=224,
|
54 |
+
# text_len=64,
|
55 |
+
# tokenizer='gemma(tokensets=("loc", "seg"))',
|
56 |
+
# vocab_size=256_000 + 1024 + 128,
|
57 |
+
# )
|
58 |
|
59 |
|
60 |
def get_cached_model(
|
requirements-cpu.txt
CHANGED
@@ -2,8 +2,6 @@ einops
|
|
2 |
flax
|
3 |
gradio
|
4 |
huggingface-hub
|
5 |
-
jax
|
6 |
-
jaxlib
|
7 |
ml_collections
|
8 |
numpy
|
9 |
orbax-checkpoint
|
@@ -11,3 +9,4 @@ Pillow
|
|
11 |
psutil
|
12 |
sentencepiece
|
13 |
tensorflow
|
|
|
|
2 |
flax
|
3 |
gradio
|
4 |
huggingface-hub
|
|
|
|
|
5 |
ml_collections
|
6 |
numpy
|
7 |
orbax-checkpoint
|
|
|
9 |
psutil
|
10 |
sentencepiece
|
11 |
tensorflow
|
12 |
+
git+https://github.com/abetlen/llama-cpp-python.git@add-paligemma-support
|
vae-oid.npz
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:5586010257b8536dddefab65e7755077f21d5672d5674dacf911f73ae95a4447
|
3 |
-
size 8479556
|
|
|
|
|
|
|
|