tomasruiz commited on
Commit
41d24d2
·
1 Parent(s): 7475e8c

Include code of llmapp from Github

Browse files
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ models/
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111
+ .pdm.toml
112
+ .pdm-python
113
+ .pdm-build/
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
Dockerfile-llm-app ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.4.0-cuda12.4-cudnn9-devel
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update
6
+ RUN apt-get install -y build-essential
7
+ RUN apt-get install -y git
8
+
9
+ COPY requirements.txt requirements.txt
10
+ RUN pip install -r requirements.txt
11
+ RUN CMAKE_ARGS="-DLLAMA_CUBLAS=on" pip install "llama-cpp-python<=0.2.79.0"
12
+ COPY *.py ./
13
+ ADD llmlib ./llmlib
14
+ RUN pip install -e llmlib
15
+ ADD .streamlit .streamlit
16
+
17
+ #CMD [ "python", "--version"]
18
+ # CMD ["nvidia-smi"]
19
+ # CMD ["nvcc", "--version"]
20
+ CMD [ "python", "-m", "streamlit", "run", "st_app.py", "--server.port", "8020"]
Makefile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ run_rest_api:
2
+ fastapi dev rest_api.py --port 8030
README.md CHANGED
@@ -5,8 +5,28 @@ colorFrom: red
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.41.1
8
+ app_file: st_app.py
9
  pinned: false
10
  ---
11
 
12
+ # LLM Multimodal Vibe-Check
13
+ We use this streamlit app to chat with different multimodal open-source and propietary LLMs. The idea is to quickly assess qualitatively (vibe-check) whether the model understands the nuance of harmful language.
14
+
15
+ https://github.com/user-attachments/assets/2fb49053-651c-4cc9-b102-92a392a3c473
16
+
17
+ ## Run Streamlit App
18
+ In the `docker-compose.yml` file, you will need to change the volume to point to your own huggingface model cache. To run the app, use the following command:
19
+ ```bash
20
+ docker compose up videoapp
21
+ ```
22
+
23
+ ### Run Only Inference Server
24
+ ```bash
25
+ docker compose up rest_api
26
+ ```
27
+
28
+ ## Structure
29
+ * Each multimodal LLM has a different way of consuming image(s). This codebase unifies the different interfaces e.g. of Phi-3, MinCPM, OpenAI GPT-4o, etc. This is done with a single base class `LLM` (interface) which is then implemented by each concrete model. You can find these implementation in the directory `llmlib/llmlib/`.
30
+ * The open-source implementation are based on the `transformers` library. I have experimented with `vLLM`, but it made the GPU run OOM. More fiddling is needed.
31
+ * I have extracted a REST API using `FastAPI` to decouple the frontend streamlit code from the inference server.
32
+ * The app supports small open-source models atm, because the inference server is running a single 24GB VRAM GPU. We will hopefully scale this backend up soon.
app.py DELETED
@@ -1,4 +0,0 @@
1
- import streamlit as st
2
-
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
docker-compose.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ x-common-gpu: &common-gpu
2
+ build:
3
+ dockerfile: Dockerfile-llm-app
4
+ environment:
5
+ - OPENAI_API_KEY=${OPENAI_API_KEY}
6
+ - HF_HOME=/app/.cache/huggingface
7
+ - HF_TOKEN=${HF_TOKEN}
8
+ - LLMS_REST_API_KEY=${LLMS_REST_API_KEY}
9
+ - BUGSNAG_API_KEY=${BUGSNAG_API_KEY}
10
+ deploy:
11
+ resources:
12
+ reservations:
13
+ devices:
14
+ - driver: nvidia
15
+ count: all
16
+ capabilities: [gpu]
17
+ volumes:
18
+ - /home/tomasruiz/.cache/huggingface:/app/.cache/huggingface
19
+
20
+ services:
21
+
22
+ llmapp:
23
+ <<: *common-gpu
24
+ ports:
25
+ - "8020:8020"
26
+ rest_api:
27
+ <<: *common-gpu
28
+ ports:
29
+ - "8030:8030"
30
+ command: fastapi run rest_api.py --port 8030
31
+ hostname: rest_api
llmlib/README.md ADDED
File without changes
llmlib/llmlib/__init__.py ADDED
File without changes
llmlib/llmlib/base_llm.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Literal, Self
3
+ from PIL import Image
4
+
5
+
6
+ from dataclasses import dataclass
7
+
8
+
9
+ @dataclass
10
+ class Message:
11
+ role: Literal["user", "assistant"]
12
+ msg: str
13
+ img_name: str | None = None
14
+ img: Image.Image | None = None
15
+
16
+ @classmethod
17
+ def from_prompt(cls, prompt: str) -> Self:
18
+ return cls(role="user", msg=prompt)
19
+
20
+
21
+ class LLM:
22
+ model_id: str
23
+ requires_gpu_exclusively: bool = False
24
+
25
+ def complete_msgs2(self, msgs: list[Message]) -> str:
26
+ raise NotImplementedError
27
+
28
+ def complete_batch(self, batch: list[list[Message]]) -> list[str]:
29
+ raise NotImplementedError
30
+
31
+ def video_prompt(self, video_path: Path, prompt: str) -> str:
32
+ raise NotImplementedError
33
+
34
+ @classmethod
35
+ def get_warnings(cls) -> list[str]:
36
+ return []
llmlib/llmlib/bundler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import logging
3
+
4
+ from .bundler_request import BundlerRequest
5
+ from .model_registry import ModelEntry, ModelRegistry
6
+ from .base_llm import LLM
7
+ import torch
8
+ import gc
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class Bundler:
16
+ """Makes sure that only 1 model occupies the GPU at a time."""
17
+
18
+ registry: ModelRegistry = field(default_factory=ModelRegistry)
19
+ model_on_gpu: LLM | None = None
20
+ id2_nongpu_model: dict[str, LLM] = field(default_factory=dict)
21
+
22
+ def id_of_model_on_gpu(self) -> str | None:
23
+ return None if self.model_on_gpu is None else self.model_on_gpu.model_id
24
+
25
+ def get_response(self, req: BundlerRequest) -> str:
26
+ e: ModelEntry = self.registry.get_entry(model_id=req.model_id)
27
+ model: LLM = self._get_model_instance(e=e)
28
+ return model.complete_msgs2(req.msgs)
29
+
30
+ def _get_model_instance(self, e: ModelEntry) -> LLM:
31
+ if e.clazz.requires_gpu_exclusively:
32
+ self.set_model_on_gpu(model_id=e.model_id)
33
+ model: LLM = self.model_on_gpu
34
+ else:
35
+ if e.model_id not in self.id2_nongpu_model:
36
+ self.id2_nongpu_model[e.model_id] = e.ctor()
37
+ model: LLM = self.id2_nongpu_model[e.model_id]
38
+ return model
39
+
40
+ def set_model_on_gpu(self, model_id: str) -> None:
41
+ if (
42
+ self.id_of_model_on_gpu() is not None
43
+ and self.id_of_model_on_gpu() == model_id
44
+ ):
45
+ return
46
+ assert model_id in self.registry.all_model_ids()
47
+
48
+ e: ModelEntry = self.registry.get_entry(model_id)
49
+ if not e.clazz.requires_gpu_exclusively:
50
+ logger.info(
51
+ "Model does not require GPU exclusively. Ignoring set_model_on_gpu() call."
52
+ )
53
+ return
54
+
55
+ self.clear_model_on_gpu()
56
+ self.model_on_gpu = e.ctor()
57
+
58
+ def clear_model_on_gpu(self):
59
+ self.model_on_gpu = None
60
+ gc.collect()
61
+ torch.cuda.empty_cache()
llmlib/llmlib/bundler_request.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_llm import Message
2
+
3
+
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class BundlerRequest:
9
+ model_id: str
10
+ msgs: list[Message]
llmlib/llmlib/gemini/__init__.py ADDED
File without changes
llmlib/llmlib/gemini/media_description.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/video-understanding
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from logging import getLogger
7
+ from pathlib import Path
8
+ from typing import Literal
9
+ from google.cloud import storage
10
+ from google.cloud.storage import transfer_manager
11
+ import proto
12
+ from vertexai.generative_models import (
13
+ GenerativeModel,
14
+ Part,
15
+ HarmCategory,
16
+ HarmBlockThreshold,
17
+ GenerationResponse,
18
+ )
19
+
20
+ import vertexai
21
+
22
+ logger = getLogger(__name__)
23
+
24
+ project_id = "css-lehrbereich" # from google cloud console
25
+ frankfurt = "europe-west3" # https://cloud.google.com/about/locations#europe
26
+
27
+
28
+ class Buckets:
29
+ temp = "css-temp-bucket-for-vertex"
30
+ output = "css-vertex-output"
31
+
32
+
33
+ def storage_uri(bucket: str, blob_name: str) -> str:
34
+ """blob_name starts without a slash"""
35
+ return "gs://%s/%s" % (bucket, blob_name)
36
+
37
+
38
+ class Models:
39
+ gemini_pro = "models/gemini-1.5-pro"
40
+ gemini_flash = "models/gemini-1.5-flash"
41
+
42
+
43
+ available_models = [Models.gemini_pro, Models.gemini_flash]
44
+
45
+
46
+ @dataclass
47
+ class Request:
48
+ media_files: list[Path]
49
+ model_name: Literal[Models.gemini_pro, Models.gemini_flash] = Models.gemini_pro
50
+ prompt: str = "Describe this video in detail."
51
+
52
+ def fetch_media_description(self) -> str:
53
+ return fetch_media_description(self)
54
+
55
+
56
+ def fetch_media_description(req: Request) -> str:
57
+ # TODO: Always delete the video in the end. Perhaps use finally block.
58
+ blobs = upload_files(files=req.media_files)
59
+
60
+ init_vertex()
61
+ model = GenerativeModel(req.model_name)
62
+
63
+ prompt = req.prompt
64
+ logger.info("Calling the Google API. model_name='%s'", req.model_name)
65
+ contents = [
66
+ Part.from_uri(storage_uri(Buckets.temp, b.name), mime_type=mime_type(b.name))
67
+ for b in blobs
68
+ ]
69
+ contents.append(prompt)
70
+ response: GenerationResponse = model.generate_content(
71
+ contents=contents,
72
+ generation_config={"temperature": 0.0},
73
+ safety_settings=block_nothing(),
74
+ )
75
+ logger.info("Token usage: %s", proto.Message.to_dict(response.usage_metadata))
76
+
77
+ if len(response.candidates) == 0:
78
+ raise ResponseRefusedException(
79
+ "No candidates in response. prompt_feedback='%s'" % response.prompt_feedback
80
+ )
81
+
82
+ enum = type(response.candidates[0].finish_reason)
83
+ if response.candidates[0].finish_reason in {enum.SAFETY, enum.PROHIBITED_CONTENT}:
84
+ raise UnsafeResponseError(safety_ratings=response.candidates[0].safety_ratings)
85
+
86
+ for blob in blobs:
87
+ blob.delete()
88
+ logger.info("Deleted %d blob(s)", len(blobs))
89
+
90
+ return response.text
91
+
92
+
93
+ def init_vertex() -> None:
94
+ vertexai.init(project=project_id, location=frankfurt)
95
+
96
+
97
+ def mime_type(file_name: str) -> str:
98
+ mapping = {
99
+ ".txt": "text/plain",
100
+ ".jpg": "image/jpeg",
101
+ ".png": "image/png",
102
+ ".flac": "audio/flac",
103
+ ".mp3": "audio/mpeg",
104
+ ".mp4": "video/mp4",
105
+ }
106
+ for ext, mime in mapping.items():
107
+ if file_name.endswith(ext):
108
+ return mime
109
+ raise ValueError(f"Unknown mime type for file: {file_name}")
110
+
111
+
112
+ def upload_files(files: list[Path]) -> list[storage.Blob]:
113
+ logger.info("Uploading %d file(s)", len(files))
114
+ bucket = _bucket(name=Buckets.temp)
115
+ files_str = [str(f) for f in files]
116
+ blobs = [bucket.blob(file.name) for file in files]
117
+ transfer_manager.upload_many(
118
+ file_blob_pairs=zip(files_str, blobs),
119
+ skip_if_exists=True,
120
+ raise_exception=True,
121
+ )
122
+ logger.info("Completed file(s) upload")
123
+ return blobs
124
+
125
+
126
+ def _bucket(name: str) -> storage.Bucket:
127
+ client = storage.Client(project=project_id)
128
+ return client.bucket(name)
129
+
130
+
131
+ def upload_single_file(file: Path, bucket: str, blob_name: str) -> storage.Blob:
132
+ logger.info("Uploading file '%s' to bucket '%s' as '%s'", file, bucket, blob_name)
133
+ bucket: storage.Bucket = _bucket(name=bucket)
134
+ blob = bucket.blob(blob_name)
135
+ if blob.exists():
136
+ logger.info("Blob '%s' already exists. Overwriting it...", blob_name)
137
+ blob.upload_from_filename(str(file))
138
+ return blob
139
+
140
+
141
+ def block_nothing() -> dict[HarmCategory, HarmBlockThreshold]:
142
+ return {
143
+ HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
144
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
145
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
146
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
147
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
148
+ HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY: HarmBlockThreshold.BLOCK_NONE,
149
+ }
150
+
151
+
152
+ class UnsafeResponseError(Exception):
153
+ def __init__(self, safety_ratings: list) -> None:
154
+ super().__init__(
155
+ "The response was blocked by Google due to safety reasons. Categories: %s"
156
+ % safety_ratings
157
+ )
158
+ self.safety_categories = safety_ratings
159
+
160
+
161
+ class ResponseRefusedException(Exception):
162
+ pass
llmlib/llmlib/llama3/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ models/
llmlib/llmlib/llama3/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ Installation for the quantized model in `llama_cpp`:
3
+ ```shell
4
+ CMAKE_ARGS="-DLLAMA_CUBLAS=on -DCUDA_PATH=/usr/local/cuda-12.5 -DCUDAToolkit_ROOT=/usr/local/cuda-12.5 -DCUDAToolkit_INCLUDE_DIR=/usr/local/cuda-12/include -DCUDAToolkit_LIBRARY_DIR=/usr/local/cuda-12.5/lib64" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir
5
+ ```
llmlib/llmlib/llama3/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .llama3_vision_8b import LLama3Vision8B
2
+
3
+ __all__ = ["LLama3Vision8B"]
llmlib/llmlib/llama3/llama3_vision_8b.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlib.base_llm import Message
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import BitsAndBytesConfig
5
+ from llmlib.base_llm import LLM
6
+ from PIL import Image
7
+
8
+ _model_id = "qresearch/llama-3-vision-alpha-hf"
9
+
10
+
11
+ class LLama3Vision8B(LLM):
12
+ model_id = _model_id
13
+ requires_gpu_exclusively = True
14
+
15
+ def __init__(self):
16
+ self.model = create_model()
17
+ self.tokenizer = create_tokenizer()
18
+
19
+ def complete_msgs2(self, msgs: list[Message]) -> str:
20
+ if len(msgs) != 1:
21
+ raise ValueError(
22
+ f"model='{_model_id}' supports only one message by the user."
23
+ )
24
+ msg = msgs[0]
25
+ if msg.role != "user":
26
+ raise ValueError(
27
+ f"model='{_model_id}' supports only a role=user message, not role={msg.role}."
28
+ )
29
+
30
+ # 2024-06-20: Model does not accept image=None, therefore we create a small white image
31
+ if msg.img is None:
32
+ empty_img = Image.new("RGB", (3, 3), color="white")
33
+ image = empty_img
34
+ else:
35
+ image = msg.img
36
+
37
+ response: str = self.tokenizer.decode(
38
+ self.model.answer_question(image, msg.msg, self.tokenizer),
39
+ skip_special_tokens=True,
40
+ )
41
+ return response
42
+
43
+ @classmethod
44
+ def get_warnings(cls) -> list[str]:
45
+ return ["This model only accepts one message by the user at a time."]
46
+
47
+
48
+ def create_model():
49
+ bnb_cfg = BitsAndBytesConfig(
50
+ load_in_4bit=True,
51
+ bnb_4bit_compute_dtype=torch.float16,
52
+ llm_int8_skip_modules=["mm_projector", "vision_model"],
53
+ )
54
+
55
+ return AutoModelForCausalLM.from_pretrained(
56
+ _model_id,
57
+ trust_remote_code=True,
58
+ torch_dtype=torch.float16,
59
+ quantization_config=bnb_cfg,
60
+ )
61
+
62
+
63
+ def create_tokenizer():
64
+ return AutoTokenizer.from_pretrained(
65
+ _model_id,
66
+ use_fast=True,
67
+ )
llmlib/llmlib/minicpm.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Any
4
+ from llmlib.base_llm import LLM, Message
5
+ import torch
6
+ from transformers import AutoModel, AutoTokenizer
7
+ from PIL import Image
8
+ from decord import VideoReader, cpu # pip install decord
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ _model_name = "openbmb/MiniCPM-V-2_6"
14
+
15
+
16
+ class MiniCPM(LLM):
17
+ temperature: float
18
+
19
+ model_id = _model_name
20
+ requires_gpu_exclusively = True
21
+
22
+ def __init__(self, temperature: float = 0.0, model=None) -> None:
23
+ if model is None:
24
+ model = _create_model()
25
+ self.model = model
26
+ self.tokenizer = _create_tokenizer()
27
+ self.temperature = temperature
28
+
29
+ def chat(self, prompt: str) -> str:
30
+ return self.complete_msgs2([Message(role="user", msg=prompt)])
31
+
32
+ def complete_msgs2(self, msgs: list[Message]) -> str:
33
+ dict_msgs = [_convert_msg_to_dict(m) for m in msgs]
34
+ use_sampling = self.temperature > 0.0
35
+ res = self.model.chat(
36
+ image=None,
37
+ msgs=dict_msgs,
38
+ tokenizer=self.tokenizer,
39
+ sampling=use_sampling,
40
+ temperature=self.temperature,
41
+ )
42
+ return res
43
+
44
+ def video_prompt(self, video_path: Path, prompt: str) -> str:
45
+ return video_prompt(self, video_path, prompt)
46
+
47
+
48
+ def _create_tokenizer():
49
+ return AutoTokenizer.from_pretrained(_model_name, trust_remote_code=True)
50
+
51
+
52
+ def _create_model():
53
+ model = AutoModel.from_pretrained(
54
+ _model_name,
55
+ trust_remote_code=True,
56
+ attn_implementation="flash_attention_2",
57
+ torch_dtype=torch.bfloat16,
58
+ )
59
+ model.eval().cuda()
60
+ return model
61
+
62
+
63
+ def _convert_msg_to_dict(msg: Message) -> dict:
64
+ if msg.img is None:
65
+ content: list[Any] = [msg.msg]
66
+ else:
67
+ content = [msg.img.convert("RGB"), msg.msg]
68
+ return {"role": msg.role, "content": content}
69
+
70
+
71
+ def to_listof_imgs(video_path: Path) -> list[Image.Image]:
72
+ """
73
+ Return one frame per second from the video.
74
+ If the video is longer than MAX_NUM_FRAMES, sample MAX_NUM_FRAMES frames.
75
+ """
76
+ MAX_NUM_FRAMES = 64 # if cuda OOM set a smaller number
77
+ assert video_path.exists(), video_path
78
+ vr = VideoReader(str(video_path), ctx=cpu(0))
79
+ sample_fps = round(vr.get_avg_fps() / 1) # FPS
80
+ frame_idx = [i for i in range(0, len(vr), sample_fps)]
81
+ if len(frame_idx) > MAX_NUM_FRAMES:
82
+ frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
83
+ imgs = vr.get_batch(frame_idx).asnumpy()
84
+ imgs = [Image.fromarray(v.astype("uint8")) for v in imgs]
85
+ return imgs
86
+
87
+
88
+ def uniform_sample(xs, n):
89
+ gap = len(xs) / n
90
+ idxs = [int(i * gap + gap / 2) for i in range(n)]
91
+ return [xs[i] for i in idxs]
92
+
93
+
94
+ def video_prompt(self: MiniCPM, video_path: Path, prompt: str) -> str:
95
+ imgs = to_listof_imgs(video_path)
96
+ logger.info("Video turned into %d images", len(imgs))
97
+ msgs = [
98
+ {"role": "user", "content": [prompt] + imgs},
99
+ ]
100
+ # Set decode params for video
101
+ params = {}
102
+ params["use_image_id"] = False
103
+ params["max_slice_nums"] = 2 # use 1 if cuda OOM and video resolution > 448*448
104
+ answer = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer, **params)
105
+ return answer
llmlib/llmlib/model_registry.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Self
2
+ from dataclasses import dataclass, field
3
+ from typing import Callable
4
+ from .base_llm import LLM
5
+
6
+
7
+ @dataclass
8
+ class ModelEntry:
9
+ model_id: str
10
+ clazz: type[LLM]
11
+ ctor: Callable[[], LLM]
12
+ warnings: list[str] = field(default_factory=list)
13
+
14
+ @classmethod
15
+ def from_cls_with_id(cls, T: type[LLM]) -> Self:
16
+ return cls(model_id=T.model_id, clazz=T, ctor=T, warnings=T.get_warnings())
17
+
18
+
19
+ @dataclass
20
+ class ModelRegistry:
21
+ models: list[ModelEntry] = field(default_factory=list)
22
+
23
+ def get_entry(self, model_id: str) -> ModelEntry:
24
+ id2entry = {entry.model_id: entry for entry in self.models}
25
+ return id2entry[model_id]
26
+
27
+ def all_model_ids(self) -> list[str]:
28
+ return [entry.model_id for entry in self.models]
llmlib/llmlib/openai/openai_completion.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL.Image import Image
3
+ from ..base_llm import LLM, Message
4
+ from ..rest_api.restapi_client import encode_as_png_in_base64
5
+ from openai import OpenAI, ChatCompletion
6
+ from multiprocessing import Pool
7
+
8
+ _default_model = "gpt-4o-mini"
9
+
10
+ client = OpenAI() # must be outside of the class to avoid pickling issues
11
+
12
+
13
+ class OpenAIModel(LLM):
14
+ model_ids = [_default_model, "gpt-4o"]
15
+
16
+ def __init__(self, model: str = _default_model):
17
+ self.model = model
18
+
19
+ def complete(self, prompt: str) -> str:
20
+ return complete(model=self.model, prompt=prompt)
21
+
22
+ def complete_msgs(self, messages: list[dict], images: list[Image] = []) -> str:
23
+ return complete_msgs(model=self.model, messages=messages)
24
+
25
+ def complete_many(
26
+ self, prompts: list[str], n_workers: int = os.cpu_count()
27
+ ) -> list[str]:
28
+ return complete_many(model=self.model, prompts=prompts, n_workers=n_workers)
29
+
30
+ def complete_msgs2(self, msgs: list[Message]) -> str:
31
+ messages: list[dict] = extract_msgs(msgs)
32
+ return self.complete_msgs(messages)
33
+
34
+
35
+ def complete_many(
36
+ model: str, prompts: list[str], n_workers: int = os.cpu_count()
37
+ ) -> list[str]:
38
+ print("Calling OpenAI API")
39
+ with Pool(processes=n_workers) as pool:
40
+ args = [(model, p) for p in prompts]
41
+ return pool.starmap(complete, args)
42
+
43
+
44
+ def complete(model: str, prompt: str) -> str:
45
+ messages = [{"role": "user", "content": prompt}]
46
+ return complete_msgs(model=model, messages=messages)
47
+
48
+
49
+ def complete_msgs(model: str, messages: list[dict]) -> str:
50
+ completion: ChatCompletion = client.chat.completions.create(
51
+ model=model, temperature=0.0, messages=messages
52
+ )
53
+ assert len(completion.choices) == 1
54
+ return completion.choices[0].message.content
55
+
56
+
57
+ def postprocess(response: str) -> str:
58
+ return response.lower().strip(".").strip()
59
+
60
+
61
+ def extract_msgs(msgs: list[Message]) -> list[dict]:
62
+ return [extract_msg(m) for m in msgs]
63
+
64
+
65
+ def extract_msg(msg: Message) -> dict:
66
+ if msg.img is None:
67
+ return {"role": msg.role, "content": msg.msg}
68
+ img_in_base64 = encode_as_png_in_base64(msg.img)
69
+ return {
70
+ "role": msg.role,
71
+ "content": [
72
+ {"type": "text", "text": msg.msg},
73
+ {
74
+ "type": "image_url",
75
+ "image_url": {"url": f"data:image/png;base64,{img_in_base64}"},
76
+ },
77
+ ],
78
+ }
llmlib/llmlib/phi3/phi3.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+ from llmlib.base_llm import Message
4
+ from torch import Tensor
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoProcessor
7
+ from PIL import Image
8
+ from llmlib.base_llm import LLM
9
+ from transformers.image_processing_utils import BatchFeature
10
+
11
+ model_id = "microsoft/Phi-3-vision-128k-instruct"
12
+
13
+
14
+ @dataclass
15
+ class GenConf:
16
+ max_new_tokens: int = 500
17
+ temperature: float = 0.0
18
+
19
+ def to_dict(self) -> dict[str, Any]:
20
+ do_sample: bool = self.temperature != 0.0
21
+ return {
22
+ "max_new_tokens": self.max_new_tokens,
23
+ "temperature": self.temperature if do_sample else None,
24
+ "do_sample": do_sample,
25
+ }
26
+
27
+
28
+ class Phi3Vision(LLM):
29
+ model_id = model_id
30
+ requires_gpu_exclusively = True
31
+
32
+ def __init__(self, gen_conf: GenConf | None = None):
33
+ self.model = create_model()
34
+ self.processor = create_processor()
35
+ self.gen_conf = GenConf() if gen_conf is None else gen_conf
36
+
37
+ def complete(self, prompt: str) -> str:
38
+ msg = Message(role="user", msg=prompt)
39
+ return completion(llm=self, batch=[[msg]])[0]
40
+
41
+ def complete_msgs2(self, msgs: list[Message]) -> str:
42
+ return completion(llm=self, batch=[msgs])[0]
43
+
44
+ def complete_batch(self, batch: list[list[Message]]) -> list[str]:
45
+ return completion(llm=self, batch=batch)
46
+
47
+
48
+ def extract_imgs_and_dicts(msgs: list[Message]) -> tuple[list[Image.Image], list[dict]]:
49
+ """
50
+ Phi3 expects in the prompts placehodlers for images in the form <|image_X|>, where X is the image number.
51
+ It also requires the images as a separate array of PIL images.
52
+ This function extracts the images from the messages and creates the placeholders.
53
+ It makes sure to avoid duplication in the images and placeholders.
54
+ """
55
+ img_names = list(dict.fromkeys(m.img_name for m in msgs if m.img_name is not None))
56
+ placeholders = {
57
+ img_name: f"<|image_{i}|>" for i, img_name in enumerate(img_names, 1)
58
+ }
59
+ imgs = {}
60
+ for msg in msgs:
61
+ if msg.img is not None and msg.img_name not in imgs:
62
+ imgs[msg.img_name] = msg.img
63
+ images = list(imgs.values())
64
+
65
+ messages: list[dict] = [] # entries are {"role": str, "content": str}
66
+ for m in msgs:
67
+ if m.img is not None and m.img_name is not None:
68
+ img_placeholder = placeholders[m.img_name]
69
+ content = f"{img_placeholder}\n{m.msg}"
70
+ else:
71
+ content = m.msg
72
+ messages.append({"role": m.role, "content": content})
73
+ return images, messages
74
+
75
+
76
+ def create_model(model_id: str = model_id):
77
+ return AutoModelForCausalLM.from_pretrained(
78
+ model_id, device_map="cuda", trust_remote_code=True, torch_dtype="auto"
79
+ )
80
+
81
+
82
+ def create_processor(model_id: str = model_id):
83
+ return AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
84
+
85
+
86
+ def convert_to_messages(prompts: list[str]) -> list[list[dict]]:
87
+ return [[{"role": "user", "content": prompt}] for prompt in prompts]
88
+
89
+
90
+ def completion(llm: Phi3Vision, batch: list[list[Message]]) -> list[str]:
91
+ reject_invalid_batches(batch)
92
+ listof_inputs: list[BatchFeature] = []
93
+ for messages in batch:
94
+ images, messages_dicts = extract_imgs_and_dicts(messages)
95
+ prompt: str = llm.processor.tokenizer.apply_chat_template(
96
+ messages_dicts, tokenize=False, add_generation_prompt=True
97
+ )
98
+ imgs = None if len(images) == 0 else images
99
+ inputs = llm.processor(prompt, imgs, return_tensors="pt").to("cuda")
100
+ listof_inputs.append(inputs)
101
+
102
+ pad_token_id = llm.processor.tokenizer.pad_token_id
103
+ inputs = stack_and_pad_inputs(listof_inputs, pad_token_id=pad_token_id)
104
+
105
+ generate_ids: Tensor = llm.model.generate(
106
+ **inputs,
107
+ eos_token_id=llm.processor.tokenizer.eos_token_id,
108
+ **llm.gen_conf.to_dict(),
109
+ )
110
+ # the prompt is included in the output, so we need to drop it.
111
+ generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
112
+
113
+ responses: list[str] = llm.processor.batch_decode(
114
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
115
+ )
116
+ return responses
117
+
118
+
119
+ def reject_invalid_batches(batch: list[list[Message]]) -> None:
120
+ """
121
+ Valid batches are:
122
+ - batch of lenght 1, or
123
+ - batches with only a single message per entry, AND
124
+ - all messages have an image, or
125
+ - all messages are text only.
126
+ """
127
+ if len(batch) <= 1:
128
+ return
129
+ if any(len(msgs) != 1 for msgs in batch):
130
+ raise ValueError("Batch must contain only one message per entry.")
131
+ any_msg_has_img = any(msg.img is not None for msgs in batch for msg in msgs)
132
+ any_msg_is_no_img = any(msg.img is None for msgs in batch for msg in msgs)
133
+ if any_msg_has_img and any_msg_is_no_img:
134
+ raise ValueError("Batch must contain an image in every entry or none at all.")
135
+
136
+
137
+ def pad_left(seqs: list[torch.Tensor], pad_token_id: int) -> torch.Tensor:
138
+ max_len = max(len(seq) for seq in seqs)
139
+ padded = torch.full((len(seqs), max_len), pad_token_id)
140
+ for i, seq in enumerate(seqs):
141
+ padded[i, -len(seq) :] = seq
142
+ return padded
143
+
144
+
145
+ def stack_and_pad_inputs(inputs: list[BatchFeature], pad_token_id: int) -> BatchFeature:
146
+ listof_input_ids = [i.input_ids[0] for i in inputs]
147
+ new_input_ids = pad_left(listof_input_ids, pad_token_id=pad_token_id)
148
+ data = dict(
149
+ input_ids=new_input_ids,
150
+ attention_mask=(new_input_ids != pad_token_id).long(),
151
+ )
152
+ has_imgs: bool = "pixel_values" in inputs[0]
153
+ if has_imgs:
154
+ data["pixel_values"] = torch.cat([i.pixel_values for i in inputs], dim=0)
155
+ data["image_sizes"] = torch.cat([i.image_sizes for i in inputs], dim=0)
156
+
157
+ return BatchFeature(data).to("cuda")
llmlib/llmlib/phi3/phi35.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from llmlib.phi3.phi3 import stack_and_pad_inputs
3
+ import requests
4
+ from transformers import AutoModelForCausalLM
5
+ from transformers import AutoProcessor
6
+ from transformers.image_processing_utils import BatchFeature
7
+
8
+ model_id = "microsoft/Phi-3.5-vision-instruct"
9
+
10
+
11
+ model = AutoModelForCausalLM.from_pretrained(
12
+ model_id,
13
+ device_map="cuda",
14
+ trust_remote_code=True,
15
+ torch_dtype="auto",
16
+ _attn_implementation="flash_attention_2",
17
+ )
18
+
19
+ # for best performance, use num_crops=4 for multi-frame, num_crops=16 for single-frame.
20
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, num_crops=4)
21
+
22
+ links = [
23
+ "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg",
24
+ "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-2-2048.jpg",
25
+ "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-3-2048.jpg",
26
+ ]
27
+ images = [Image.open(requests.get(link, stream=True).raw) for link in links]
28
+ batch = [
29
+ [{"role": "user", "content": "<|image_1|>Who is mentioned in this picture?"}],
30
+ [{"role": "user", "content": "<|image_1|>What is the title of this image?"}],
31
+ [{"role": "user", "content": "<|image_1|>What icons are shown in this image?"}],
32
+ ]
33
+ # batch = [
34
+ # [{"role": "user", "content": "What is the capital of France?"}],
35
+ # [{"role": "user", "content": "How does one make a cookie that is vegetarian?"}],
36
+ # ]
37
+ # images = [None, None]
38
+
39
+ # BatchFeature(s) are the output of the processor, which is used as input to the model.
40
+ listof_inputs: list[BatchFeature] = []
41
+ for messages, image in zip(batch, images):
42
+ prompt = processor.tokenizer.apply_chat_template(
43
+ messages, tokenize=False, add_generation_prompt=True
44
+ )
45
+ images_ = None if image is None else [image]
46
+ inputs = processor(prompt, images_, return_tensors="pt").to("cuda:0")
47
+ listof_inputs.append(inputs)
48
+
49
+
50
+ inputs = stack_and_pad_inputs(
51
+ listof_inputs, pad_token_id=processor.tokenizer.pad_token_id
52
+ )
53
+
54
+ generation_args = {
55
+ "max_new_tokens": 1000,
56
+ "temperature": None,
57
+ "do_sample": False,
58
+ }
59
+
60
+ generate_ids = model.generate(
61
+ **inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args
62
+ )
63
+
64
+ generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :]
65
+ responses: list[str] = processor.batch_decode(
66
+ generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
67
+ )
68
+
69
+ for p, r in zip(batch, responses):
70
+ print(p)
71
+ print(r)
72
+ print()
llmlib/llmlib/pixtral_demo.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vllm import LLM
2
+ from vllm.sampling_params import SamplingParams
3
+
4
+ if __name__ == "__main__":
5
+ model_name = "mistralai/Pixtral-12B-2409"
6
+
7
+ sampling_params = SamplingParams(max_tokens=8192)
8
+
9
+ llm = LLM(model=model_name, gpu_memory_utilization=0.1, tokenizer_mode="mistral")
10
+
11
+ prompt = "Describe this image in one sentence."
12
+ image_url = "https://picsum.photos/id/237/200/300"
13
+
14
+ messages = [
15
+ {
16
+ "role": "user",
17
+ "content": [
18
+ {"type": "text", "text": prompt},
19
+ {"type": "image_url", "image_url": {"url": image_url}},
20
+ ],
21
+ },
22
+ ]
23
+
24
+ outputs = llm.chat(messages, sampling_params=sampling_params)
25
+
26
+ print(outputs[0].outputs[0].text)
llmlib/llmlib/rest_api/__init__.py ADDED
File without changes
llmlib/llmlib/rest_api/restapi_client.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import logging
4
+ import os
5
+ import requests
6
+ from PIL import Image
7
+ from ..base_llm import Message
8
+ from ..bundler_request import BundlerRequest
9
+ from pydantic import BaseModel
10
+ from typing import Literal
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def encode_as_png_in_base64(img: Image.Image) -> str:
16
+ stream = io.BytesIO()
17
+ img.save(stream, format="PNG")
18
+ return base64.b64encode(stream.getvalue()).decode("utf-8")
19
+
20
+
21
+ class MsgDto(BaseModel):
22
+ role: Literal["user", "assistant"]
23
+ msg: str
24
+ img_name: str | None = None
25
+ img_str: str | None = None
26
+
27
+ @classmethod
28
+ def from_bundler_msg(cls, msg: Message) -> "MsgDto":
29
+ return cls(
30
+ role=msg.role,
31
+ msg=msg.msg,
32
+ img_name=msg.img_name,
33
+ img_str=encode_as_png_in_base64(msg.img) if msg.img is not None else None,
34
+ )
35
+
36
+
37
+ def to_bundler_msg(msg: MsgDto) -> Message:
38
+ return Message(
39
+ role=msg.role,
40
+ msg=msg.msg,
41
+ img_name=msg.img_name,
42
+ img=Image.open(io.BytesIO(base64.b64decode(msg.img_str)))
43
+ if msg.img_str
44
+ else None,
45
+ )
46
+
47
+
48
+ class RequestDto(BaseModel):
49
+ model: str
50
+ msgs: list[MsgDto]
51
+
52
+ @classmethod
53
+ def from_bundler_request(cls, breq: BundlerRequest) -> "RequestDto":
54
+ return cls(
55
+ model=breq.model_id,
56
+ msgs=[MsgDto.from_bundler_msg(msg) for msg in breq.msgs],
57
+ )
58
+
59
+ model_config = {
60
+ "json_schema_extra": {
61
+ "examples": [
62
+ {
63
+ "model": "microsoft/Phi-3-vision-128k-instruct",
64
+ "msgs": [{"role": "user", "msg": "What is the capital of France?"}],
65
+ }
66
+ ]
67
+ }
68
+ }
69
+
70
+
71
+ _api_host = os.environ.get("LLMS_REST_API_HOST", "http://localhost") + ":8030"
72
+
73
+
74
+ def _headers():
75
+ return {"X-API-Key": os.environ["LLMS_REST_API_KEY"]}
76
+
77
+
78
+ def get_completion_from_rest_api(
79
+ breq: BundlerRequest, source=requests, **kwargs
80
+ ) -> requests.Response:
81
+ req = RequestDto.from_bundler_request(breq)
82
+ url = _api_host + "/completion/"
83
+ logger.info(f"Sending completion request to '{url}'.")
84
+ return source.post(
85
+ url=url,
86
+ json=req.model_dump(),
87
+ headers=_headers(),
88
+ **kwargs,
89
+ )
90
+
91
+
92
+ def get_models(source=requests) -> requests.Response:
93
+ return source.get(url=_api_host + "/models/", headers=_headers())
94
+
95
+
96
+ def clear_gpu(source=requests) -> requests.Response:
97
+ return source.post(url=_api_host + "/clear-gpu/", headers=_headers())
llmlib/llmlib/rest_api/restapi_server.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import Depends, FastAPI, HTTPException, Security
3
+ from fastapi.responses import JSONResponse
4
+ from fastapi.security import APIKeyHeader
5
+ from llmlib.bundler import Bundler
6
+ from llmlib.bundler_request import BundlerRequest
7
+ from llmlib.rest_api.restapi_client import RequestDto, to_bundler_msg
8
+ from llmlib.runtime import filled_model_registry
9
+
10
+
11
+ import os
12
+
13
+ import bugsnag
14
+ from bugsnag.asgi import BugsnagMiddleware
15
+
16
+
17
+ def create_fastapi_app() -> FastAPI:
18
+ bugsnag.configure(api_key=os.environ["BUGSNAG_API_KEY"])
19
+
20
+ bundler = Bundler(registry=filled_model_registry())
21
+ app = FastAPI()
22
+ app.add_middleware(BugsnagMiddleware)
23
+
24
+ header = APIKeyHeader(name="X-API-Key")
25
+
26
+ def is_authorized(api_key: str = Security(header)) -> bool:
27
+ if api_key != os.environ["LLMS_REST_API_KEY"]:
28
+ raise HTTPException(status_code=401, detail="Invalid API Key")
29
+ return True
30
+
31
+ @app.get("/models/")
32
+ def _(_=Depends(is_authorized)):
33
+ return bundler.registry.all_model_ids()
34
+
35
+ @app.post("/completion/")
36
+ def _(req: RequestDto, _=Depends(is_authorized)):
37
+ breq = BundlerRequest(
38
+ model_id=req.model, msgs=[to_bundler_msg(msg) for msg in req.msgs]
39
+ )
40
+ return {"response": bundler.get_response(breq)}
41
+
42
+ @app.post("/clear-gpu/")
43
+ def _(_=Depends(is_authorized)):
44
+ bundler.clear_model_on_gpu()
45
+ return {"status": "success"}
46
+
47
+ @app.exception_handler(torch.cuda.OutOfMemoryError)
48
+ def _(req, exc):
49
+ return JSONResponse(
50
+ status_code=500,
51
+ content={
52
+ "detail": "Error. GPU out of memory. There might be another workload running on the GPU."
53
+ },
54
+ )
55
+
56
+ return app
llmlib/llmlib/runtime.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .minicpm import MiniCPM
2
+ from .llama3 import LLama3Vision8B
3
+ from .model_registry import ModelEntry, ModelRegistry
4
+ from .openai.openai_completion import OpenAIModel
5
+ from .phi3.phi3 import Phi3Vision
6
+
7
+
8
+ def filled_model_registry() -> ModelRegistry:
9
+ return ModelRegistry(
10
+ models=[
11
+ ModelEntry.from_cls_with_id(Phi3Vision),
12
+ ModelEntry.from_cls_with_id(MiniCPM),
13
+ ModelEntry.from_cls_with_id(LLama3Vision8B),
14
+ *[
15
+ ModelEntry(
16
+ model_id=id_, clazz=OpenAIModel, ctor=lambda: OpenAIModel(model=id_)
17
+ )
18
+ for id_ in OpenAIModel.model_ids
19
+ ],
20
+ ]
21
+ )
llmlib/llmlib/whisper.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from logging import getLogger
3
+ from typing import Any
4
+ import warnings
5
+ import torch
6
+ from transformers import (
7
+ AutoModelForSpeechSeq2Seq,
8
+ AutoProcessor,
9
+ pipeline,
10
+ AutomaticSpeechRecognitionPipeline,
11
+ )
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ def create_whisper_pipe() -> AutomaticSpeechRecognitionPipeline:
17
+ device = "cuda"
18
+ torch_dtype = torch.float16
19
+
20
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
21
+ model_id,
22
+ torch_dtype=torch_dtype,
23
+ low_cpu_mem_usage=True,
24
+ use_safetensors=True,
25
+ attn_implementation="flash_attention_2",
26
+ )
27
+ model.to(device)
28
+
29
+ processor = AutoProcessor.from_pretrained(model_id)
30
+
31
+ pipe = pipeline(
32
+ "automatic-speech-recognition",
33
+ model=model,
34
+ tokenizer=processor.tokenizer,
35
+ feature_extractor=processor.feature_extractor,
36
+ torch_dtype=torch_dtype,
37
+ device=device,
38
+ )
39
+
40
+ return pipe
41
+
42
+
43
+ model_id = "openai/whisper-large-v3-turbo"
44
+
45
+
46
+ @dataclass
47
+ class Whisper:
48
+ model_id = model_id
49
+
50
+ pipe: AutomaticSpeechRecognitionPipeline = field(
51
+ default_factory=create_whisper_pipe
52
+ )
53
+
54
+ def transcribe_file(self, file: str, translate=False) -> str:
55
+ assert isinstance(file, str)
56
+ logger.info("Transcribing file: %s", file)
57
+ try:
58
+ return self._transcribe(file, translate, return_timestamps=False)
59
+ except ValueError as e:
60
+ if "Please either pass `return_timestamps=True`" in repr(e):
61
+ logger.info("File is >30s, transcribing with timestamps: %s", file)
62
+ return self._transcribe(file, translate, return_timestamps=True)
63
+ raise
64
+
65
+ def _transcribe(self, file: str, translate: bool, return_timestamps: bool) -> str:
66
+ kwargs: dict[str, Any] = {"return_timestamps": return_timestamps}
67
+ if translate:
68
+ kwargs["generate_kwargs"] = {"language": "english"}
69
+ # ignore this warning:
70
+ # .../site-packages/transformers/models/whisper/generation_whisper.py:496: FutureWarning: The input name `inputs` is deprecated. Please make sure to use `input_features` instead.
71
+ with warnings.catch_warnings(action="ignore", category=FutureWarning):
72
+ # data["chunks"] contains the timestamped transcriptions
73
+ data = self.pipe(file, **kwargs)
74
+ return data["text"].strip()
llmlib/pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "llmlib"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Tomas Ruiz <[email protected]>"]
6
+ readme = "README.md"
7
+
8
+ [tool.poetry.dependencies]
9
+ python = "^3.11"
10
+ bugsnag = "^4.7.1"
11
+ decord = "^0.6.0"
12
+ google-cloud-aiplatform = "^1.64"
13
+ # I cannot add the dependencies below. I suspect that torch is a build-time dependency for flash-attn, or something like that.
14
+ # transformers = "^4.44.2"
15
+ # accelerate = "^0.34.2"
16
+ # flash-attn = "^2.6.3"
17
+ # torch = "^2.4.0"
18
+
19
+ [build-system]
20
+ requires = ["poetry-core"]
21
+ build-backend = "poetry.core.masonry.api"
login_mask_simple.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WARNING: This file is duplicated in the projects: llm-app, tiktok. Make sure changes are reflected in all projects!
3
+
4
+ Copied from https://docs.streamlit.io/knowledge-base/deploy/authentication-without-sso
5
+ """
6
+
7
+ from functools import cache
8
+ import logging
9
+ import os
10
+ import streamlit as st
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def login_form():
17
+ with st.form("Credentials"):
18
+ st.text_input("Password", type="password", key="password")
19
+ st.form_submit_button("Log in", on_click=password_entered)
20
+
21
+
22
+ def password_entered():
23
+ correct_pw = os.environ["LLMS_REST_API_KEY"]
24
+ is_correct: bool = st.session_state.pop("password") == correct_pw
25
+ st.session_state["password_correct"] = is_correct
26
+
27
+
28
+ def check_password() -> bool:
29
+ """Return `True` if the user is allowed to access the app, `False` otherwise."""
30
+ skip_pw: bool = os.environ.get("USE_STREAMLIT_PASSWORD", "true").lower() == "false"
31
+ if skip_pw:
32
+ log_password_check_skipped()
33
+ return True
34
+
35
+ """Returns `True` if the user had a correct password."""
36
+
37
+ # Return True if the username + password is validated.
38
+ if st.session_state.get("password_correct", False):
39
+ return True
40
+
41
+ # Show inputs for username + password.
42
+ login_form()
43
+ if "password_correct" in st.session_state:
44
+ st.error("😕 Password incorrect")
45
+ return False
46
+
47
+
48
+ @cache # Print only once per session
49
+ def log_password_check_skipped():
50
+ logger.info("Skipping password check because USE_STREAMLIT_PASSWORD=false.")
readme/llm-app-demo.mp4 ADDED
Binary file (884 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytest
2
+ deepdiff
3
+ pillow
4
+ openai
5
+ transformers
6
+ torch
7
+ streamlit
8
+ bitsandbytes
9
+ accelerate
10
+ flash-attn
11
+ fastapi[standard]
12
+ ./llmlib
rest_api.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from llmlib.rest_api.restapi_server import create_fastapi_app
3
+
4
+ app: FastAPI = create_fastapi_app()
st_app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import streamlit as st
3
+ from llmlib.runtime import filled_model_registry
4
+ from llmlib.model_registry import ModelEntry, ModelRegistry
5
+ from llmlib.base_llm import Message
6
+ from llmlib.bundler import Bundler
7
+ from llmlib.bundler_request import BundlerRequest
8
+ from login_mask_simple import check_password
9
+
10
+ if not check_password():
11
+ st.stop()
12
+
13
+ st.set_page_config(page_title="LLM App", layout="wide")
14
+
15
+ st.title("LLM App")
16
+
17
+
18
+ model_registry: ModelRegistry = filled_model_registry()
19
+
20
+
21
+ @st.cache_resource()
22
+ def create_model_bundler() -> Bundler:
23
+ return Bundler(registry=model_registry)
24
+
25
+
26
+ def display_warnings(r: ModelRegistry, model_id: str) -> None:
27
+ e1: ModelEntry = r.get_entry(model_id)
28
+ if len(e1.warnings) > 0:
29
+ st.warning(" \n".join(e1.warnings))
30
+
31
+
32
+ cs = st.columns(2)
33
+ with cs[0]:
34
+ model1_id: str = st.selectbox("Select model", model_registry.all_model_ids())
35
+ display_warnings(model_registry, model1_id)
36
+ with cs[1]:
37
+ if "img-key" not in st.session_state:
38
+ st.session_state["img-key"] = 0
39
+ image = st.file_uploader("Include an image", key=st.session_state["img-key"])
40
+
41
+ if "messages1" not in st.session_state:
42
+ st.session_state.messages1 = [] # list[Message]
43
+ st.session_state.messages2 = [] # list[Message]
44
+
45
+ if st.button("Restart chat"):
46
+ st.session_state.messages1 = [] # list[Message]
47
+ st.session_state.messages2 = [] # list[Message]
48
+
49
+
50
+ def render_messages(msgs: list[Message]) -> None:
51
+ for msg in msgs:
52
+ render_message(msg)
53
+
54
+
55
+ def render_message(msg: Message):
56
+ with st.chat_message(msg.role):
57
+ if msg.img_name is not None:
58
+ render_img(msg)
59
+ st.markdown(msg.msg)
60
+
61
+
62
+ def render_img(msg: Message):
63
+ st.image(msg.img, caption=msg.img_name, width=400)
64
+
65
+
66
+ n_cols = 1
67
+ cs = st.columns(n_cols)
68
+ render_messages(st.session_state.messages1)
69
+
70
+ prompt = st.chat_input("Type here")
71
+ if prompt is None:
72
+ st.stop()
73
+
74
+ msg = Message(
75
+ role="user",
76
+ msg=prompt,
77
+ img_name=image.name if image is not None else None,
78
+ img=Image.open(image) if image is not None else None,
79
+ )
80
+
81
+ if image is not None:
82
+ st.session_state["img-key"] += 1
83
+
84
+ st.session_state.messages1.append(msg)
85
+ render_message(msg)
86
+
87
+ model_bundler: Bundler = create_model_bundler()
88
+
89
+ with st.spinner("Initializing model..."):
90
+ model_bundler.set_model_on_gpu(model_id=model1_id)
91
+
92
+ with st.spinner("Generating response..."):
93
+ req = BundlerRequest(model_id=model1_id, msgs=st.session_state.messages1)
94
+ response = model_bundler.get_response(req)
95
+ msg = Message(role="assistant", msg=response)
96
+ st.session_state.messages1.append(msg)
97
+ render_message(msg)
test-files/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.jpg
2
+ *.png
3
+ *.mp3
4
+ *.mp4
5
+ *.flac
test-files/toxicity-prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Evaluate whether this image could potentially be harmful and if so, explain why
tests/__init__.py ADDED
File without changes
tests/helpers.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ import PIL
4
+ from llmlib.base_llm import LLM, Message
5
+ import pytest
6
+
7
+
8
+ def assert_model_knows_capital_of_france(model: LLM) -> None:
9
+ response: str = model.complete_msgs2(
10
+ msgs=[Message(role="user", msg="What is the capital of France?")]
11
+ )
12
+ assert "paris" in response.lower()
13
+
14
+
15
+ def assert_model_can_answer_batch_of_text_prompts(model: LLM) -> None:
16
+ prompts = [
17
+ "What is the capital of France?",
18
+ "What continent is south of Europe?",
19
+ "What are the two tallest mountains in the world?",
20
+ ]
21
+ batch = [[Message.from_prompt(prompt)] for prompt in prompts]
22
+ responses = model.complete_batch(batch=batch)
23
+ assert len(responses) == 3
24
+ assert "paris" in responses[0].lower()
25
+ assert "africa" in responses[1].lower()
26
+ assert "everest" in responses[2].lower()
27
+
28
+
29
+ def assert_model_can_answer_batch_of_img_prompts(model: LLM) -> None:
30
+ batch = [
31
+ [pyramid_message()],
32
+ [forest_message()],
33
+ [fish_message()],
34
+ ]
35
+ responses = model.complete_batch(batch=batch)
36
+ assert len(responses) == 3
37
+ assert "pyramid" in responses[0].lower()
38
+ assert "forest" in responses[1].lower()
39
+ assert "fish" in responses[2].lower()
40
+
41
+
42
+ def assert_model_rejects_unsupported_batches(model: LLM) -> None:
43
+ mixed_textonly_and_img_batch = [
44
+ [Message.from_prompt("What is the capital of France?")],
45
+ [pyramid_message()],
46
+ ]
47
+ err_msg = "Batch must contain an image in every entry or none at all."
48
+ with pytest.raises(ValueError, match=err_msg):
49
+ model.complete_batch(mixed_textonly_and_img_batch)
50
+
51
+
52
+ def assert_model_recognizes_pyramid_in_image(model: LLM):
53
+ msg = pyramid_message()
54
+ answer: str = model.complete_msgs2(msgs=[msg])
55
+ assert "pyramid" in answer.lower()
56
+
57
+
58
+ def assert_model_recognizes_afd_in_video(model: LLM):
59
+ video_path = file_for_test("video.mp4")
60
+ question = "Describe the video in english"
61
+ answer: str = model.video_prompt(video_path, question)
62
+ assert "alternative für deutschland" in answer.lower(), answer
63
+
64
+
65
+ def get_mona_lisa_completion(model: LLM) -> str:
66
+ msg: Message = mona_lisa_message()
67
+ answer: str = model.complete_msgs2(msgs=[msg])
68
+ return answer
69
+
70
+
71
+ def mona_lisa_message() -> Message:
72
+ _, img = mona_lisa_filename_and_img()
73
+ prompt = "What is in the image?"
74
+ msg = Message(role="user", msg=prompt, img=img, img_name="")
75
+ return msg
76
+
77
+
78
+ def pyramid_message() -> Message:
79
+ img_name = "pyramid.jpg"
80
+ img = get_test_img(img_name)
81
+ msg = Message(role="user", msg="What is in the image?", img=img, img_name="")
82
+ return msg
83
+
84
+
85
+ def forest_message() -> Message:
86
+ img_name = "forest.jpg"
87
+ img = get_test_img(img_name)
88
+ msg = Message(
89
+ role="user", msg="Describe what you see in the picture.", img=img, img_name=""
90
+ )
91
+ return msg
92
+
93
+
94
+ def fish_message() -> Message:
95
+ img_name = "fish.jpg"
96
+ img = get_test_img(img_name)
97
+ msg = Message(
98
+ role="user",
99
+ msg="What animal is depicted and where does it live?",
100
+ img=img,
101
+ img_name="",
102
+ )
103
+ return msg
104
+
105
+
106
+ def mona_lisa_filename_and_img() -> tuple[str, PIL.Image.Image]:
107
+ img_name = "mona-lisa.png"
108
+ img = get_test_img(img_name)
109
+ return img_name, img
110
+
111
+
112
+ def get_test_img(name: str) -> PIL.Image.Image:
113
+ path = file_for_test(name)
114
+ return PIL.Image.open(path)
115
+
116
+
117
+ def file_for_test(name: str) -> Path:
118
+ return Path(__file__).parent.parent / "test-files" / name
119
+
120
+
121
+ def is_ci() -> bool:
122
+ is_ci_str: str = os.environ.get("CI", "false").lower()
123
+ return is_ci_str != "false"
tests/test_bundler.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from llmlib.bundler import Bundler
3
+ from llmlib.bundler_request import BundlerRequest
4
+ from llmlib.base_llm import LLM, Message
5
+ import pytest
6
+ from llmlib.model_registry import ModelEntry, ModelRegistry
7
+
8
+
9
+ def test_model_id_on_gpu():
10
+ b = Bundler(filled_model_registry())
11
+ assert b.id_of_model_on_gpu() is None
12
+ b.set_model_on_gpu(GpuLLM.model_id)
13
+ assert b.id_of_model_on_gpu() == GpuLLM.model_id
14
+
15
+
16
+ def test_get_response():
17
+ b = Bundler(filled_model_registry())
18
+ msgs = [Message(role="user", msg="hello")]
19
+ request = BundlerRequest(model_id=GpuLLM.model_id, msgs=msgs)
20
+ expected_response = GpuLLM().complete_msgs2(msgs)
21
+ actual_response: str = b.get_response(request)
22
+ assert actual_response == expected_response
23
+ assert b.id_of_model_on_gpu() == GpuLLM.model_id
24
+
25
+
26
+ def test_bundler_multiple_responses():
27
+ b = Bundler(filled_model_registry())
28
+ models = [GpuLLM(), GpuLLM2(), NonGpuLLM()]
29
+ msgs = [Message(role="user", msg="hello")]
30
+
31
+ expected_responses = [m.complete_msgs2(msgs) for m in models]
32
+ assert expected_responses[0] != expected_responses[1]
33
+
34
+ actual_responses = [
35
+ b.get_response(BundlerRequest(model_id=m.model_id, msgs=msgs)) for m in models
36
+ ]
37
+ assert actual_responses == expected_responses
38
+
39
+ last_gpu_model = [m for m in models if m.requires_gpu_exclusively][-1]
40
+ assert b.id_of_model_on_gpu() == last_gpu_model.model_id
41
+
42
+
43
+ def test_set_model_on_gpu():
44
+ b = Bundler(filled_model_registry())
45
+ b.set_model_on_gpu(GpuLLM.model_id)
46
+ assert b.id_of_model_on_gpu() == GpuLLM.model_id
47
+
48
+ with pytest.raises(AssertionError):
49
+ b.set_model_on_gpu("invalid")
50
+ assert b.id_of_model_on_gpu() == GpuLLM.model_id
51
+
52
+ b.set_model_on_gpu(NonGpuLLM.model_id)
53
+ gpu_model_is_still_loaded: bool = b.id_of_model_on_gpu() == GpuLLM.model_id
54
+ assert gpu_model_is_still_loaded
55
+
56
+
57
+ def filled_model_registry() -> ModelRegistry:
58
+ model_entries = [
59
+ ModelEntry.from_cls_with_id(GpuLLM),
60
+ ModelEntry.from_cls_with_id(GpuLLM2),
61
+ ModelEntry.from_cls_with_id(NonGpuLLM),
62
+ ]
63
+ return ModelRegistry(model_entries)
64
+
65
+
66
+ @dataclass
67
+ class GpuLLM(LLM):
68
+ model_id = "gpu-llm-model"
69
+ requires_gpu_exclusively = True
70
+
71
+ def complete_msgs2(self, msgs: list[Message]) -> str:
72
+ return "gpu msg"
73
+
74
+
75
+ @dataclass
76
+ class GpuLLM2(LLM):
77
+ model_id = "gpu-llm-model-2"
78
+ requires_gpu_exclusively = True
79
+
80
+ def complete_msgs2(self, msgs: list[Message]) -> str:
81
+ return "gpu msg 2"
82
+
83
+
84
+ @dataclass
85
+ class NonGpuLLM(LLM):
86
+ model_id = "non-gpu-llm-model"
87
+ requires_gpu_exclusively = False
88
+
89
+ def complete_msgs2(self, msgs: list[Message]) -> str:
90
+ return "non-gpu message"
tests/test_gemini.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from llmlib.gemini.media_description import Request
3
+ import pytest
4
+
5
+ from tests.helpers import file_for_test, is_ci
6
+
7
+
8
+ @pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
9
+ def test_gemini_vision():
10
+ files: list[Path] = [
11
+ file_for_test("pyramid.jpg"),
12
+ file_for_test("mona-lisa.png"),
13
+ file_for_test("some-audio.mp3"),
14
+ ]
15
+
16
+ for path in files:
17
+ assert path.exists()
18
+
19
+ req = Request(
20
+ media_files=files, prompt="Describe this combined images/audio/text in detail."
21
+ )
22
+ description: str = req.fetch_media_description().lower()
23
+ assert "pyramid" in description
24
+ assert "mona lisa" in description
25
+ assert "horses are very fast" in description
tests/test_llama3.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlib.base_llm import LLM
2
+ import pytest
3
+
4
+ from llmlib.llama3.llama3_vision_70b_quantized import LLama3Vision70BQuantized
5
+ from llmlib.llama3.llama3_vision_8b import LLama3Vision8B
6
+
7
+ from .helpers import (
8
+ assert_model_knows_capital_of_france,
9
+ assert_model_recognizes_pyramid_in_image,
10
+ is_ci,
11
+ )
12
+
13
+
14
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
15
+ def test_llama_8b():
16
+ model: LLM = LLama3Vision8B()
17
+ assert_model_knows_capital_of_france(model)
18
+ assert_model_recognizes_pyramid_in_image(model)
19
+
20
+
21
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
22
+ def test_llama_70b_quantized():
23
+ model: LLM = LLama3Vision70BQuantized()
24
+ assert_model_knows_capital_of_france(model)
25
+ # model cannot recognize mona lisa yet
26
+ # assert_model_recognized_mona_lisa_in_image(model)
tests/test_minicpm.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlib.minicpm import MiniCPM
2
+ import pytest
3
+ from .helpers import (
4
+ assert_model_knows_capital_of_france,
5
+ assert_model_recognizes_afd_in_video,
6
+ assert_model_recognizes_pyramid_in_image,
7
+ is_ci,
8
+ )
9
+
10
+
11
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
12
+ def test_minicpm_vision():
13
+ model = MiniCPM()
14
+ assert_model_knows_capital_of_france(model)
15
+ assert_model_recognizes_pyramid_in_image(model)
16
+ assert_model_recognizes_afd_in_video(model)
tests/test_openai.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlib.base_llm import LLM, Message
2
+ from PIL import Image
3
+ from llmlib.rest_api.restapi_client import encode_as_png_in_base64
4
+ import pytest
5
+ from llmlib.openai.openai_completion import (
6
+ OpenAIModel,
7
+ extract_msgs,
8
+ )
9
+ from deepdiff import DeepDiff
10
+
11
+ from .helpers import (
12
+ assert_model_knows_capital_of_france,
13
+ assert_model_recognizes_pyramid_in_image,
14
+ is_ci,
15
+ )
16
+
17
+
18
+ def test_extract_msgs():
19
+ img = Image.new(mode="RGB", size=(1, 1))
20
+ msgs = [
21
+ Message(role="user", msg="Hi"),
22
+ Message(role="assistant", msg="Hi!"),
23
+ Message(role="user", msg="Describe:", img=img, img_name="img1"),
24
+ ]
25
+ messages = extract_msgs(msgs)
26
+ expected_msgs = [
27
+ {"role": "user", "content": "Hi"},
28
+ {"role": "assistant", "content": "Hi!"},
29
+ {
30
+ "role": "user",
31
+ "content": [
32
+ {"type": "text", "text": "Describe:"},
33
+ {
34
+ "type": "image_url",
35
+ "image_url": {
36
+ "url": f"data:image/png;base64,{encode_as_png_in_base64(img)}",
37
+ },
38
+ },
39
+ ],
40
+ },
41
+ ]
42
+ assert DeepDiff(messages, expected_msgs) == {}
43
+
44
+
45
+ @pytest.mark.skipif(condition=is_ci(), reason="Avoid costs")
46
+ def test_openai_vision():
47
+ model: LLM = OpenAIModel()
48
+ assert_model_knows_capital_of_france(model)
49
+ assert_model_recognizes_pyramid_in_image(model)
tests/test_phi_3.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlib.base_llm import Message
2
+ from PIL import Image
3
+
4
+ from llmlib.phi3.phi3 import GenConf, Phi3Vision, extract_imgs_and_dicts, pad_left
5
+ import pytest
6
+ import torch
7
+
8
+ from .helpers import (
9
+ assert_model_can_answer_batch_of_img_prompts,
10
+ assert_model_can_answer_batch_of_text_prompts,
11
+ assert_model_knows_capital_of_france,
12
+ assert_model_rejects_unsupported_batches,
13
+ get_mona_lisa_completion,
14
+ is_ci,
15
+ )
16
+
17
+
18
+ def test_extract_imgs_and_dicts():
19
+ img1 = Image.new(mode="RGB", size=(1, 1))
20
+ img2 = Image.new(mode="RGB", size=(1, 1))
21
+ msgs = [
22
+ a_msg(),
23
+ a_msg(img=img1, img_name="img1"),
24
+ a_msg(img=img2, img_name="img2"),
25
+ a_msg(),
26
+ a_msg(img=img1, img_name="img1"),
27
+ a_msg(img=img2, img_name="img2"),
28
+ ]
29
+ images, messages = extract_imgs_and_dicts(msgs)
30
+ assert len(images) == 2
31
+ assert len(messages) == 6
32
+ assert "<|image_1|>" in messages[1]["content"]
33
+ assert "<|image_1|>" in messages[4]["content"]
34
+ assert "<|image_2|>" in messages[5]["content"]
35
+ assert "<|image_2|>" in messages[2]["content"]
36
+
37
+
38
+ def a_msg(img: Image.Image | None = None, img_name: str | None = None) -> Message:
39
+ return Message(role="user", msg="", img=img, img_name=img_name)
40
+
41
+
42
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
43
+ def test_phi3_vision(model: Phi3Vision):
44
+ assert_model_knows_capital_of_france(model)
45
+ answer: str = get_mona_lisa_completion(model)
46
+ assert isinstance(answer, str)
47
+
48
+
49
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
50
+ def test_phi3_batching(model: Phi3Vision):
51
+ assert_model_can_answer_batch_of_text_prompts(model)
52
+ assert_model_can_answer_batch_of_img_prompts(model)
53
+
54
+
55
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
56
+ def test_phi3_invalid_input(model: Phi3Vision):
57
+ assert_model_rejects_unsupported_batches(model)
58
+
59
+
60
+ @pytest.fixture(scope="module")
61
+ def model():
62
+ yield Phi3Vision(GenConf(max_new_tokens=30))
63
+
64
+
65
+ def test_padleft():
66
+ pad_token = -1
67
+ seqs = [torch.tensor([1, 2, 3]), torch.tensor([4, 5]), torch.tensor([6])]
68
+ expected = torch.tensor([[1, 2, 3], [pad_token, 4, 5], [pad_token, pad_token, 6]])
69
+ actual = pad_left(seqs, pad_token)
70
+ assert torch.equal(actual, expected)
tests/test_rest_api.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ from llmlib.bundler_request import BundlerRequest
3
+ import llmlib.rest_api.restapi_client as llmclient
4
+ from llmlib.rest_api.restapi_server import create_fastapi_app
5
+ from llmlib.phi3.phi3 import Phi3Vision
6
+ import pytest
7
+ from .helpers import is_ci, mona_lisa_message
8
+
9
+
10
+ def app():
11
+ return TestClient(create_fastapi_app())
12
+
13
+
14
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
15
+ def test_rest_api_get_completion():
16
+ breq: BundlerRequest = _mona_lisa_request()
17
+ response = llmclient.get_completion_from_rest_api(source=app(), breq=breq)
18
+ assert response.status_code == 200, response.content
19
+ assert "portrait" in response.json()["response"].lower()
20
+
21
+
22
+ def test_rest_api_get_models():
23
+ response = llmclient.get_models(source=app())
24
+ assert response.status_code == 200, response.content
25
+ assert len(response.json()) > 3
26
+
27
+
28
+ @pytest.mark.skip(reason="This test requires the REST API to be running")
29
+ def test_rest_api_integration_test():
30
+ breq: BundlerRequest = _mona_lisa_request()
31
+ response = llmclient.get_completion_from_rest_api(breq)
32
+ llmclient.clear_gpu()
33
+ assert response.status_code == 200, response.content
34
+ assert "portrait" in response.json()["response"].lower()
35
+
36
+
37
+ def _mona_lisa_request() -> BundlerRequest:
38
+ msg = mona_lisa_message()
39
+ some_valid_modelid: str = Phi3Vision.model_id
40
+ return BundlerRequest(model_id=some_valid_modelid, msgs=[msg])
tests/test_whisper.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlib.whisper import Whisper
2
+ import pytest
3
+ from tests.helpers import is_ci, file_for_test
4
+
5
+
6
+ @pytest.fixture(scope="module")
7
+ def model() -> Whisper:
8
+ return Whisper()
9
+
10
+
11
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
12
+ def test_transcription(model: Whisper):
13
+ audio_file = str(file_for_test(name="some-audio.flac")) # Librispeech sample 2
14
+ expected_transcription = "before he had time to answer a much encumbered vera burst into the room with the question i say can i leave these here these were a small black pig and a lusty specimen of black-red game-cock"
15
+ actual_transcription: str = model.transcribe_file(audio_file)
16
+ assert actual_transcription == expected_transcription
17
+
18
+
19
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
20
+ def test_video_transcription(model: Whisper):
21
+ video_file = str(file_for_test("video.mp4"))
22
+ expected_fragment = (
23
+ "Die Unionsparteien oder deren Politiker sind heute wichtige Offiziere"
24
+ )
25
+ transcription = model.transcribe_file(video_file)
26
+ assert expected_fragment in transcription
27
+
28
+
29
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
30
+ def test_translation(model: Whisper):
31
+ german_video = str(file_for_test("video.mp4"))
32
+ translation: str = model.transcribe_file(german_video, translate=True)
33
+ assert "The parties and their politicians" in translation
34
+
35
+
36
+ @pytest.mark.skipif(condition=is_ci(), reason="No GPU in CI")
37
+ def test_long_video_transcription(model: Whisper):
38
+ video_file = str(file_for_test("long-video.mp4"))
39
+ transcription: str = model.transcribe_file(video_file)
40
+ assert isinstance(transcription, str)