Spaces:
Sleeping
Sleeping
Include code of llmapp from Github
Browse files- .gitignore +163 -0
- Dockerfile-llm-app +20 -0
- Makefile +2 -0
- README.md +22 -2
- app.py +0 -4
- docker-compose.yml +31 -0
- llmlib/README.md +0 -0
- llmlib/llmlib/__init__.py +0 -0
- llmlib/llmlib/base_llm.py +36 -0
- llmlib/llmlib/bundler.py +61 -0
- llmlib/llmlib/bundler_request.py +10 -0
- llmlib/llmlib/gemini/__init__.py +0 -0
- llmlib/llmlib/gemini/media_description.py +162 -0
- llmlib/llmlib/llama3/.gitignore +1 -0
- llmlib/llmlib/llama3/README.md +5 -0
- llmlib/llmlib/llama3/__init__.py +3 -0
- llmlib/llmlib/llama3/llama3_vision_8b.py +67 -0
- llmlib/llmlib/minicpm.py +105 -0
- llmlib/llmlib/model_registry.py +28 -0
- llmlib/llmlib/openai/openai_completion.py +78 -0
- llmlib/llmlib/phi3/phi3.py +157 -0
- llmlib/llmlib/phi3/phi35.py +72 -0
- llmlib/llmlib/pixtral_demo.py +26 -0
- llmlib/llmlib/rest_api/__init__.py +0 -0
- llmlib/llmlib/rest_api/restapi_client.py +97 -0
- llmlib/llmlib/rest_api/restapi_server.py +56 -0
- llmlib/llmlib/runtime.py +21 -0
- llmlib/llmlib/whisper.py +74 -0
- llmlib/pyproject.toml +21 -0
- login_mask_simple.py +50 -0
- readme/llm-app-demo.mp4 +0 -0
- requirements.txt +12 -0
- rest_api.py +4 -0
- st_app.py +97 -0
- test-files/.gitignore +5 -0
- test-files/toxicity-prompt.txt +1 -0
- tests/__init__.py +0 -0
- tests/helpers.py +123 -0
- tests/test_bundler.py +90 -0
- tests/test_gemini.py +25 -0
- tests/test_llama3.py +26 -0
- tests/test_minicpm.py +16 -0
- tests/test_openai.py +49 -0
- tests/test_phi_3.py +70 -0
- tests/test_rest_api.py +40 -0
- tests/test_whisper.py +40 -0
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
models/
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
cover/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
111 |
+
.pdm.toml
|
112 |
+
.pdm-python
|
113 |
+
.pdm-build/
|
114 |
+
|
115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
116 |
+
__pypackages__/
|
117 |
+
|
118 |
+
# Celery stuff
|
119 |
+
celerybeat-schedule
|
120 |
+
celerybeat.pid
|
121 |
+
|
122 |
+
# SageMath parsed files
|
123 |
+
*.sage.py
|
124 |
+
|
125 |
+
# Environments
|
126 |
+
.env
|
127 |
+
.venv
|
128 |
+
env/
|
129 |
+
venv/
|
130 |
+
ENV/
|
131 |
+
env.bak/
|
132 |
+
venv.bak/
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
Dockerfile-llm-app
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN apt-get update
|
6 |
+
RUN apt-get install -y build-essential
|
7 |
+
RUN apt-get install -y git
|
8 |
+
|
9 |
+
COPY requirements.txt requirements.txt
|
10 |
+
RUN pip install -r requirements.txt
|
11 |
+
RUN CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install "llama-cpp-python<=0.2.79.0"
|
12 |
+
COPY *.py ./
|
13 |
+
ADD llmlib ./llmlib
|
14 |
+
RUN pip install -e llmlib
|
15 |
+
ADD .streamlit .streamlit
|
16 |
+
|
17 |
+
#CMD [ "python", "--version"]
|
18 |
+
# CMD ["nvidia-smi"]
|
19 |
+
# CMD ["nvcc", "--version"]
|
20 |
+
CMD [ "python", "-m", "streamlit", "run", "st_app.py", "--server.port", "8020"]
|
Makefile
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
run_rest_api:
|
2 |
+
fastapi dev rest_api.py --port 8030
|
README.md
CHANGED
@@ -5,8 +5,28 @@ colorFrom: red
|
|
5 |
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.41.1
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
colorTo: purple
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.41.1
|
8 |
+
app_file: st_app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# LLM Multimodal Vibe-Check
|
13 |
+
We use this streamlit app to chat with different multimodal open-source and propietary LLMs. The idea is to quickly assess qualitatively (vibe-check) whether the model understands the nuance of harmful language.
|
14 |
+
|
15 |
+
https://github.com/user-attachments/assets/2fb49053-651c-4cc9-b102-92a392a3c473
|
16 |
+
|
17 |
+
## Run Streamlit App
|
18 |
+
In the `docker-compose.yml` file, you will need to change the volume to point to your own huggingface model cache. To run the app, use the following command:
|
19 |
+
```bash
|
20 |
+
docker compose up videoapp
|
21 |
+
```
|
22 |
+
|
23 |
+
### Run Only Inference Server
|
24 |
+
```bash
|
25 |
+
docker compose up rest_api
|
26 |
+
```
|
27 |
+
|
28 |
+
## Structure
|
29 |
+
* Each multimodal LLM has a different way of consuming image(s). This codebase unifies the different interfaces e.g. of Phi-3, MinCPM, OpenAI GPT-4o, etc. This is done with a single base class `LLM` (interface) which is then implemented by each concrete model. You can find these implementation in the directory `llmlib/llmlib/`.
|
30 |
+
* The open-source implementation are based on the `transformers` library. I have experimented with `vLLM`, but it made the GPU run OOM. More fiddling is needed.
|
31 |
+
* I have extracted a REST API using `FastAPI` to decouple the frontend streamlit code from the inference server.
|
32 |
+
* The app supports small open-source models atm, because the inference server is running a single 24GB VRAM GPU. We will hopefully scale this backend up soon.
|
app.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
|
3 |
-
x = st.slider('Select a value')
|
4 |
-
st.write(x, 'squared is', x * x)
|
|
|
|
|
|
|
|
|
|
docker-compose.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
x-common-gpu: &common-gpu
|
2 |
+
build:
|
3 |
+
dockerfile: Dockerfile-llm-app
|
4 |
+
environment:
|
5 |
+
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
6 |
+
- HF_HOME=/app/.cache/huggingface
|
7 |
+
- HF_TOKEN=${HF_TOKEN}
|
8 |
+
- LLMS_REST_API_KEY=${LLMS_REST_API_KEY}
|
9 |
+
- BUGSNAG_API_KEY=${BUGSNAG_API_KEY}
|
10 |
+
deploy:
|
11 |
+
resources:
|
12 |
+
reservations:
|
13 |
+
devices:
|
14 |
+
- driver: nvidia
|
15 |
+
count: all
|
16 |
+
capabilities: [gpu]
|
17 |
+
volumes:
|
18 |
+
- /home/tomasruiz/.cache/huggingface:/app/.cache/huggingface
|
19 |
+
|
20 |
+
services:
|
21 |
+
|
22 |
+
llmapp:
|
23 |
+
<<: *common-gpu
|
24 |
+
ports:
|
25 |
+
- "8020:8020"
|
26 |
+
rest_api:
|
27 |
+
<<: *common-gpu
|
28 |
+
ports:
|
29 |
+
- "8030:8030"
|
30 |
+
command: fastapi run rest_api.py --port 8030
|
31 |
+
hostname: rest_api
|
llmlib/README.md
ADDED
File without changes
|
llmlib/llmlib/__init__.py
ADDED
File without changes
|
llmlib/llmlib/base_llm.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Literal, Self
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class Message:
|
11 |
+
role: Literal["user", "assistant"]
|
12 |
+
msg: str
|
13 |
+
img_name: str | None = None
|
14 |
+
img: Image.Image | None = None
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def from_prompt(cls, prompt: str) -> Self:
|
18 |
+
return cls(role="user", msg=prompt)
|
19 |
+
|
20 |
+
|
21 |
+
class LLM:
|
22 |
+
model_id: str
|
23 |
+
requires_gpu_exclusively: bool = False
|
24 |
+
|
25 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
def complete_batch(self, batch: list[list[Message]]) -> list[str]:
|
29 |
+
raise NotImplementedError
|
30 |
+
|
31 |
+
def video_prompt(self, video_path: Path, prompt: str) -> str:
|
32 |
+
raise NotImplementedError
|
33 |
+
|
34 |
+
@classmethod
|
35 |
+
def get_warnings(cls) -> list[str]:
|
36 |
+
return []
|
llmlib/llmlib/bundler.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from .bundler_request import BundlerRequest
|
5 |
+
from .model_registry import ModelEntry, ModelRegistry
|
6 |
+
from .base_llm import LLM
|
7 |
+
import torch
|
8 |
+
import gc
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Bundler:
|
16 |
+
"""Makes sure that only 1 model occupies the GPU at a time."""
|
17 |
+
|
18 |
+
registry: ModelRegistry = field(default_factory=ModelRegistry)
|
19 |
+
model_on_gpu: LLM | None = None
|
20 |
+
id2_nongpu_model: dict[str, LLM] = field(default_factory=dict)
|
21 |
+
|
22 |
+
def id_of_model_on_gpu(self) -> str | None:
|
23 |
+
return None if self.model_on_gpu is None else self.model_on_gpu.model_id
|
24 |
+
|
25 |
+
def get_response(self, req: BundlerRequest) -> str:
|
26 |
+
e: ModelEntry = self.registry.get_entry(model_id=req.model_id)
|
27 |
+
model: LLM = self._get_model_instance(e=e)
|
28 |
+
return model.complete_msgs2(req.msgs)
|
29 |
+
|
30 |
+
def _get_model_instance(self, e: ModelEntry) -> LLM:
|
31 |
+
if e.clazz.requires_gpu_exclusively:
|
32 |
+
self.set_model_on_gpu(model_id=e.model_id)
|
33 |
+
model: LLM = self.model_on_gpu
|
34 |
+
else:
|
35 |
+
if e.model_id not in self.id2_nongpu_model:
|
36 |
+
self.id2_nongpu_model[e.model_id] = e.ctor()
|
37 |
+
model: LLM = self.id2_nongpu_model[e.model_id]
|
38 |
+
return model
|
39 |
+
|
40 |
+
def set_model_on_gpu(self, model_id: str) -> None:
|
41 |
+
if (
|
42 |
+
self.id_of_model_on_gpu() is not None
|
43 |
+
and self.id_of_model_on_gpu() == model_id
|
44 |
+
):
|
45 |
+
return
|
46 |
+
assert model_id in self.registry.all_model_ids()
|
47 |
+
|
48 |
+
e: ModelEntry = self.registry.get_entry(model_id)
|
49 |
+
if not e.clazz.requires_gpu_exclusively:
|
50 |
+
logger.info(
|
51 |
+
"Model does not require GPU exclusively. Ignoring set_model_on_gpu() call."
|
52 |
+
)
|
53 |
+
return
|
54 |
+
|
55 |
+
self.clear_model_on_gpu()
|
56 |
+
self.model_on_gpu = e.ctor()
|
57 |
+
|
58 |
+
def clear_model_on_gpu(self):
|
59 |
+
self.model_on_gpu = None
|
60 |
+
gc.collect()
|
61 |
+
torch.cuda.empty_cache()
|
llmlib/llmlib/bundler_request.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_llm import Message
|
2 |
+
|
3 |
+
|
4 |
+
from dataclasses import dataclass
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class BundlerRequest:
|
9 |
+
model_id: str
|
10 |
+
msgs: list[Message]
|
llmlib/llmlib/gemini/__init__.py
ADDED
File without changes
|
llmlib/llmlib/gemini/media_description.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Based on https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/video-understanding
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from logging import getLogger
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Literal
|
9 |
+
from google.cloud import storage
|
10 |
+
from google.cloud.storage import transfer_manager
|
11 |
+
import proto
|
12 |
+
from vertexai.generative_models import (
|
13 |
+
GenerativeModel,
|
14 |
+
Part,
|
15 |
+
HarmCategory,
|
16 |
+
HarmBlockThreshold,
|
17 |
+
GenerationResponse,
|
18 |
+
)
|
19 |
+
|
20 |
+
import vertexai
|
21 |
+
|
22 |
+
logger = getLogger(__name__)
|
23 |
+
|
24 |
+
project_id = "css-lehrbereich" # from google cloud console
|
25 |
+
frankfurt = "europe-west3" # https://cloud.google.com/about/locations#europe
|
26 |
+
|
27 |
+
|
28 |
+
class Buckets:
|
29 |
+
temp = "css-temp-bucket-for-vertex"
|
30 |
+
output = "css-vertex-output"
|
31 |
+
|
32 |
+
|
33 |
+
def storage_uri(bucket: str, blob_name: str) -> str:
|
34 |
+
"""blob_name starts without a slash"""
|
35 |
+
return "gs://%s/%s" % (bucket, blob_name)
|
36 |
+
|
37 |
+
|
38 |
+
class Models:
|
39 |
+
gemini_pro = "models/gemini-1.5-pro"
|
40 |
+
gemini_flash = "models/gemini-1.5-flash"
|
41 |
+
|
42 |
+
|
43 |
+
available_models = [Models.gemini_pro, Models.gemini_flash]
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class Request:
|
48 |
+
media_files: list[Path]
|
49 |
+
model_name: Literal[Models.gemini_pro, Models.gemini_flash] = Models.gemini_pro
|
50 |
+
prompt: str = "Describe this video in detail."
|
51 |
+
|
52 |
+
def fetch_media_description(self) -> str:
|
53 |
+
return fetch_media_description(self)
|
54 |
+
|
55 |
+
|
56 |
+
def fetch_media_description(req: Request) -> str:
|
57 |
+
# TODO: Always delete the video in the end. Perhaps use finally block.
|
58 |
+
blobs = upload_files(files=req.media_files)
|
59 |
+
|
60 |
+
init_vertex()
|
61 |
+
model = GenerativeModel(req.model_name)
|
62 |
+
|
63 |
+
prompt = req.prompt
|
64 |
+
logger.info("Calling the Google API. model_name='%s'", req.model_name)
|
65 |
+
contents = [
|
66 |
+
Part.from_uri(storage_uri(Buckets.temp, b.name), mime_type=mime_type(b.name))
|
67 |
+
for b in blobs
|
68 |
+
]
|
69 |
+
contents.append(prompt)
|
70 |
+
response: GenerationResponse = model.generate_content(
|
71 |
+
contents=contents,
|
72 |
+
generation_config={"temperature": 0.0},
|
73 |
+
safety_settings=block_nothing(),
|
74 |
+
)
|
75 |
+
logger.info("Token usage: %s", proto.Message.to_dict(response.usage_metadata))
|
76 |
+
|
77 |
+
if len(response.candidates) == 0:
|
78 |
+
raise ResponseRefusedException(
|
79 |
+
"No candidates in response. prompt_feedback='%s'" % response.prompt_feedback
|
80 |
+
)
|
81 |
+
|
82 |
+
enum = type(response.candidates[0].finish_reason)
|
83 |
+
if response.candidates[0].finish_reason in {enum.SAFETY, enum.PROHIBITED_CONTENT}:
|
84 |
+
raise UnsafeResponseError(safety_ratings=response.candidates[0].safety_ratings)
|
85 |
+
|
86 |
+
for blob in blobs:
|
87 |
+
blob.delete()
|
88 |
+
logger.info("Deleted %d blob(s)", len(blobs))
|
89 |
+
|
90 |
+
return response.text
|
91 |
+
|
92 |
+
|
93 |
+
def init_vertex() -> None:
|
94 |
+
vertexai.init(project=project_id, location=frankfurt)
|
95 |
+
|
96 |
+
|
97 |
+
def mime_type(file_name: str) -> str:
|
98 |
+
mapping = {
|
99 |
+
".txt": "text/plain",
|
100 |
+
".jpg": "image/jpeg",
|
101 |
+
".png": "image/png",
|
102 |
+
".flac": "audio/flac",
|
103 |
+
".mp3": "audio/mpeg",
|
104 |
+
".mp4": "video/mp4",
|
105 |
+
}
|
106 |
+
for ext, mime in mapping.items():
|
107 |
+
if file_name.endswith(ext):
|
108 |
+
return mime
|
109 |
+
raise ValueError(f"Unknown mime type for file: {file_name}")
|
110 |
+
|
111 |
+
|
112 |
+
def upload_files(files: list[Path]) -> list[storage.Blob]:
|
113 |
+
logger.info("Uploading %d file(s)", len(files))
|
114 |
+
bucket = _bucket(name=Buckets.temp)
|
115 |
+
files_str = [str(f) for f in files]
|
116 |
+
blobs = [bucket.blob(file.name) for file in files]
|
117 |
+
transfer_manager.upload_many(
|
118 |
+
file_blob_pairs=zip(files_str, blobs),
|
119 |
+
skip_if_exists=True,
|
120 |
+
raise_exception=True,
|
121 |
+
)
|
122 |
+
logger.info("Completed file(s) upload")
|
123 |
+
return blobs
|
124 |
+
|
125 |
+
|
126 |
+
def _bucket(name: str) -> storage.Bucket:
|
127 |
+
client = storage.Client(project=project_id)
|
128 |
+
return client.bucket(name)
|
129 |
+
|
130 |
+
|
131 |
+
def upload_single_file(file: Path, bucket: str, blob_name: str) -> storage.Blob:
|
132 |
+
logger.info("Uploading file '%s' to bucket '%s' as '%s'", file, bucket, blob_name)
|
133 |
+
bucket: storage.Bucket = _bucket(name=bucket)
|
134 |
+
blob = bucket.blob(blob_name)
|
135 |
+
if blob.exists():
|
136 |
+
logger.info("Blob '%s' already exists. Overwriting it...", blob_name)
|
137 |
+
blob.upload_from_filename(str(file))
|
138 |
+
return blob
|
139 |
+
|
140 |
+
|
141 |
+
def block_nothing() -> dict[HarmCategory, HarmBlockThreshold]:
|
142 |
+
return {
|
143 |
+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
144 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
145 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
146 |
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
147 |
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
148 |
+
HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: HarmBlockThreshold.BLOCK_NONE,
|
149 |
+
}
|
150 |
+
|
151 |
+
|
152 |
+
class UnsafeResponseError(Exception):
|
153 |
+
def __init__(self, safety_ratings: list) -> None:
|
154 |
+
super().__init__(
|
155 |
+
"The response was blocked by Google due to safety reasons. Categories: %s"
|
156 |
+
% safety_ratings
|
157 |
+
)
|
158 |
+
self.safety_categories = safety_ratings
|
159 |
+
|
160 |
+
|
161 |
+
class ResponseRefusedException(Exception):
|
162 |
+
pass
|
llmlib/llmlib/llama3/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
models/
|
llmlib/llmlib/llama3/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Installation for the quantized model in `llama_cpp`:
|
3 |
+
```shell
|
4 |
+
CMAKE_ARGS="-DLLAMA_CUBLAS=on -DCUDA_PATH=/usr/local/cuda-12.5 -DCUDAToolkit_ROOT=/usr/local/cuda-12.5 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.5/lib64" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
|
5 |
+
```
|
llmlib/llmlib/llama3/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .llama3_vision_8b import LLama3Vision8B
|
2 |
+
|
3 |
+
__all__ = ["LLama3Vision8B"]
|
llmlib/llmlib/llama3/llama3_vision_8b.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmlib.base_llm import Message
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
from transformers import BitsAndBytesConfig
|
5 |
+
from llmlib.base_llm import LLM
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
_model_id = "qresearch/llama-3-vision-alpha-hf"
|
9 |
+
|
10 |
+
|
11 |
+
class LLama3Vision8B(LLM):
|
12 |
+
model_id = _model_id
|
13 |
+
requires_gpu_exclusively = True
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
self.model = create_model()
|
17 |
+
self.tokenizer = create_tokenizer()
|
18 |
+
|
19 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
20 |
+
if len(msgs) != 1:
|
21 |
+
raise ValueError(
|
22 |
+
f"model='{_model_id}' supports only one message by the user."
|
23 |
+
)
|
24 |
+
msg = msgs[0]
|
25 |
+
if msg.role != "user":
|
26 |
+
raise ValueError(
|
27 |
+
f"model='{_model_id}' supports only a role=user message, not role={msg.role}."
|
28 |
+
)
|
29 |
+
|
30 |
+
# 2024-06-20: Model does not accept image=None, therefore we create a small white image
|
31 |
+
if msg.img is None:
|
32 |
+
empty_img = Image.new("RGB", (3, 3), color="white")
|
33 |
+
image = empty_img
|
34 |
+
else:
|
35 |
+
image = msg.img
|
36 |
+
|
37 |
+
response: str = self.tokenizer.decode(
|
38 |
+
self.model.answer_question(image, msg.msg, self.tokenizer),
|
39 |
+
skip_special_tokens=True,
|
40 |
+
)
|
41 |
+
return response
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def get_warnings(cls) -> list[str]:
|
45 |
+
return ["This model only accepts one message by the user at a time."]
|
46 |
+
|
47 |
+
|
48 |
+
def create_model():
|
49 |
+
bnb_cfg = BitsAndBytesConfig(
|
50 |
+
load_in_4bit=True,
|
51 |
+
bnb_4bit_compute_dtype=torch.float16,
|
52 |
+
llm_int8_skip_modules=["mm_projector", "vision_model"],
|
53 |
+
)
|
54 |
+
|
55 |
+
return AutoModelForCausalLM.from_pretrained(
|
56 |
+
_model_id,
|
57 |
+
trust_remote_code=True,
|
58 |
+
torch_dtype=torch.float16,
|
59 |
+
quantization_config=bnb_cfg,
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def create_tokenizer():
|
64 |
+
return AutoTokenizer.from_pretrained(
|
65 |
+
_model_id,
|
66 |
+
use_fast=True,
|
67 |
+
)
|
llmlib/llmlib/minicpm.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Any
|
4 |
+
from llmlib.base_llm import LLM, Message
|
5 |
+
import torch
|
6 |
+
from transformers import AutoModel, AutoTokenizer
|
7 |
+
from PIL import Image
|
8 |
+
from decord import VideoReader, cpu # pip install decord
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
_model_name = "openbmb/MiniCPM-V-2_6"
|
14 |
+
|
15 |
+
|
16 |
+
class MiniCPM(LLM):
|
17 |
+
temperature: float
|
18 |
+
|
19 |
+
model_id = _model_name
|
20 |
+
requires_gpu_exclusively = True
|
21 |
+
|
22 |
+
def __init__(self, temperature: float = 0.0, model=None) -> None:
|
23 |
+
if model is None:
|
24 |
+
model = _create_model()
|
25 |
+
self.model = model
|
26 |
+
self.tokenizer = _create_tokenizer()
|
27 |
+
self.temperature = temperature
|
28 |
+
|
29 |
+
def chat(self, prompt: str) -> str:
|
30 |
+
return self.complete_msgs2([Message(role="user", msg=prompt)])
|
31 |
+
|
32 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
33 |
+
dict_msgs = [_convert_msg_to_dict(m) for m in msgs]
|
34 |
+
use_sampling = self.temperature > 0.0
|
35 |
+
res = self.model.chat(
|
36 |
+
image=None,
|
37 |
+
msgs=dict_msgs,
|
38 |
+
tokenizer=self.tokenizer,
|
39 |
+
sampling=use_sampling,
|
40 |
+
temperature=self.temperature,
|
41 |
+
)
|
42 |
+
return res
|
43 |
+
|
44 |
+
def video_prompt(self, video_path: Path, prompt: str) -> str:
|
45 |
+
return video_prompt(self, video_path, prompt)
|
46 |
+
|
47 |
+
|
48 |
+
def _create_tokenizer():
|
49 |
+
return AutoTokenizer.from_pretrained(_model_name, trust_remote_code=True)
|
50 |
+
|
51 |
+
|
52 |
+
def _create_model():
|
53 |
+
model = AutoModel.from_pretrained(
|
54 |
+
_model_name,
|
55 |
+
trust_remote_code=True,
|
56 |
+
attn_implementation="flash_attention_2",
|
57 |
+
torch_dtype=torch.bfloat16,
|
58 |
+
)
|
59 |
+
model.eval().cuda()
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def _convert_msg_to_dict(msg: Message) -> dict:
|
64 |
+
if msg.img is None:
|
65 |
+
content: list[Any] = [msg.msg]
|
66 |
+
else:
|
67 |
+
content = [msg.img.convert("RGB"), msg.msg]
|
68 |
+
return {"role": msg.role, "content": content}
|
69 |
+
|
70 |
+
|
71 |
+
def to_listof_imgs(video_path: Path) -> list[Image.Image]:
|
72 |
+
"""
|
73 |
+
Return one frame per second from the video.
|
74 |
+
If the video is longer than MAX_NUM_FRAMES, sample MAX_NUM_FRAMES frames.
|
75 |
+
"""
|
76 |
+
MAX_NUM_FRAMES = 64 # if cuda OOM set a smaller number
|
77 |
+
assert video_path.exists(), video_path
|
78 |
+
vr = VideoReader(str(video_path), ctx=cpu(0))
|
79 |
+
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
80 |
+
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
81 |
+
if len(frame_idx) > MAX_NUM_FRAMES:
|
82 |
+
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
|
83 |
+
imgs = vr.get_batch(frame_idx).asnumpy()
|
84 |
+
imgs = [Image.fromarray(v.astype("uint8")) for v in imgs]
|
85 |
+
return imgs
|
86 |
+
|
87 |
+
|
88 |
+
def uniform_sample(xs, n):
|
89 |
+
gap = len(xs) / n
|
90 |
+
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
91 |
+
return [xs[i] for i in idxs]
|
92 |
+
|
93 |
+
|
94 |
+
def video_prompt(self: MiniCPM, video_path: Path, prompt: str) -> str:
|
95 |
+
imgs = to_listof_imgs(video_path)
|
96 |
+
logger.info("Video turned into %d images", len(imgs))
|
97 |
+
msgs = [
|
98 |
+
{"role": "user", "content": [prompt] + imgs},
|
99 |
+
]
|
100 |
+
# Set decode params for video
|
101 |
+
params = {}
|
102 |
+
params["use_image_id"] = False
|
103 |
+
params["max_slice_nums"] = 2 # use 1 if cuda OOM and video resolution > 448*448
|
104 |
+
answer = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer, **params)
|
105 |
+
return answer
|
llmlib/llmlib/model_registry.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Self
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from typing import Callable
|
4 |
+
from .base_llm import LLM
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class ModelEntry:
|
9 |
+
model_id: str
|
10 |
+
clazz: type[LLM]
|
11 |
+
ctor: Callable[[], LLM]
|
12 |
+
warnings: list[str] = field(default_factory=list)
|
13 |
+
|
14 |
+
@classmethod
|
15 |
+
def from_cls_with_id(cls, T: type[LLM]) -> Self:
|
16 |
+
return cls(model_id=T.model_id, clazz=T, ctor=T, warnings=T.get_warnings())
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ModelRegistry:
|
21 |
+
models: list[ModelEntry] = field(default_factory=list)
|
22 |
+
|
23 |
+
def get_entry(self, model_id: str) -> ModelEntry:
|
24 |
+
id2entry = {entry.model_id: entry for entry in self.models}
|
25 |
+
return id2entry[model_id]
|
26 |
+
|
27 |
+
def all_model_ids(self) -> list[str]:
|
28 |
+
return [entry.model_id for entry in self.models]
|
llmlib/llmlib/openai/openai_completion.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL.Image import Image
|
3 |
+
from ..base_llm import LLM, Message
|
4 |
+
from ..rest_api.restapi_client import encode_as_png_in_base64
|
5 |
+
from openai import OpenAI, ChatCompletion
|
6 |
+
from multiprocessing import Pool
|
7 |
+
|
8 |
+
_default_model = "gpt-4o-mini"
|
9 |
+
|
10 |
+
client = OpenAI() # must be outside of the class to avoid pickling issues
|
11 |
+
|
12 |
+
|
13 |
+
class OpenAIModel(LLM):
|
14 |
+
model_ids = [_default_model, "gpt-4o"]
|
15 |
+
|
16 |
+
def __init__(self, model: str = _default_model):
|
17 |
+
self.model = model
|
18 |
+
|
19 |
+
def complete(self, prompt: str) -> str:
|
20 |
+
return complete(model=self.model, prompt=prompt)
|
21 |
+
|
22 |
+
def complete_msgs(self, messages: list[dict], images: list[Image] = []) -> str:
|
23 |
+
return complete_msgs(model=self.model, messages=messages)
|
24 |
+
|
25 |
+
def complete_many(
|
26 |
+
self, prompts: list[str], n_workers: int = os.cpu_count()
|
27 |
+
) -> list[str]:
|
28 |
+
return complete_many(model=self.model, prompts=prompts, n_workers=n_workers)
|
29 |
+
|
30 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
31 |
+
messages: list[dict] = extract_msgs(msgs)
|
32 |
+
return self.complete_msgs(messages)
|
33 |
+
|
34 |
+
|
35 |
+
def complete_many(
|
36 |
+
model: str, prompts: list[str], n_workers: int = os.cpu_count()
|
37 |
+
) -> list[str]:
|
38 |
+
print("Calling OpenAI API")
|
39 |
+
with Pool(processes=n_workers) as pool:
|
40 |
+
args = [(model, p) for p in prompts]
|
41 |
+
return pool.starmap(complete, args)
|
42 |
+
|
43 |
+
|
44 |
+
def complete(model: str, prompt: str) -> str:
|
45 |
+
messages = [{"role": "user", "content": prompt}]
|
46 |
+
return complete_msgs(model=model, messages=messages)
|
47 |
+
|
48 |
+
|
49 |
+
def complete_msgs(model: str, messages: list[dict]) -> str:
|
50 |
+
completion: ChatCompletion = client.chat.completions.create(
|
51 |
+
model=model, temperature=0.0, messages=messages
|
52 |
+
)
|
53 |
+
assert len(completion.choices) == 1
|
54 |
+
return completion.choices[0].message.content
|
55 |
+
|
56 |
+
|
57 |
+
def postprocess(response: str) -> str:
|
58 |
+
return response.lower().strip(".").strip()
|
59 |
+
|
60 |
+
|
61 |
+
def extract_msgs(msgs: list[Message]) -> list[dict]:
|
62 |
+
return [extract_msg(m) for m in msgs]
|
63 |
+
|
64 |
+
|
65 |
+
def extract_msg(msg: Message) -> dict:
|
66 |
+
if msg.img is None:
|
67 |
+
return {"role": msg.role, "content": msg.msg}
|
68 |
+
img_in_base64 = encode_as_png_in_base64(msg.img)
|
69 |
+
return {
|
70 |
+
"role": msg.role,
|
71 |
+
"content": [
|
72 |
+
{"type": "text", "text": msg.msg},
|
73 |
+
{
|
74 |
+
"type": "image_url",
|
75 |
+
"image_url": {"url": f"data:image/png;base64,{img_in_base64}"},
|
76 |
+
},
|
77 |
+
],
|
78 |
+
}
|
llmlib/llmlib/phi3/phi3.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any
|
3 |
+
from llmlib.base_llm import Message
|
4 |
+
from torch import Tensor
|
5 |
+
import torch
|
6 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
7 |
+
from PIL import Image
|
8 |
+
from llmlib.base_llm import LLM
|
9 |
+
from transformers.image_processing_utils import BatchFeature
|
10 |
+
|
11 |
+
model_id = "microsoft/Phi-3-vision-128k-instruct"
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class GenConf:
|
16 |
+
max_new_tokens: int = 500
|
17 |
+
temperature: float = 0.0
|
18 |
+
|
19 |
+
def to_dict(self) -> dict[str, Any]:
|
20 |
+
do_sample: bool = self.temperature != 0.0
|
21 |
+
return {
|
22 |
+
"max_new_tokens": self.max_new_tokens,
|
23 |
+
"temperature": self.temperature if do_sample else None,
|
24 |
+
"do_sample": do_sample,
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class Phi3Vision(LLM):
|
29 |
+
model_id = model_id
|
30 |
+
requires_gpu_exclusively = True
|
31 |
+
|
32 |
+
def __init__(self, gen_conf: GenConf | None = None):
|
33 |
+
self.model = create_model()
|
34 |
+
self.processor = create_processor()
|
35 |
+
self.gen_conf = GenConf() if gen_conf is None else gen_conf
|
36 |
+
|
37 |
+
def complete(self, prompt: str) -> str:
|
38 |
+
msg = Message(role="user", msg=prompt)
|
39 |
+
return completion(llm=self, batch=[[msg]])[0]
|
40 |
+
|
41 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
42 |
+
return completion(llm=self, batch=[msgs])[0]
|
43 |
+
|
44 |
+
def complete_batch(self, batch: list[list[Message]]) -> list[str]:
|
45 |
+
return completion(llm=self, batch=batch)
|
46 |
+
|
47 |
+
|
48 |
+
def extract_imgs_and_dicts(msgs: list[Message]) -> tuple[list[Image.Image], list[dict]]:
|
49 |
+
"""
|
50 |
+
Phi3 expects in the prompts placehodlers for images in the form <|image_X|>, where X is the image number.
|
51 |
+
It also requires the images as a separate array of PIL images.
|
52 |
+
This function extracts the images from the messages and creates the placeholders.
|
53 |
+
It makes sure to avoid duplication in the images and placeholders.
|
54 |
+
"""
|
55 |
+
img_names = list(dict.fromkeys(m.img_name for m in msgs if m.img_name is not None))
|
56 |
+
placeholders = {
|
57 |
+
img_name: f"<|image_{i}|>" for i, img_name in enumerate(img_names, 1)
|
58 |
+
}
|
59 |
+
imgs = {}
|
60 |
+
for msg in msgs:
|
61 |
+
if msg.img is not None and msg.img_name not in imgs:
|
62 |
+
imgs[msg.img_name] = msg.img
|
63 |
+
images = list(imgs.values())
|
64 |
+
|
65 |
+
messages: list[dict] = [] # entries are {"role": str, "content": str}
|
66 |
+
for m in msgs:
|
67 |
+
if m.img is not None and m.img_name is not None:
|
68 |
+
img_placeholder = placeholders[m.img_name]
|
69 |
+
content = f"{img_placeholder}\n{m.msg}"
|
70 |
+
else:
|
71 |
+
content = m.msg
|
72 |
+
messages.append({"role": m.role, "content": content})
|
73 |
+
return images, messages
|
74 |
+
|
75 |
+
|
76 |
+
def create_model(model_id: str = model_id):
|
77 |
+
return AutoModelForCausalLM.from_pretrained(
|
78 |
+
model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto"
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
def create_processor(model_id: str = model_id):
|
83 |
+
return AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
84 |
+
|
85 |
+
|
86 |
+
def convert_to_messages(prompts: list[str]) -> list[list[dict]]:
|
87 |
+
return [[{"role": "user", "content": prompt}] for prompt in prompts]
|
88 |
+
|
89 |
+
|
90 |
+
def completion(llm: Phi3Vision, batch: list[list[Message]]) -> list[str]:
|
91 |
+
reject_invalid_batches(batch)
|
92 |
+
listof_inputs: list[BatchFeature] = []
|
93 |
+
for messages in batch:
|
94 |
+
images, messages_dicts = extract_imgs_and_dicts(messages)
|
95 |
+
prompt: str = llm.processor.tokenizer.apply_chat_template(
|
96 |
+
messages_dicts, tokenize=False, add_generation_prompt=True
|
97 |
+
)
|
98 |
+
imgs = None if len(images) == 0 else images
|
99 |
+
inputs = llm.processor(prompt, imgs, return_tensors="pt").to("cuda")
|
100 |
+
listof_inputs.append(inputs)
|
101 |
+
|
102 |
+
pad_token_id = llm.processor.tokenizer.pad_token_id
|
103 |
+
inputs = stack_and_pad_inputs(listof_inputs, pad_token_id=pad_token_id)
|
104 |
+
|
105 |
+
generate_ids: Tensor = llm.model.generate(
|
106 |
+
**inputs,
|
107 |
+
eos_token_id=llm.processor.tokenizer.eos_token_id,
|
108 |
+
**llm.gen_conf.to_dict(),
|
109 |
+
)
|
110 |
+
# the prompt is included in the output, so we need to drop it.
|
111 |
+
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
|
112 |
+
|
113 |
+
responses: list[str] = llm.processor.batch_decode(
|
114 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
115 |
+
)
|
116 |
+
return responses
|
117 |
+
|
118 |
+
|
119 |
+
def reject_invalid_batches(batch: list[list[Message]]) -> None:
|
120 |
+
"""
|
121 |
+
Valid batches are:
|
122 |
+
- batch of lenght 1, or
|
123 |
+
- batches with only a single message per entry, AND
|
124 |
+
- all messages have an image, or
|
125 |
+
- all messages are text only.
|
126 |
+
"""
|
127 |
+
if len(batch) <= 1:
|
128 |
+
return
|
129 |
+
if any(len(msgs) != 1 for msgs in batch):
|
130 |
+
raise ValueError("Batch must contain only one message per entry.")
|
131 |
+
any_msg_has_img = any(msg.img is not None for msgs in batch for msg in msgs)
|
132 |
+
any_msg_is_no_img = any(msg.img is None for msgs in batch for msg in msgs)
|
133 |
+
if any_msg_has_img and any_msg_is_no_img:
|
134 |
+
raise ValueError("Batch must contain an image in every entry or none at all.")
|
135 |
+
|
136 |
+
|
137 |
+
def pad_left(seqs: list[torch.Tensor], pad_token_id: int) -> torch.Tensor:
|
138 |
+
max_len = max(len(seq) for seq in seqs)
|
139 |
+
padded = torch.full((len(seqs), max_len), pad_token_id)
|
140 |
+
for i, seq in enumerate(seqs):
|
141 |
+
padded[i, -len(seq) :] = seq
|
142 |
+
return padded
|
143 |
+
|
144 |
+
|
145 |
+
def stack_and_pad_inputs(inputs: list[BatchFeature], pad_token_id: int) -> BatchFeature:
|
146 |
+
listof_input_ids = [i.input_ids[0] for i in inputs]
|
147 |
+
new_input_ids = pad_left(listof_input_ids, pad_token_id=pad_token_id)
|
148 |
+
data = dict(
|
149 |
+
input_ids=new_input_ids,
|
150 |
+
attention_mask=(new_input_ids != pad_token_id).long(),
|
151 |
+
)
|
152 |
+
has_imgs: bool = "pixel_values" in inputs[0]
|
153 |
+
if has_imgs:
|
154 |
+
data["pixel_values"] = torch.cat([i.pixel_values for i in inputs], dim=0)
|
155 |
+
data["image_sizes"] = torch.cat([i.image_sizes for i in inputs], dim=0)
|
156 |
+
|
157 |
+
return BatchFeature(data).to("cuda")
|
llmlib/llmlib/phi3/phi35.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from llmlib.phi3.phi3 import stack_and_pad_inputs
|
3 |
+
import requests
|
4 |
+
from transformers import AutoModelForCausalLM
|
5 |
+
from transformers import AutoProcessor
|
6 |
+
from transformers.image_processing_utils import BatchFeature
|
7 |
+
|
8 |
+
model_id = "microsoft/Phi-3.5-vision-instruct"
|
9 |
+
|
10 |
+
|
11 |
+
model = AutoModelForCausalLM.from_pretrained(
|
12 |
+
model_id,
|
13 |
+
device_map="cuda",
|
14 |
+
trust_remote_code=True,
|
15 |
+
torch_dtype="auto",
|
16 |
+
_attn_implementation="flash_attention_2",
|
17 |
+
)
|
18 |
+
|
19 |
+
# for best performance, use num_crops=4 for multi-frame, num_crops=16 for single-frame.
|
20 |
+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, num_crops=4)
|
21 |
+
|
22 |
+
links = [
|
23 |
+
"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg",
|
24 |
+
"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-2-2048.jpg",
|
25 |
+
"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-3-2048.jpg",
|
26 |
+
]
|
27 |
+
images = [Image.open(requests.get(link, stream=True).raw) for link in links]
|
28 |
+
batch = [
|
29 |
+
[{"role": "user", "content": "<|image_1|>Who is mentioned in this picture?"}],
|
30 |
+
[{"role": "user", "content": "<|image_1|>What is the title of this image?"}],
|
31 |
+
[{"role": "user", "content": "<|image_1|>What icons are shown in this image?"}],
|
32 |
+
]
|
33 |
+
# batch = [
|
34 |
+
# [{"role": "user", "content": "What is the capital of France?"}],
|
35 |
+
# [{"role": "user", "content": "How does one make a cookie that is vegetarian?"}],
|
36 |
+
# ]
|
37 |
+
# images = [None, None]
|
38 |
+
|
39 |
+
# BatchFeature(s) are the output of the processor, which is used as input to the model.
|
40 |
+
listof_inputs: list[BatchFeature] = []
|
41 |
+
for messages, image in zip(batch, images):
|
42 |
+
prompt = processor.tokenizer.apply_chat_template(
|
43 |
+
messages, tokenize=False, add_generation_prompt=True
|
44 |
+
)
|
45 |
+
images_ = None if image is None else [image]
|
46 |
+
inputs = processor(prompt, images_, return_tensors="pt").to("cuda:0")
|
47 |
+
listof_inputs.append(inputs)
|
48 |
+
|
49 |
+
|
50 |
+
inputs = stack_and_pad_inputs(
|
51 |
+
listof_inputs, pad_token_id=processor.tokenizer.pad_token_id
|
52 |
+
)
|
53 |
+
|
54 |
+
generation_args = {
|
55 |
+
"max_new_tokens": 1000,
|
56 |
+
"temperature": None,
|
57 |
+
"do_sample": False,
|
58 |
+
}
|
59 |
+
|
60 |
+
generate_ids = model.generate(
|
61 |
+
**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args
|
62 |
+
)
|
63 |
+
|
64 |
+
generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
|
65 |
+
responses: list[str] = processor.batch_decode(
|
66 |
+
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
67 |
+
)
|
68 |
+
|
69 |
+
for p, r in zip(batch, responses):
|
70 |
+
print(p)
|
71 |
+
print(r)
|
72 |
+
print()
|
llmlib/llmlib/pixtral_demo.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from vllm import LLM
|
2 |
+
from vllm.sampling_params import SamplingParams
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
model_name = "mistralai/Pixtral-12B-2409"
|
6 |
+
|
7 |
+
sampling_params = SamplingParams(max_tokens=8192)
|
8 |
+
|
9 |
+
llm = LLM(model=model_name, gpu_memory_utilization=0.1, tokenizer_mode="mistral")
|
10 |
+
|
11 |
+
prompt = "Describe this image in one sentence."
|
12 |
+
image_url = "https://picsum.photos/id/237/200/300"
|
13 |
+
|
14 |
+
messages = [
|
15 |
+
{
|
16 |
+
"role": "user",
|
17 |
+
"content": [
|
18 |
+
{"type": "text", "text": prompt},
|
19 |
+
{"type": "image_url", "image_url": {"url": image_url}},
|
20 |
+
],
|
21 |
+
},
|
22 |
+
]
|
23 |
+
|
24 |
+
outputs = llm.chat(messages, sampling_params=sampling_params)
|
25 |
+
|
26 |
+
print(outputs[0].outputs[0].text)
|
llmlib/llmlib/rest_api/__init__.py
ADDED
File without changes
|
llmlib/llmlib/rest_api/restapi_client.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
from PIL import Image
|
7 |
+
from ..base_llm import Message
|
8 |
+
from ..bundler_request import BundlerRequest
|
9 |
+
from pydantic import BaseModel
|
10 |
+
from typing import Literal
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def encode_as_png_in_base64(img: Image.Image) -> str:
|
16 |
+
stream = io.BytesIO()
|
17 |
+
img.save(stream, format="PNG")
|
18 |
+
return base64.b64encode(stream.getvalue()).decode("utf-8")
|
19 |
+
|
20 |
+
|
21 |
+
class MsgDto(BaseModel):
|
22 |
+
role: Literal["user", "assistant"]
|
23 |
+
msg: str
|
24 |
+
img_name: str | None = None
|
25 |
+
img_str: str | None = None
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
def from_bundler_msg(cls, msg: Message) -> "MsgDto":
|
29 |
+
return cls(
|
30 |
+
role=msg.role,
|
31 |
+
msg=msg.msg,
|
32 |
+
img_name=msg.img_name,
|
33 |
+
img_str=encode_as_png_in_base64(msg.img) if msg.img is not None else None,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
def to_bundler_msg(msg: MsgDto) -> Message:
|
38 |
+
return Message(
|
39 |
+
role=msg.role,
|
40 |
+
msg=msg.msg,
|
41 |
+
img_name=msg.img_name,
|
42 |
+
img=Image.open(io.BytesIO(base64.b64decode(msg.img_str)))
|
43 |
+
if msg.img_str
|
44 |
+
else None,
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
class RequestDto(BaseModel):
|
49 |
+
model: str
|
50 |
+
msgs: list[MsgDto]
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def from_bundler_request(cls, breq: BundlerRequest) -> "RequestDto":
|
54 |
+
return cls(
|
55 |
+
model=breq.model_id,
|
56 |
+
msgs=[MsgDto.from_bundler_msg(msg) for msg in breq.msgs],
|
57 |
+
)
|
58 |
+
|
59 |
+
model_config = {
|
60 |
+
"json_schema_extra": {
|
61 |
+
"examples": [
|
62 |
+
{
|
63 |
+
"model": "microsoft/Phi-3-vision-128k-instruct",
|
64 |
+
"msgs": [{"role": "user", "msg": "What is the capital of France?"}],
|
65 |
+
}
|
66 |
+
]
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
_api_host = os.environ.get("LLMS_REST_API_HOST", "http://localhost") + ":8030"
|
72 |
+
|
73 |
+
|
74 |
+
def _headers():
|
75 |
+
return {"X-API-Key": os.environ["LLMS_REST_API_KEY"]}
|
76 |
+
|
77 |
+
|
78 |
+
def get_completion_from_rest_api(
|
79 |
+
breq: BundlerRequest, source=requests, **kwargs
|
80 |
+
) -> requests.Response:
|
81 |
+
req = RequestDto.from_bundler_request(breq)
|
82 |
+
url = _api_host + "/completion/"
|
83 |
+
logger.info(f"Sending completion request to '{url}'.")
|
84 |
+
return source.post(
|
85 |
+
url=url,
|
86 |
+
json=req.model_dump(),
|
87 |
+
headers=_headers(),
|
88 |
+
**kwargs,
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
def get_models(source=requests) -> requests.Response:
|
93 |
+
return source.get(url=_api_host + "/models/", headers=_headers())
|
94 |
+
|
95 |
+
|
96 |
+
def clear_gpu(source=requests) -> requests.Response:
|
97 |
+
return source.post(url=_api_host + "/clear-gpu/", headers=_headers())
|
llmlib/llmlib/rest_api/restapi_server.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from fastapi import Depends, FastAPI, HTTPException, Security
|
3 |
+
from fastapi.responses import JSONResponse
|
4 |
+
from fastapi.security import APIKeyHeader
|
5 |
+
from llmlib.bundler import Bundler
|
6 |
+
from llmlib.bundler_request import BundlerRequest
|
7 |
+
from llmlib.rest_api.restapi_client import RequestDto, to_bundler_msg
|
8 |
+
from llmlib.runtime import filled_model_registry
|
9 |
+
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
import bugsnag
|
14 |
+
from bugsnag.asgi import BugsnagMiddleware
|
15 |
+
|
16 |
+
|
17 |
+
def create_fastapi_app() -> FastAPI:
|
18 |
+
bugsnag.configure(api_key=os.environ["BUGSNAG_API_KEY"])
|
19 |
+
|
20 |
+
bundler = Bundler(registry=filled_model_registry())
|
21 |
+
app = FastAPI()
|
22 |
+
app.add_middleware(BugsnagMiddleware)
|
23 |
+
|
24 |
+
header = APIKeyHeader(name="X-API-Key")
|
25 |
+
|
26 |
+
def is_authorized(api_key: str = Security(header)) -> bool:
|
27 |
+
if api_key != os.environ["LLMS_REST_API_KEY"]:
|
28 |
+
raise HTTPException(status_code=401, detail="Invalid API Key")
|
29 |
+
return True
|
30 |
+
|
31 |
+
@app.get("/models/")
|
32 |
+
def _(_=Depends(is_authorized)):
|
33 |
+
return bundler.registry.all_model_ids()
|
34 |
+
|
35 |
+
@app.post("/completion/")
|
36 |
+
def _(req: RequestDto, _=Depends(is_authorized)):
|
37 |
+
breq = BundlerRequest(
|
38 |
+
model_id=req.model, msgs=[to_bundler_msg(msg) for msg in req.msgs]
|
39 |
+
)
|
40 |
+
return {"response": bundler.get_response(breq)}
|
41 |
+
|
42 |
+
@app.post("/clear-gpu/")
|
43 |
+
def _(_=Depends(is_authorized)):
|
44 |
+
bundler.clear_model_on_gpu()
|
45 |
+
return {"status": "success"}
|
46 |
+
|
47 |
+
@app.exception_handler(torch.cuda.OutOfMemoryError)
|
48 |
+
def _(req, exc):
|
49 |
+
return JSONResponse(
|
50 |
+
status_code=500,
|
51 |
+
content={
|
52 |
+
"detail": "Error. GPU out of memory. There might be another workload running on the GPU."
|
53 |
+
},
|
54 |
+
)
|
55 |
+
|
56 |
+
return app
|
llmlib/llmlib/runtime.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .minicpm import MiniCPM
|
2 |
+
from .llama3 import LLama3Vision8B
|
3 |
+
from .model_registry import ModelEntry, ModelRegistry
|
4 |
+
from .openai.openai_completion import OpenAIModel
|
5 |
+
from .phi3.phi3 import Phi3Vision
|
6 |
+
|
7 |
+
|
8 |
+
def filled_model_registry() -> ModelRegistry:
|
9 |
+
return ModelRegistry(
|
10 |
+
models=[
|
11 |
+
ModelEntry.from_cls_with_id(Phi3Vision),
|
12 |
+
ModelEntry.from_cls_with_id(MiniCPM),
|
13 |
+
ModelEntry.from_cls_with_id(LLama3Vision8B),
|
14 |
+
*[
|
15 |
+
ModelEntry(
|
16 |
+
model_id=id_, clazz=OpenAIModel, ctor=lambda: OpenAIModel(model=id_)
|
17 |
+
)
|
18 |
+
for id_ in OpenAIModel.model_ids
|
19 |
+
],
|
20 |
+
]
|
21 |
+
)
|
llmlib/llmlib/whisper.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from logging import getLogger
|
3 |
+
from typing import Any
|
4 |
+
import warnings
|
5 |
+
import torch
|
6 |
+
from transformers import (
|
7 |
+
AutoModelForSpeechSeq2Seq,
|
8 |
+
AutoProcessor,
|
9 |
+
pipeline,
|
10 |
+
AutomaticSpeechRecognitionPipeline,
|
11 |
+
)
|
12 |
+
|
13 |
+
logger = getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def create_whisper_pipe() -> AutomaticSpeechRecognitionPipeline:
|
17 |
+
device = "cuda"
|
18 |
+
torch_dtype = torch.float16
|
19 |
+
|
20 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
21 |
+
model_id,
|
22 |
+
torch_dtype=torch_dtype,
|
23 |
+
low_cpu_mem_usage=True,
|
24 |
+
use_safetensors=True,
|
25 |
+
attn_implementation="flash_attention_2",
|
26 |
+
)
|
27 |
+
model.to(device)
|
28 |
+
|
29 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
30 |
+
|
31 |
+
pipe = pipeline(
|
32 |
+
"automatic-speech-recognition",
|
33 |
+
model=model,
|
34 |
+
tokenizer=processor.tokenizer,
|
35 |
+
feature_extractor=processor.feature_extractor,
|
36 |
+
torch_dtype=torch_dtype,
|
37 |
+
device=device,
|
38 |
+
)
|
39 |
+
|
40 |
+
return pipe
|
41 |
+
|
42 |
+
|
43 |
+
model_id = "openai/whisper-large-v3-turbo"
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass
|
47 |
+
class Whisper:
|
48 |
+
model_id = model_id
|
49 |
+
|
50 |
+
pipe: AutomaticSpeechRecognitionPipeline = field(
|
51 |
+
default_factory=create_whisper_pipe
|
52 |
+
)
|
53 |
+
|
54 |
+
def transcribe_file(self, file: str, translate=False) -> str:
|
55 |
+
assert isinstance(file, str)
|
56 |
+
logger.info("Transcribing file: %s", file)
|
57 |
+
try:
|
58 |
+
return self._transcribe(file, translate, return_timestamps=False)
|
59 |
+
except ValueError as e:
|
60 |
+
if "Please either pass `return_timestamps=True`" in repr(e):
|
61 |
+
logger.info("File is >30s, transcribing with timestamps: %s", file)
|
62 |
+
return self._transcribe(file, translate, return_timestamps=True)
|
63 |
+
raise
|
64 |
+
|
65 |
+
def _transcribe(self, file: str, translate: bool, return_timestamps: bool) -> str:
|
66 |
+
kwargs: dict[str, Any] = {"return_timestamps": return_timestamps}
|
67 |
+
if translate:
|
68 |
+
kwargs["generate_kwargs"] = {"language": "english"}
|
69 |
+
# ignore this warning:
|
70 |
+
# .../site-packages/transformers/models/whisper/generation_whisper.py:496: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.
|
71 |
+
with warnings.catch_warnings(action="ignore", category=FutureWarning):
|
72 |
+
# data["chunks"] contains the timestamped transcriptions
|
73 |
+
data = self.pipe(file, **kwargs)
|
74 |
+
return data["text"].strip()
|
llmlib/pyproject.toml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "llmlib"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Tomas Ruiz <[email protected]>"]
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[tool.poetry.dependencies]
|
9 |
+
python = "^3.11"
|
10 |
+
bugsnag = "^4.7.1"
|
11 |
+
decord = "^0.6.0"
|
12 |
+
google-cloud-aiplatform = "^1.64"
|
13 |
+
# I cannot add the dependencies below. I suspect that torch is a build-time dependency for flash-attn, or something like that.
|
14 |
+
# transformers = "^4.44.2"
|
15 |
+
# accelerate = "^0.34.2"
|
16 |
+
# flash-attn = "^2.6.3"
|
17 |
+
# torch = "^2.4.0"
|
18 |
+
|
19 |
+
[build-system]
|
20 |
+
requires = ["poetry-core"]
|
21 |
+
build-backend = "poetry.core.masonry.api"
|
login_mask_simple.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
WARNING: This file is duplicated in the projects: llm-app, tiktok. Make sure changes are reflected in all projects!
|
3 |
+
|
4 |
+
Copied from https://docs.streamlit.io/knowledge-base/deploy/authentication-without-sso
|
5 |
+
"""
|
6 |
+
|
7 |
+
from functools import cache
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import streamlit as st
|
11 |
+
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def login_form():
|
17 |
+
with st.form("Credentials"):
|
18 |
+
st.text_input("Password", type="password", key="password")
|
19 |
+
st.form_submit_button("Log in", on_click=password_entered)
|
20 |
+
|
21 |
+
|
22 |
+
def password_entered():
|
23 |
+
correct_pw = os.environ["LLMS_REST_API_KEY"]
|
24 |
+
is_correct: bool = st.session_state.pop("password") == correct_pw
|
25 |
+
st.session_state["password_correct"] = is_correct
|
26 |
+
|
27 |
+
|
28 |
+
def check_password() -> bool:
|
29 |
+
"""Return `True` if the user is allowed to access the app, `False` otherwise."""
|
30 |
+
skip_pw: bool = os.environ.get("USE_STREAMLIT_PASSWORD", "true").lower() == "false"
|
31 |
+
if skip_pw:
|
32 |
+
log_password_check_skipped()
|
33 |
+
return True
|
34 |
+
|
35 |
+
"""Returns `True` if the user had a correct password."""
|
36 |
+
|
37 |
+
# Return True if the username + password is validated.
|
38 |
+
if st.session_state.get("password_correct", False):
|
39 |
+
return True
|
40 |
+
|
41 |
+
# Show inputs for username + password.
|
42 |
+
login_form()
|
43 |
+
if "password_correct" in st.session_state:
|
44 |
+
st.error("😕 Password incorrect")
|
45 |
+
return False
|
46 |
+
|
47 |
+
|
48 |
+
@cache # Print only once per session
|
49 |
+
def log_password_check_skipped():
|
50 |
+
logger.info("Skipping password check because USE_STREAMLIT_PASSWORD=false.")
|
readme/llm-app-demo.mp4
ADDED
Binary file (884 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytest
|
2 |
+
deepdiff
|
3 |
+
pillow
|
4 |
+
openai
|
5 |
+
transformers
|
6 |
+
torch
|
7 |
+
streamlit
|
8 |
+
bitsandbytes
|
9 |
+
accelerate
|
10 |
+
flash-attn
|
11 |
+
fastapi[standard]
|
12 |
+
./llmlib
|
rest_api.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
from llmlib.rest_api.restapi_server import create_fastapi_app
|
3 |
+
|
4 |
+
app: FastAPI = create_fastapi_app()
|
st_app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import streamlit as st
|
3 |
+
from llmlib.runtime import filled_model_registry
|
4 |
+
from llmlib.model_registry import ModelEntry, ModelRegistry
|
5 |
+
from llmlib.base_llm import Message
|
6 |
+
from llmlib.bundler import Bundler
|
7 |
+
from llmlib.bundler_request import BundlerRequest
|
8 |
+
from login_mask_simple import check_password
|
9 |
+
|
10 |
+
if not check_password():
|
11 |
+
st.stop()
|
12 |
+
|
13 |
+
st.set_page_config(page_title="LLM App", layout="wide")
|
14 |
+
|
15 |
+
st.title("LLM App")
|
16 |
+
|
17 |
+
|
18 |
+
model_registry: ModelRegistry = filled_model_registry()
|
19 |
+
|
20 |
+
|
21 |
+
@st.cache_resource()
|
22 |
+
def create_model_bundler() -> Bundler:
|
23 |
+
return Bundler(registry=model_registry)
|
24 |
+
|
25 |
+
|
26 |
+
def display_warnings(r: ModelRegistry, model_id: str) -> None:
|
27 |
+
e1: ModelEntry = r.get_entry(model_id)
|
28 |
+
if len(e1.warnings) > 0:
|
29 |
+
st.warning(" \n".join(e1.warnings))
|
30 |
+
|
31 |
+
|
32 |
+
cs = st.columns(2)
|
33 |
+
with cs[0]:
|
34 |
+
model1_id: str = st.selectbox("Select model", model_registry.all_model_ids())
|
35 |
+
display_warnings(model_registry, model1_id)
|
36 |
+
with cs[1]:
|
37 |
+
if "img-key" not in st.session_state:
|
38 |
+
st.session_state["img-key"] = 0
|
39 |
+
image = st.file_uploader("Include an image", key=st.session_state["img-key"])
|
40 |
+
|
41 |
+
if "messages1" not in st.session_state:
|
42 |
+
st.session_state.messages1 = [] # list[Message]
|
43 |
+
st.session_state.messages2 = [] # list[Message]
|
44 |
+
|
45 |
+
if st.button("Restart chat"):
|
46 |
+
st.session_state.messages1 = [] # list[Message]
|
47 |
+
st.session_state.messages2 = [] # list[Message]
|
48 |
+
|
49 |
+
|
50 |
+
def render_messages(msgs: list[Message]) -> None:
|
51 |
+
for msg in msgs:
|
52 |
+
render_message(msg)
|
53 |
+
|
54 |
+
|
55 |
+
def render_message(msg: Message):
|
56 |
+
with st.chat_message(msg.role):
|
57 |
+
if msg.img_name is not None:
|
58 |
+
render_img(msg)
|
59 |
+
st.markdown(msg.msg)
|
60 |
+
|
61 |
+
|
62 |
+
def render_img(msg: Message):
|
63 |
+
st.image(msg.img, caption=msg.img_name, width=400)
|
64 |
+
|
65 |
+
|
66 |
+
n_cols = 1
|
67 |
+
cs = st.columns(n_cols)
|
68 |
+
render_messages(st.session_state.messages1)
|
69 |
+
|
70 |
+
prompt = st.chat_input("Type here")
|
71 |
+
if prompt is None:
|
72 |
+
st.stop()
|
73 |
+
|
74 |
+
msg = Message(
|
75 |
+
role="user",
|
76 |
+
msg=prompt,
|
77 |
+
img_name=image.name if image is not None else None,
|
78 |
+
img=Image.open(image) if image is not None else None,
|
79 |
+
)
|
80 |
+
|
81 |
+
if image is not None:
|
82 |
+
st.session_state["img-key"] += 1
|
83 |
+
|
84 |
+
st.session_state.messages1.append(msg)
|
85 |
+
render_message(msg)
|
86 |
+
|
87 |
+
model_bundler: Bundler = create_model_bundler()
|
88 |
+
|
89 |
+
with st.spinner("Initializing model..."):
|
90 |
+
model_bundler.set_model_on_gpu(model_id=model1_id)
|
91 |
+
|
92 |
+
with st.spinner("Generating response..."):
|
93 |
+
req = BundlerRequest(model_id=model1_id, msgs=st.session_state.messages1)
|
94 |
+
response = model_bundler.get_response(req)
|
95 |
+
msg = Message(role="assistant", msg=response)
|
96 |
+
st.session_state.messages1.append(msg)
|
97 |
+
render_message(msg)
|
test-files/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.jpg
|
2 |
+
*.png
|
3 |
+
*.mp3
|
4 |
+
*.mp4
|
5 |
+
*.flac
|
test-files/toxicity-prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Evaluate whether this image could potentially be harmful and if so, explain why
|
tests/__init__.py
ADDED
File without changes
|
tests/helpers.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import PIL
|
4 |
+
from llmlib.base_llm import LLM, Message
|
5 |
+
import pytest
|
6 |
+
|
7 |
+
|
8 |
+
def assert_model_knows_capital_of_france(model: LLM) -> None:
|
9 |
+
response: str = model.complete_msgs2(
|
10 |
+
msgs=[Message(role="user", msg="What is the capital of France?")]
|
11 |
+
)
|
12 |
+
assert "paris" in response.lower()
|
13 |
+
|
14 |
+
|
15 |
+
def assert_model_can_answer_batch_of_text_prompts(model: LLM) -> None:
|
16 |
+
prompts = [
|
17 |
+
"What is the capital of France?",
|
18 |
+
"What continent is south of Europe?",
|
19 |
+
"What are the two tallest mountains in the world?",
|
20 |
+
]
|
21 |
+
batch = [[Message.from_prompt(prompt)] for prompt in prompts]
|
22 |
+
responses = model.complete_batch(batch=batch)
|
23 |
+
assert len(responses) == 3
|
24 |
+
assert "paris" in responses[0].lower()
|
25 |
+
assert "africa" in responses[1].lower()
|
26 |
+
assert "everest" in responses[2].lower()
|
27 |
+
|
28 |
+
|
29 |
+
def assert_model_can_answer_batch_of_img_prompts(model: LLM) -> None:
|
30 |
+
batch = [
|
31 |
+
[pyramid_message()],
|
32 |
+
[forest_message()],
|
33 |
+
[fish_message()],
|
34 |
+
]
|
35 |
+
responses = model.complete_batch(batch=batch)
|
36 |
+
assert len(responses) == 3
|
37 |
+
assert "pyramid" in responses[0].lower()
|
38 |
+
assert "forest" in responses[1].lower()
|
39 |
+
assert "fish" in responses[2].lower()
|
40 |
+
|
41 |
+
|
42 |
+
def assert_model_rejects_unsupported_batches(model: LLM) -> None:
|
43 |
+
mixed_textonly_and_img_batch = [
|
44 |
+
[Message.from_prompt("What is the capital of France?")],
|
45 |
+
[pyramid_message()],
|
46 |
+
]
|
47 |
+
err_msg = "Batch must contain an image in every entry or none at all."
|
48 |
+
with pytest.raises(ValueError, match=err_msg):
|
49 |
+
model.complete_batch(mixed_textonly_and_img_batch)
|
50 |
+
|
51 |
+
|
52 |
+
def assert_model_recognizes_pyramid_in_image(model: LLM):
|
53 |
+
msg = pyramid_message()
|
54 |
+
answer: str = model.complete_msgs2(msgs=[msg])
|
55 |
+
assert "pyramid" in answer.lower()
|
56 |
+
|
57 |
+
|
58 |
+
def assert_model_recognizes_afd_in_video(model: LLM):
|
59 |
+
video_path = file_for_test("video.mp4")
|
60 |
+
question = "Describe the video in english"
|
61 |
+
answer: str = model.video_prompt(video_path, question)
|
62 |
+
assert "alternative für deutschland" in answer.lower(), answer
|
63 |
+
|
64 |
+
|
65 |
+
def get_mona_lisa_completion(model: LLM) -> str:
|
66 |
+
msg: Message = mona_lisa_message()
|
67 |
+
answer: str = model.complete_msgs2(msgs=[msg])
|
68 |
+
return answer
|
69 |
+
|
70 |
+
|
71 |
+
def mona_lisa_message() -> Message:
|
72 |
+
_, img = mona_lisa_filename_and_img()
|
73 |
+
prompt = "What is in the image?"
|
74 |
+
msg = Message(role="user", msg=prompt, img=img, img_name="")
|
75 |
+
return msg
|
76 |
+
|
77 |
+
|
78 |
+
def pyramid_message() -> Message:
|
79 |
+
img_name = "pyramid.jpg"
|
80 |
+
img = get_test_img(img_name)
|
81 |
+
msg = Message(role="user", msg="What is in the image?", img=img, img_name="")
|
82 |
+
return msg
|
83 |
+
|
84 |
+
|
85 |
+
def forest_message() -> Message:
|
86 |
+
img_name = "forest.jpg"
|
87 |
+
img = get_test_img(img_name)
|
88 |
+
msg = Message(
|
89 |
+
role="user", msg="Describe what you see in the picture.", img=img, img_name=""
|
90 |
+
)
|
91 |
+
return msg
|
92 |
+
|
93 |
+
|
94 |
+
def fish_message() -> Message:
|
95 |
+
img_name = "fish.jpg"
|
96 |
+
img = get_test_img(img_name)
|
97 |
+
msg = Message(
|
98 |
+
role="user",
|
99 |
+
msg="What animal is depicted and where does it live?",
|
100 |
+
img=img,
|
101 |
+
img_name="",
|
102 |
+
)
|
103 |
+
return msg
|
104 |
+
|
105 |
+
|
106 |
+
def mona_lisa_filename_and_img() -> tuple[str, PIL.Image.Image]:
|
107 |
+
img_name = "mona-lisa.png"
|
108 |
+
img = get_test_img(img_name)
|
109 |
+
return img_name, img
|
110 |
+
|
111 |
+
|
112 |
+
def get_test_img(name: str) -> PIL.Image.Image:
|
113 |
+
path = file_for_test(name)
|
114 |
+
return PIL.Image.open(path)
|
115 |
+
|
116 |
+
|
117 |
+
def file_for_test(name: str) -> Path:
|
118 |
+
return Path(__file__).parent.parent / "test-files" / name
|
119 |
+
|
120 |
+
|
121 |
+
def is_ci() -> bool:
|
122 |
+
is_ci_str: str = os.environ.get("CI", "false").lower()
|
123 |
+
return is_ci_str != "false"
|
tests/test_bundler.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from llmlib.bundler import Bundler
|
3 |
+
from llmlib.bundler_request import BundlerRequest
|
4 |
+
from llmlib.base_llm import LLM, Message
|
5 |
+
import pytest
|
6 |
+
from llmlib.model_registry import ModelEntry, ModelRegistry
|
7 |
+
|
8 |
+
|
9 |
+
def test_model_id_on_gpu():
|
10 |
+
b = Bundler(filled_model_registry())
|
11 |
+
assert b.id_of_model_on_gpu() is None
|
12 |
+
b.set_model_on_gpu(GpuLLM.model_id)
|
13 |
+
assert b.id_of_model_on_gpu() == GpuLLM.model_id
|
14 |
+
|
15 |
+
|
16 |
+
def test_get_response():
|
17 |
+
b = Bundler(filled_model_registry())
|
18 |
+
msgs = [Message(role="user", msg="hello")]
|
19 |
+
request = BundlerRequest(model_id=GpuLLM.model_id, msgs=msgs)
|
20 |
+
expected_response = GpuLLM().complete_msgs2(msgs)
|
21 |
+
actual_response: str = b.get_response(request)
|
22 |
+
assert actual_response == expected_response
|
23 |
+
assert b.id_of_model_on_gpu() == GpuLLM.model_id
|
24 |
+
|
25 |
+
|
26 |
+
def test_bundler_multiple_responses():
|
27 |
+
b = Bundler(filled_model_registry())
|
28 |
+
models = [GpuLLM(), GpuLLM2(), NonGpuLLM()]
|
29 |
+
msgs = [Message(role="user", msg="hello")]
|
30 |
+
|
31 |
+
expected_responses = [m.complete_msgs2(msgs) for m in models]
|
32 |
+
assert expected_responses[0] != expected_responses[1]
|
33 |
+
|
34 |
+
actual_responses = [
|
35 |
+
b.get_response(BundlerRequest(model_id=m.model_id, msgs=msgs)) for m in models
|
36 |
+
]
|
37 |
+
assert actual_responses == expected_responses
|
38 |
+
|
39 |
+
last_gpu_model = [m for m in models if m.requires_gpu_exclusively][-1]
|
40 |
+
assert b.id_of_model_on_gpu() == last_gpu_model.model_id
|
41 |
+
|
42 |
+
|
43 |
+
def test_set_model_on_gpu():
|
44 |
+
b = Bundler(filled_model_registry())
|
45 |
+
b.set_model_on_gpu(GpuLLM.model_id)
|
46 |
+
assert b.id_of_model_on_gpu() == GpuLLM.model_id
|
47 |
+
|
48 |
+
with pytest.raises(AssertionError):
|
49 |
+
b.set_model_on_gpu("invalid")
|
50 |
+
assert b.id_of_model_on_gpu() == GpuLLM.model_id
|
51 |
+
|
52 |
+
b.set_model_on_gpu(NonGpuLLM.model_id)
|
53 |
+
gpu_model_is_still_loaded: bool = b.id_of_model_on_gpu() == GpuLLM.model_id
|
54 |
+
assert gpu_model_is_still_loaded
|
55 |
+
|
56 |
+
|
57 |
+
def filled_model_registry() -> ModelRegistry:
|
58 |
+
model_entries = [
|
59 |
+
ModelEntry.from_cls_with_id(GpuLLM),
|
60 |
+
ModelEntry.from_cls_with_id(GpuLLM2),
|
61 |
+
ModelEntry.from_cls_with_id(NonGpuLLM),
|
62 |
+
]
|
63 |
+
return ModelRegistry(model_entries)
|
64 |
+
|
65 |
+
|
66 |
+
@dataclass
|
67 |
+
class GpuLLM(LLM):
|
68 |
+
model_id = "gpu-llm-model"
|
69 |
+
requires_gpu_exclusively = True
|
70 |
+
|
71 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
72 |
+
return "gpu msg"
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class GpuLLM2(LLM):
|
77 |
+
model_id = "gpu-llm-model-2"
|
78 |
+
requires_gpu_exclusively = True
|
79 |
+
|
80 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
81 |
+
return "gpu msg 2"
|
82 |
+
|
83 |
+
|
84 |
+
@dataclass
|
85 |
+
class NonGpuLLM(LLM):
|
86 |
+
model_id = "non-gpu-llm-model"
|
87 |
+
requires_gpu_exclusively = False
|
88 |
+
|
89 |
+
def complete_msgs2(self, msgs: list[Message]) -> str:
|
90 |
+
return "non-gpu message"
|
tests/test_gemini.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from llmlib.gemini.media_description import Request
|
3 |
+
import pytest
|
4 |
+
|
5 |
+
from tests.helpers import file_for_test, is_ci
|
6 |
+
|
7 |
+
|
8 |
+
@pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
|
9 |
+
def test_gemini_vision():
|
10 |
+
files: list[Path] = [
|
11 |
+
file_for_test("pyramid.jpg"),
|
12 |
+
file_for_test("mona-lisa.png"),
|
13 |
+
file_for_test("some-audio.mp3"),
|
14 |
+
]
|
15 |
+
|
16 |
+
for path in files:
|
17 |
+
assert path.exists()
|
18 |
+
|
19 |
+
req = Request(
|
20 |
+
media_files=files, prompt="Describe this combined images/audio/text in detail."
|
21 |
+
)
|
22 |
+
description: str = req.fetch_media_description().lower()
|
23 |
+
assert "pyramid" in description
|
24 |
+
assert "mona lisa" in description
|
25 |
+
assert "horses are very fast" in description
|
tests/test_llama3.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmlib.base_llm import LLM
|
2 |
+
import pytest
|
3 |
+
|
4 |
+
from llmlib.llama3.llama3_vision_70b_quantized import LLama3Vision70BQuantized
|
5 |
+
from llmlib.llama3.llama3_vision_8b import LLama3Vision8B
|
6 |
+
|
7 |
+
from .helpers import (
|
8 |
+
assert_model_knows_capital_of_france,
|
9 |
+
assert_model_recognizes_pyramid_in_image,
|
10 |
+
is_ci,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
15 |
+
def test_llama_8b():
|
16 |
+
model: LLM = LLama3Vision8B()
|
17 |
+
assert_model_knows_capital_of_france(model)
|
18 |
+
assert_model_recognizes_pyramid_in_image(model)
|
19 |
+
|
20 |
+
|
21 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
22 |
+
def test_llama_70b_quantized():
|
23 |
+
model: LLM = LLama3Vision70BQuantized()
|
24 |
+
assert_model_knows_capital_of_france(model)
|
25 |
+
# model cannot recognize mona lisa yet
|
26 |
+
# assert_model_recognized_mona_lisa_in_image(model)
|
tests/test_minicpm.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmlib.minicpm import MiniCPM
|
2 |
+
import pytest
|
3 |
+
from .helpers import (
|
4 |
+
assert_model_knows_capital_of_france,
|
5 |
+
assert_model_recognizes_afd_in_video,
|
6 |
+
assert_model_recognizes_pyramid_in_image,
|
7 |
+
is_ci,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
12 |
+
def test_minicpm_vision():
|
13 |
+
model = MiniCPM()
|
14 |
+
assert_model_knows_capital_of_france(model)
|
15 |
+
assert_model_recognizes_pyramid_in_image(model)
|
16 |
+
assert_model_recognizes_afd_in_video(model)
|
tests/test_openai.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmlib.base_llm import LLM, Message
|
2 |
+
from PIL import Image
|
3 |
+
from llmlib.rest_api.restapi_client import encode_as_png_in_base64
|
4 |
+
import pytest
|
5 |
+
from llmlib.openai.openai_completion import (
|
6 |
+
OpenAIModel,
|
7 |
+
extract_msgs,
|
8 |
+
)
|
9 |
+
from deepdiff import DeepDiff
|
10 |
+
|
11 |
+
from .helpers import (
|
12 |
+
assert_model_knows_capital_of_france,
|
13 |
+
assert_model_recognizes_pyramid_in_image,
|
14 |
+
is_ci,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def test_extract_msgs():
|
19 |
+
img = Image.new(mode="RGB", size=(1, 1))
|
20 |
+
msgs = [
|
21 |
+
Message(role="user", msg="Hi"),
|
22 |
+
Message(role="assistant", msg="Hi!"),
|
23 |
+
Message(role="user", msg="Describe:", img=img, img_name="img1"),
|
24 |
+
]
|
25 |
+
messages = extract_msgs(msgs)
|
26 |
+
expected_msgs = [
|
27 |
+
{"role": "user", "content": "Hi"},
|
28 |
+
{"role": "assistant", "content": "Hi!"},
|
29 |
+
{
|
30 |
+
"role": "user",
|
31 |
+
"content": [
|
32 |
+
{"type": "text", "text": "Describe:"},
|
33 |
+
{
|
34 |
+
"type": "image_url",
|
35 |
+
"image_url": {
|
36 |
+
"url": f"data:image/png;base64,{encode_as_png_in_base64(img)}",
|
37 |
+
},
|
38 |
+
},
|
39 |
+
],
|
40 |
+
},
|
41 |
+
]
|
42 |
+
assert DeepDiff(messages, expected_msgs) == {}
|
43 |
+
|
44 |
+
|
45 |
+
@pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
|
46 |
+
def test_openai_vision():
|
47 |
+
model: LLM = OpenAIModel()
|
48 |
+
assert_model_knows_capital_of_france(model)
|
49 |
+
assert_model_recognizes_pyramid_in_image(model)
|
tests/test_phi_3.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmlib.base_llm import Message
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from llmlib.phi3.phi3 import GenConf, Phi3Vision, extract_imgs_and_dicts, pad_left
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from .helpers import (
|
9 |
+
assert_model_can_answer_batch_of_img_prompts,
|
10 |
+
assert_model_can_answer_batch_of_text_prompts,
|
11 |
+
assert_model_knows_capital_of_france,
|
12 |
+
assert_model_rejects_unsupported_batches,
|
13 |
+
get_mona_lisa_completion,
|
14 |
+
is_ci,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def test_extract_imgs_and_dicts():
|
19 |
+
img1 = Image.new(mode="RGB", size=(1, 1))
|
20 |
+
img2 = Image.new(mode="RGB", size=(1, 1))
|
21 |
+
msgs = [
|
22 |
+
a_msg(),
|
23 |
+
a_msg(img=img1, img_name="img1"),
|
24 |
+
a_msg(img=img2, img_name="img2"),
|
25 |
+
a_msg(),
|
26 |
+
a_msg(img=img1, img_name="img1"),
|
27 |
+
a_msg(img=img2, img_name="img2"),
|
28 |
+
]
|
29 |
+
images, messages = extract_imgs_and_dicts(msgs)
|
30 |
+
assert len(images) == 2
|
31 |
+
assert len(messages) == 6
|
32 |
+
assert "<|image_1|>" in messages[1]["content"]
|
33 |
+
assert "<|image_1|>" in messages[4]["content"]
|
34 |
+
assert "<|image_2|>" in messages[5]["content"]
|
35 |
+
assert "<|image_2|>" in messages[2]["content"]
|
36 |
+
|
37 |
+
|
38 |
+
def a_msg(img: Image.Image | None = None, img_name: str | None = None) -> Message:
|
39 |
+
return Message(role="user", msg="", img=img, img_name=img_name)
|
40 |
+
|
41 |
+
|
42 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
43 |
+
def test_phi3_vision(model: Phi3Vision):
|
44 |
+
assert_model_knows_capital_of_france(model)
|
45 |
+
answer: str = get_mona_lisa_completion(model)
|
46 |
+
assert isinstance(answer, str)
|
47 |
+
|
48 |
+
|
49 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
50 |
+
def test_phi3_batching(model: Phi3Vision):
|
51 |
+
assert_model_can_answer_batch_of_text_prompts(model)
|
52 |
+
assert_model_can_answer_batch_of_img_prompts(model)
|
53 |
+
|
54 |
+
|
55 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
56 |
+
def test_phi3_invalid_input(model: Phi3Vision):
|
57 |
+
assert_model_rejects_unsupported_batches(model)
|
58 |
+
|
59 |
+
|
60 |
+
@pytest.fixture(scope="module")
|
61 |
+
def model():
|
62 |
+
yield Phi3Vision(GenConf(max_new_tokens=30))
|
63 |
+
|
64 |
+
|
65 |
+
def test_padleft():
|
66 |
+
pad_token = -1
|
67 |
+
seqs = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6])]
|
68 |
+
expected = torch.tensor([[1, 2, 3], [pad_token, 4, 5], [pad_token, pad_token, 6]])
|
69 |
+
actual = pad_left(seqs, pad_token)
|
70 |
+
assert torch.equal(actual, expected)
|
tests/test_rest_api.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi.testclient import TestClient
|
2 |
+
from llmlib.bundler_request import BundlerRequest
|
3 |
+
import llmlib.rest_api.restapi_client as llmclient
|
4 |
+
from llmlib.rest_api.restapi_server import create_fastapi_app
|
5 |
+
from llmlib.phi3.phi3 import Phi3Vision
|
6 |
+
import pytest
|
7 |
+
from .helpers import is_ci, mona_lisa_message
|
8 |
+
|
9 |
+
|
10 |
+
def app():
|
11 |
+
return TestClient(create_fastapi_app())
|
12 |
+
|
13 |
+
|
14 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
15 |
+
def test_rest_api_get_completion():
|
16 |
+
breq: BundlerRequest = _mona_lisa_request()
|
17 |
+
response = llmclient.get_completion_from_rest_api(source=app(), breq=breq)
|
18 |
+
assert response.status_code == 200, response.content
|
19 |
+
assert "portrait" in response.json()["response"].lower()
|
20 |
+
|
21 |
+
|
22 |
+
def test_rest_api_get_models():
|
23 |
+
response = llmclient.get_models(source=app())
|
24 |
+
assert response.status_code == 200, response.content
|
25 |
+
assert len(response.json()) > 3
|
26 |
+
|
27 |
+
|
28 |
+
@pytest.mark.skip(reason="This test requires the REST API to be running")
|
29 |
+
def test_rest_api_integration_test():
|
30 |
+
breq: BundlerRequest = _mona_lisa_request()
|
31 |
+
response = llmclient.get_completion_from_rest_api(breq)
|
32 |
+
llmclient.clear_gpu()
|
33 |
+
assert response.status_code == 200, response.content
|
34 |
+
assert "portrait" in response.json()["response"].lower()
|
35 |
+
|
36 |
+
|
37 |
+
def _mona_lisa_request() -> BundlerRequest:
|
38 |
+
msg = mona_lisa_message()
|
39 |
+
some_valid_modelid: str = Phi3Vision.model_id
|
40 |
+
return BundlerRequest(model_id=some_valid_modelid, msgs=[msg])
|
tests/test_whisper.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llmlib.whisper import Whisper
|
2 |
+
import pytest
|
3 |
+
from tests.helpers import is_ci, file_for_test
|
4 |
+
|
5 |
+
|
6 |
+
@pytest.fixture(scope="module")
|
7 |
+
def model() -> Whisper:
|
8 |
+
return Whisper()
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
12 |
+
def test_transcription(model: Whisper):
|
13 |
+
audio_file = str(file_for_test(name="some-audio.flac")) # Librispeech sample 2
|
14 |
+
expected_transcription = "before he had time to answer a much encumbered vera burst into the room with the question i say can i leave these here these were a small black pig and a lusty specimen of black-red game-cock"
|
15 |
+
actual_transcription: str = model.transcribe_file(audio_file)
|
16 |
+
assert actual_transcription == expected_transcription
|
17 |
+
|
18 |
+
|
19 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
20 |
+
def test_video_transcription(model: Whisper):
|
21 |
+
video_file = str(file_for_test("video.mp4"))
|
22 |
+
expected_fragment = (
|
23 |
+
"Die Unionsparteien oder deren Politiker sind heute wichtige Offiziere"
|
24 |
+
)
|
25 |
+
transcription = model.transcribe_file(video_file)
|
26 |
+
assert expected_fragment in transcription
|
27 |
+
|
28 |
+
|
29 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
30 |
+
def test_translation(model: Whisper):
|
31 |
+
german_video = str(file_for_test("video.mp4"))
|
32 |
+
translation: str = model.transcribe_file(german_video, translate=True)
|
33 |
+
assert "The parties and their politicians" in translation
|
34 |
+
|
35 |
+
|
36 |
+
@pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
|
37 |
+
def test_long_video_transcription(model: Whisper):
|
38 |
+
video_file = str(file_for_test("long-video.mp4"))
|
39 |
+
transcription: str = model.transcribe_file(video_file)
|
40 |
+
assert isinstance(transcription, str)
|