root commited on
Commit
e676d24
·
1 Parent(s): 3c0d958
.gitignore ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.h5
6
+ *.out
7
+
8
+ # Distribution / packaging
9
+ .Python
10
+ build/
11
+ experiments/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ saved_imgs/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ temp_data/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # Debug
31
+ debug.py
32
+ debugs/
33
+ tensorboard_log/
34
+ saved_models/
35
+ configs_collection/
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # PyBuilder
42
+ target/
43
+
44
+ # Jupyter Notebook
45
+ .ipynb_checkpoints
46
+
47
+ # pyenv
48
+ .python-version
49
+
50
+
51
+ # Environments
52
+ .env
53
+ .venv
54
+ env/
55
+ venv/
56
+ ENV/
57
+ env.bak/
58
+ venv.bak/
59
+
60
+ #saved_model
61
+ *.pth
62
+
63
+ *.pt
64
+
65
+ *.log
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #FROM ubuntu:22.04
2
+ FROM python:3.11-bullseye
3
+
4
+ ARG DEBIAN_FRONTEND=noninteractive
5
+
6
+ USER root
7
+
8
+ RUN apt-get update && apt-get install -y \
9
+ curl \
10
+ nano \
11
+ poppler-utils \
12
+ software-properties-common \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ ENV APP_ROOT=/home
16
+
17
+ WORKDIR /home
18
+ COPY . .
19
+
20
+ RUN chown -R root:root ${APP_ROOT} && chmod -R 777 ${APP_ROOT}
21
+
22
+ RUN pip install --upgrade pip setuptools \
23
+ && pip install --no-cache-dir -r requirements.txt \
24
+ && pip install streamlit==1.38.0
25
+
26
+ EXPOSE 7860
27
+
28
+ ENTRYPOINT /home/server/start_server.sh
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ pypdf
3
+ unstructured
4
+ typing
5
+ pydantic
6
+ llama-index==0.10.28
7
+ llama-index-llms-openai-like==0.1.3
8
+ llama-index-embeddings-openai==0.1.7
9
+ llama-index-readers-web==0.1.8
10
+ openai==1.53.0
11
+ httpx==0.27.2
12
+ #streamlit==1.41.1
13
+ #streamlit-navigation-bar==3.3.0
14
+ streamlit-community-navigation-bar==4.0.9
15
+ aiohttp
16
+ docx2txt
17
+ trafilatura==1.8.1
18
+ motor
19
+ loguru
20
+ qdrant-client==1.12.2
21
+ Pillow
22
+ stamina
23
+ pdf2image==1.17.0
24
+ st-clickable-images==0.0.3
25
+
26
+ #arize-phoenix==2.5.0
server/.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [theme]
2
+ base="light"
3
+ primaryColor="#E20074"
server/app/__init__.py ADDED
File without changes
server/app/config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, List, Tuple, Type
3
+
4
+ class Settings:
5
+ #==============================================================================
6
+ GLOBAL_API_BASE="https://llm-server.llmhub.t-systems.net/queue"
7
+ GLOBAL_API_KEY=os.getenv("GLOBAL_AIFS_API_KEY")
8
+
9
+ app_settings = Settings()
10
+
11
+ from loguru import logger
12
+ import sys
13
+ from datetime import time, timezone
14
+ import os, time
15
+
16
+ os.environ['TZ'] = 'Europe/Berlin'
17
+ time.tzset()
18
+
19
+
20
+ def only_level(level):
21
+ def is_level(record):
22
+ return record['level'].name == level
23
+
24
+ return is_level
25
+
26
+ #logger.remove(0)
27
+ formato = '{time:YYYY-MM-DD HH:mm:ss.SS!UTC} {level:8} {message} [{file} : {line}]'
28
+
29
+ logger.add(sys.stderr, format=formato, level="DEBUG")
server/app/prompt_template.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ VDR_PROMPT='''\
2
+ You are an helpful AI assistant that answer the question base on the context provided.
3
+ If the context doesn't help, truthfully answer with: I can't find that information in the given context.
4
+ Base on the given context, focus to answer the following question:
5
+ {user_question}
6
+ '''
server/app/qdrant_db.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from qdrant_client.http import models
3
+ from tqdm import tqdm
4
+ import os
5
+ import time
6
+ import numpy as np
7
+ from loguru import logger
8
+ import stamina
9
+ from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict
10
+
11
+ class MyQdrantClient:
12
+ def __init__(self, path: str):
13
+ self.qdrant_client = QdrantClient(path=path)
14
+ logger.debug(f"Qdrant client created at {path}")
15
+
16
+ def create_collection(self, collection_name: str, vector_dim: int = 128, vector_type: str = "colbert"):
17
+ if vector_type == "colbert":
18
+ self.qdrant_client.create_collection(
19
+ collection_name=collection_name,
20
+ on_disk_payload=True, # store the payload on disk
21
+ vectors_config=models.VectorParams(
22
+ size=vector_dim,
23
+ distance=models.Distance.COSINE,
24
+ on_disk=True, # move original vectors to disk
25
+ multivector_config=models.MultiVectorConfig(
26
+ comparator=models.MultiVectorComparator.MAX_SIM
27
+ ),
28
+ #quantization_config=models.BinaryQuantization(
29
+ #binary=models.BinaryQuantizationConfig(
30
+ # always_ram=True # keep only quantized vectors in RAM
31
+ # ),
32
+ #),
33
+ ),
34
+ )
35
+ elif vector_type == "dense":
36
+ self.qdrant_client.create_collection(
37
+ collection_name=collection_name,
38
+ on_disk_payload=True, # store the payload on disk
39
+ vectors_config=models.VectorParams(
40
+ size=vector_dim,
41
+ distance=models.Distance.COSINE,
42
+ on_disk=True, # move original vectors to disk
43
+ ),
44
+ )
45
+ else:
46
+ raise ValueError(f"Vector type {vector_type} not supported")
47
+
48
+ logger.debug(f"Qdrant collection of type {vector_type} : {collection_name} created")
49
+
50
+ def delete_collection(self, collection_name: str):
51
+ self.qdrant_client.delete_collection(collection_name=collection_name)
52
+
53
+ @stamina.retry(on=Exception, attempts=3) # retry mechanism if an exception occurs during the operation
54
+ def upsert_to_qdrant(self, batch, collection_name: str):
55
+ try:
56
+ self.qdrant_client.upsert(
57
+ collection_name=collection_name,
58
+ points=batch,
59
+ wait=False,
60
+ )
61
+ except Exception as e:
62
+ logger.error(f"Error during upsert: {e}")
63
+ return False
64
+ return True
65
+
66
+ def upsert_multivector(self, index: int, multivector_input_list: list[Any], collection_name: str):
67
+ try:
68
+ points = []
69
+ for j, multivector in enumerate(multivector_input_list):
70
+ points.append(
71
+ models.PointStruct(
72
+ id=index + j, # we just use the index as the ID
73
+ vector=multivector, # This is now a list of vectors
74
+ payload={
75
+ "source": "user uploaded data"
76
+ }, # can also add other metadata/data
77
+ )
78
+ )
79
+ # Upload points to Qdrant
80
+
81
+ self.upsert_to_qdrant(points, collection_name)
82
+ except Exception as e:
83
+ logger.error(f"Vector DB client - error during upsert: {e}")
84
+
85
+ def query_multivector(self, multivector_input, collection_name: str, top_k:int=10) -> list[int]:
86
+ try:
87
+ #logger.debug(f"Number of vector: {len(multivector_input)}")
88
+ #logger.debug(f"Vector dim: {len(multivector_input[0])}")
89
+
90
+ start_time = time.time()
91
+ search_result = self.qdrant_client.query_points(
92
+ collection_name=collection_name,
93
+ query=multivector_input,
94
+ limit=top_k,
95
+ # timeout=100,
96
+ # search_params=models.SearchParams(
97
+ # quantization=models.QuantizationSearchParams(
98
+ # ignore=False,
99
+ # rescore=True,
100
+ # oversampling=2.0,
101
+ # )
102
+ # )
103
+ )
104
+ end_time = time.time()
105
+ elapsed_time = end_time - start_time
106
+ logger.debug(f"Search completed in {elapsed_time:.4f} seconds")
107
+
108
+ result = [x.id for x in search_result.points]
109
+ return result
110
+
111
+ except Exception as e:
112
+ logger.error(f"Error during query: {e}")
113
+ return None
114
+
115
+ def __del__(self):
116
+ self.qdrant_client.close()
117
+
server/app/vdr_schemas.py ADDED
File without changes
server/app/vdr_session.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import os
3
+ import time
4
+ import subprocess
5
+ import uuid
6
+ from loguru import logger
7
+ from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict
8
+ import httpx
9
+ import os
10
+ import time
11
+ import subprocess
12
+ import uuid
13
+ import streamlit as st
14
+ from openai import OpenAI
15
+ import base64
16
+ from tqdm import tqdm
17
+
18
+ from app.config import app_settings
19
+
20
+ from app.qdrant_db import MyQdrantClient
21
+
22
+ from app.vdr_utils import (
23
+ get_text_embedding,
24
+ get_image_embedding,
25
+ pdf_folder_to_images,
26
+ scale_image,
27
+ pil_image_to_base64,
28
+ load_images,
29
+ )
30
+
31
+ class VDRSession:
32
+ def __init__(self):
33
+ self.client = None
34
+ self.api_key = None
35
+ self.base_url = app_settings.GLOBAL_API_BASE
36
+ self.SAVE_DIR = None
37
+ self.db_collection = None
38
+ self.session_id = str(uuid.uuid4())[:5]
39
+ self.indexed_images = []
40
+ self.vector_db_client = None
41
+
42
+ def set_api_key(self, api_key: str):
43
+ if api_key is not None and len(api_key)>10:
44
+ try:
45
+ api_key = api_key.strip()
46
+ client = OpenAI(api_key=api_key,
47
+ base_url=self.base_url)
48
+ models = client.models.list()
49
+ if models:
50
+ self.api_key = api_key
51
+ self.client = client
52
+ return True
53
+ except Exception as e:
54
+ logger.debug(f'Incorrect API Key: {e}')
55
+
56
+ self.client = None
57
+ return False
58
+
59
+ def set_context(self, embed_model: str):
60
+ self.embed_model = embed_model
61
+
62
+ if not self.SAVE_DIR:
63
+ self.SAVE_DIR=os.path.join('./temp_data', self.session_id)
64
+ os.makedirs(self.SAVE_DIR, exist_ok=True)
65
+ self.SAVE_IMAGE_DIR=os.path.join(self.SAVE_DIR, 'images')
66
+ logger.debug(f'Created folder: {self.SAVE_DIR} and {self.SAVE_IMAGE_DIR}')
67
+
68
+ if not self.vector_db_client:
69
+ self.vector_db_client = MyQdrantClient(path=self.SAVE_DIR)
70
+
71
+ if not self.db_collection:
72
+ self.db_collection = f"qd-{embed_model}-{self.session_id}"
73
+ try:
74
+ if self.embed_model == "tsi-embedding-colqwen2-2b-v1":
75
+ self.vector_db_client.create_collection(self.db_collection, vector_dim=128, vector_type="colbert")
76
+ elif self.embed_model == "jina-embedding-clip-v1":
77
+ self.vector_db_client.create_collection(self.db_collection, vector_dim=768, vector_type="dense")
78
+ else:
79
+ raise ValueError(f"Embedding model {self.embed_model} not supported")
80
+ except Exception as e:
81
+ logger.error(f"Error while creating collection: {e}")
82
+
83
+ return True
84
+
85
+ def get_available_vlms(self) -> List[str]:
86
+ assert self.client != None
87
+ model_name_list = []
88
+ try:
89
+ models = self.client.models.list()
90
+ for model in models.data:
91
+ model_name = model.id
92
+ substrings = ['gemini','QWEN-VL2-7B']
93
+ if any(substring in model_name for substring in substrings):
94
+ model_name_list.append(model.id)
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error while query all models: {e}")
98
+ raise e
99
+
100
+ # Prioritize name
101
+ # Remove the item if it exists in the list
102
+ priority_item = "gemini-2.0-flash-exp-US"
103
+ if priority_item in model_name_list:
104
+ model_name_list.remove(priority_item)
105
+
106
+ # Insert the item at the beginning of the list
107
+ model_name_list.insert(0, priority_item)
108
+
109
+ return model_name_list
110
+
111
+ def get_available_image_embeds(self) -> List[str]:
112
+ assert self.client != None
113
+ model_name_list = []
114
+ try:
115
+ models = self.client.models.list()
116
+ for model in models.data:
117
+ model_name = model.id
118
+ substrings = ['tsi-embedding','clip']
119
+ if any(substring in model_name for substring in substrings):
120
+ model_name_list.append(model.id)
121
+
122
+ except Exception as e:
123
+ logger.error(f"Error while query all models: {e}")
124
+ raise e
125
+
126
+ return model_name_list
127
+
128
+ def search_images(self, text: str, top_k: int = 5) -> list[str]:
129
+ assert self.client != None
130
+ assert self.vector_db_client != None
131
+ try:
132
+ if not self.indexed_images:
133
+ raise Exception("No indexed images found. You need to click on 'Add selected context' button to index images.")
134
+ text = text.strip()
135
+ if len(text) < 2:
136
+ return False
137
+
138
+ embeddings = get_text_embedding(
139
+ texts=text,
140
+ openai_client=self.client,
141
+ model=self.embed_model
142
+ )[0]
143
+
144
+ index_results = self.vector_db_client.query_multivector(
145
+ multivector_input=embeddings,
146
+ collection_name=self.db_collection,
147
+ top_k=top_k
148
+ )
149
+ image_list=[self.indexed_images[i] for i in index_results]
150
+ images = []
151
+ for img in image_list:
152
+ #with open(file, "rb") as image:
153
+ #encoded = base64.b64encode(image.read()).decode()
154
+ encoded = pil_image_to_base64(img)
155
+ images.append(f"data:image/png;base64,{encoded}")
156
+ return images
157
+ except Exception as e:
158
+ logger.error(f"Error while generating image: {e}")
159
+ raise e
160
+
161
+ def ask(self, query: str, model: str, prompt_template: str, retrieved_context: Any, modality: str = "image", stream: bool = False) -> str:
162
+ assert self.client != None
163
+ assert query != None
164
+ assert prompt_template != None
165
+ assert retrieved_context != None
166
+
167
+ try:
168
+ prompt = prompt_template.format(user_question=query)
169
+ if modality == "image":
170
+ context = [
171
+ {
172
+ "type": "image_url",
173
+ "image_url": {
174
+ "url": base64_image
175
+ }
176
+ } for base64_image in retrieved_context
177
+ ]
178
+
179
+ content = [
180
+ {
181
+ "type": "text",
182
+ "text": prompt
183
+ }
184
+ ]
185
+ content=content+context
186
+
187
+ messages=[
188
+ {
189
+ "role": "user",
190
+ "content": content,
191
+ }
192
+ ]
193
+
194
+ chat_response = self.client.chat.completions.create(
195
+ model=model,
196
+ messages=messages,
197
+ temperature=0.1,
198
+ max_tokens=2048,
199
+ stream=stream,
200
+ )
201
+ if not stream:
202
+ return chat_response.choices[0].message.content
203
+ else:
204
+ for chunk in chat_response:
205
+ if chunk.choices:
206
+ if chunk.choices[0].delta.content is not None:
207
+ yield chunk.choices[0].delta.content
208
+ #print(chunk.choices[0].delta.content, end="", flush=True)
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error while asking: {e}")
212
+ raise e
213
+
214
+ def indexing(self, uploaded_files: list[str], embed_model: str, indexing_bar: Optional[st.progress] = None) -> bool:
215
+ self.set_context(embed_model)
216
+
217
+ assert self.client != None
218
+ assert self.db_collection != None
219
+ assert self.SAVE_DIR != None
220
+ assert self.embed_model != None
221
+ assert len(uploaded_files) > 0
222
+
223
+ # Write files to disk
224
+ for file in uploaded_files :
225
+ path = os.path.join(self.SAVE_DIR, file.name)
226
+ if os.path.exists(path):
227
+ print("File existed, skip")
228
+ continue
229
+ with open(path, "wb") as f:
230
+ f.write(file.getvalue())
231
+
232
+ image_path_list = pdf_folder_to_images(pdf_folder=self.SAVE_DIR, output_folder=self.SAVE_IMAGE_DIR)
233
+ logger.debug(f"Extracted {len(image_path_list)} images from {len(uploaded_files)} files.")
234
+
235
+ indexed_images = self.index_from_images(image_path_list, indexing_bar=indexing_bar)
236
+ logger.debug(f"Indexed {len(indexed_images)} images.")
237
+
238
+ self.indexed_images.extend(indexed_images)
239
+ return True
240
+
241
+ def clear_context(self):
242
+ self.indexed_images = []
243
+ self.vector_db_client.delete_collection(self.db_collection)
244
+ self.db_collection = None
245
+ self.vector_db_client = None
246
+
247
+ if self.SAVE_DIR:
248
+ if os.path.exists(self.SAVE_DIR):
249
+ subprocess.run(['rm', '-rf', self.SAVE_DIR])
250
+ logger.debug(f'Removed folder: {self.SAVE_DIR}')
251
+ self.SAVE_DIR = None
252
+ return True
253
+
254
+ def __del__(self):
255
+ self.clear_context()
256
+ logger.debug('VDR session is cleaned up.')
257
+
258
+ def index_from_images(self,
259
+ images_path_list: list,
260
+ batch_size: int =5,
261
+ indexing_bar: Optional[st.progress] = None
262
+ ):
263
+ try:
264
+ indexed_images = []
265
+ total_len = len(images_path_list)
266
+ with tqdm(total=total_len, desc="Indexing Progress") as pbar:
267
+ for i in range(0, total_len, batch_size):
268
+ try:
269
+ batch = images_path_list[i:min(i+batch_size,total_len)]
270
+ #batch = load_images(batch)
271
+ batch = [scale_image(x, 768) for x in batch]
272
+
273
+ embeddings = get_image_embedding(
274
+ image_list=batch,
275
+ openai_client=self.client,
276
+ model=self.embed_model
277
+ )
278
+ self.vector_db_client.upsert_multivector(
279
+ index=i,
280
+ multivector_input_list=embeddings,
281
+ collection_name=self.db_collection
282
+ )
283
+
284
+ indexed_images.extend(batch)
285
+ # Update the progress bar
286
+ pbar.update(batch_size)
287
+ indexing_bar.progress(i/total_len, text=f"Indexing {i}/{total_len}")
288
+ except Exception as e:
289
+ logger.exception(f"Error during indexing: {e}")
290
+ continue
291
+
292
+ return indexed_images
293
+
294
+ logger.debug("Indexing complete!")
295
+ except Exception as e:
296
+ raise Exception(f"Error during indexing: {e}")
297
+
298
+
server/app/vdr_utils.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import base64
4
+ import io
5
+ from io import BytesIO
6
+ from PIL import Image, ImageFile
7
+ from pdf2image import convert_from_path
8
+ import tempfile
9
+ from multiprocessing import Pool
10
+ import os
11
+ from loguru import logger
12
+ import uuid
13
+
14
+ from typing import Any, List, Tuple, Type, Literal, Optional, Union, Dict
15
+
16
+ def encode_image(image_path):
17
+ with open(image_path, "rb") as image_file:
18
+ return base64.b64encode(image_file.read()).decode('utf-8')
19
+
20
+ def load_image_from_base64(image):
21
+ return Image.open(BytesIO(base64.b64decode(image)))
22
+
23
+ def pil_image_to_base64(image: Image) -> str:
24
+ """
25
+ Convert a PIL Image object to its base64 representation.
26
+
27
+ Args:
28
+ image (Image): The PIL Image object to be converted.
29
+
30
+ Returns:
31
+ str: The base64 representation of the image.
32
+ """
33
+
34
+ # Create a bytes buffer
35
+ buffer = io.BytesIO()
36
+
37
+ # Save the image to the buffer
38
+ image.save(buffer, format="PNG")
39
+
40
+ # Get the bytes from the buffer
41
+ img_bytes = buffer.getvalue()
42
+
43
+ # Convert the bytes to base64
44
+ img_base64 = base64.b64encode(img_bytes).decode("utf-8")
45
+
46
+ return img_base64
47
+
48
+ def scale_image(image: Image.Image, new_height: int = 1024) -> Image.Image:
49
+ """
50
+ Scale an image to a new height while maintaining the aspect ratio.
51
+ """
52
+ width, height = image.size
53
+ aspect_ratio = width / height
54
+ new_width = int(new_height * aspect_ratio)
55
+
56
+ scaled_image = image.resize((new_width, new_height))
57
+
58
+ return scaled_image
59
+
60
+ def unflatten_array(flat_list, vector_size=128):
61
+ return np.array(flat_list).reshape(-1, vector_size)
62
+
63
+ def get_image_embedding(image_list: list[Image], openai_client, model: str, flatten: bool = False) -> list:
64
+ """
65
+ Get the embedding of an image.
66
+
67
+ Args:
68
+ image (Image): The image to be embedded.
69
+
70
+ Returns:
71
+ list[list[float]] if flatten,
72
+ else: list[list[list[float]]] with shape = (number of images (m), number of vector for each text (n), vector dim = 128)
73
+ """
74
+ if not isinstance(image_list, list):
75
+ image_list = [image_list]
76
+
77
+ input_base64_list = [f"data:image/png;base64,{pil_image_to_base64(image)}" for image in image_list]
78
+ # Get the embedding of the image
79
+ embedding = openai_client.embeddings.create(
80
+ input=input_base64_list,
81
+ model=model,
82
+ extra_body={
83
+ "modality": "image",
84
+ "encoding_format":"float" if not flatten else "base64",
85
+ },
86
+ )
87
+
88
+ result = []
89
+ for embed in embedding.data:
90
+ result.append(embed.embedding) # embed.embedding is a list[float] in case of flatten, else: list[list[float]]
91
+ return result
92
+
93
+ def get_text_embedding(texts: list[str], openai_client, model: str, flatten: bool = False) -> list:
94
+ """
95
+ Get the embedding of a text.
96
+
97
+ Args:
98
+ text (str): The text to be embedded.
99
+
100
+ Returns:
101
+ list[list[float]] if flatten,
102
+ else: list[list[list[float]]] with shape = (number of texts (m), number of vector for each text (n), vector dim = 128)
103
+ """
104
+ if not isinstance(texts, list):
105
+ texts = [texts]
106
+
107
+ # Get the embedding of the text
108
+ embedding = openai_client.embeddings.create(
109
+ input=texts,
110
+ model=model,
111
+ extra_body={
112
+ "encoding_format":"float" if not flatten else "base64",
113
+ },
114
+ )
115
+
116
+ result = []
117
+ for embed in embedding.data:
118
+ result.append(embed.embedding) # embed.embedding is a list[float] in case of flatten, else: list[list[float]]
119
+ return result
120
+
121
+ def load_images(image_paths):
122
+ """
123
+ Load images from a list of paths and return a list of PIL image objects.
124
+
125
+ Args:
126
+ image_paths (list): List of image paths.
127
+
128
+ Returns:
129
+ list: List of PIL image objects.
130
+ """
131
+ images = []
132
+ for path in image_paths:
133
+ try:
134
+ img = Image.open(path)
135
+ images.append(img)
136
+ except Exception as e:
137
+ logger.error(f"Error loading image at path {path}: {str(e)}")
138
+ return images
139
+
140
+
141
+ def process_pdf(pdf_path: str, output_folder: str, thread_count=1):
142
+ result_image_paths = []
143
+
144
+ with tempfile.TemporaryDirectory() as temp_dir:
145
+ images = convert_from_path(pdf_path, dpi=200, output_folder=temp_dir, thread_count=thread_count)
146
+
147
+ # for page_num, image in enumerate(images):
148
+ # image_filename = f"{str(uuid.uuid4())}.png"
149
+ # image_path = os.path.join(output_folder, image_filename)
150
+ # image.save(image_path, "PNG")
151
+ # result_image_paths.append(image_path)
152
+
153
+ # del images
154
+ # return result_image_paths
155
+ return images
156
+
157
+
158
+ def pdf_folder_to_images(pdf_folder: str, output_folder: str, process_count: int = 2):
159
+ try:
160
+ if process_count is None:
161
+ process_count = os.cpu_count()
162
+
163
+ pdf_files = [os.path.join(pdf_folder, f) for f in os.listdir(pdf_folder)
164
+ if f.lower().endswith('.pdf')]
165
+
166
+ # Create a list of tuples containing (pdf_file, output_folder)
167
+ args = [(pdf_file, output_folder) for pdf_file in pdf_files]
168
+
169
+ with Pool(process_count) as pool:
170
+ all_images = pool.starmap(process_pdf, args)
171
+
172
+ result = [img for sublist in all_images for img in sublist]
173
+
174
+ logger.debug(f"Number of pdfs processed: {len(all_images)} - Number of images: {len(result)}")
175
+ return result
176
+ except Exception as e:
177
+ logger.exception(f"Error during processing pdf: {e}")
178
+
179
+
server/favicon.png ADDED
server/main.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_navigation_bar import st_navbar
3
+ import st_pages as pg
4
+
5
+ st.set_page_config(page_title='T-Systems LLM Playground', page_icon='favicon.png')
6
+
7
+ with st.sidebar:
8
+ st.html("""<center><img src="https://upload.wikimedia.org/wikipedia/commons/0/0a/T-SYSTEMS-LOGO2013.svg" width="300" height="68" ></center>""")
9
+
10
+ st.markdown('**This is playground for the LLM available via T-Systems AI Foundation Services**')
11
+
12
+ pages_name = ['Visual Retrieval',"Documentation", "Terms & Conditions"]
13
+ urls = {
14
+ #"Create API Key":"https://apikey.llmhub.t-systems.net/#/dashboard",
15
+ "Documentation":"https://docs.llmhub.t-systems.net/",
16
+ "Terms & Conditions":"https://smartchat.ai-health.aisf.t-systems.net/privacy"
17
+ }
18
+ styles = {
19
+ "nav": {
20
+ #"background-color": "#E20074",
21
+ "justify-content": "center",
22
+ },
23
+ "span": {
24
+ "border-radius": "0.5rem",
25
+ "color": "rgb(49, 51, 63)",
26
+ "margin": "0 0.125rem",
27
+ "padding": "0.4375rem 0.625rem",
28
+ },
29
+ # "active": {
30
+ # "background-color": "rgba(256, 0, 116, 0.25)",
31
+ # },
32
+ "hover": {
33
+ "background-color": "rgba(226, 0, 116, 0.5)",
34
+ },
35
+ }
36
+
37
+ # options = {
38
+ # "show_menu": False,
39
+ # #"show_sidebar": False,
40
+ # }
41
+ page = st_navbar(
42
+ pages_name,
43
+ urls=urls,
44
+ styles=styles,
45
+ #options=options,
46
+ )
47
+
48
+ if page == 'Visual Retrieval':
49
+ pg.page_vdr()
server/st_pages/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from st_pages.page_vdr import page_vdr
server/st_pages/page_vdr.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os, time
3
+ from app.vdr_session import *
4
+ from app.vdr_schemas import *
5
+ from st_clickable_images import clickable_images
6
+ from app.prompt_template import VDR_PROMPT
7
+
8
+ def page_vdr():
9
+ st.header("Visual Document Retrieval")
10
+
11
+ # Store session context
12
+ if "vdr_session" not in st.session_state.keys():
13
+ st.session_state["vdr_session"] = VDRSession()
14
+
15
+ with st.sidebar:
16
+
17
+ #api_key = st.text_input('Enter API Key:', type='password')
18
+ api_key = os.getenv("GLOBAL_AIFS_API_KEY")
19
+
20
+ check_api_key=st.session_state["vdr_session"].set_api_key(api_key)
21
+
22
+ if check_api_key:
23
+ st.success('API Key is valid!', icon='✅')
24
+ avai_llms = st.session_state["vdr_session"].get_available_vlms()
25
+ avai_embeds = st.session_state["vdr_session"].get_available_image_embeds()
26
+ selected_llm = st.sidebar.selectbox('Choose VLM models', avai_llms, key='selected_llm', disabled=not check_api_key)
27
+ selected_embed = st.sidebar.selectbox('Choose Embedding models', avai_embeds, key='selected_embed', disabled=not check_api_key)
28
+ #st.session_state["vdr_session"].set_context(selected_llm, selected_embed)
29
+ else:
30
+ st.warning('Please enter valid credentials!', icon='⚠️')
31
+
32
+ if check_api_key:
33
+
34
+ with st.sidebar:
35
+ uploaded_files = st.file_uploader("Upload PDF files", key="uploaded_files", accept_multiple_files=True, disabled=not check_api_key)
36
+
37
+ if st.button("Add selected context", key="add_context", type="primary"):
38
+ if uploaded_files:
39
+ try:
40
+ indexing_bar = st.progress(0, text="Indexing...")
41
+ if st.session_state["vdr_session"].indexing(uploaded_files, selected_embed, indexing_bar):
42
+ st.success('Indexing completed!')
43
+ indexing_bar.empty()
44
+ #st.rerun()
45
+ else:
46
+ st.warning('Files empty or not supported.', icon='⚠️')
47
+ except Exception as e:
48
+ st.error(f"Error during indexing: {e}")
49
+ else:
50
+ st.warning('Please upload files first!', icon='⚠️')
51
+
52
+ if st.button("🗑️ Remove all context", key="remove_context"):
53
+ try:
54
+ st.session_state["vdr_session"].clear_context()
55
+ st.success("Context removed")
56
+ st.rerun()
57
+ except Exception as e:
58
+ st.error(f"Error during removing context: {e}")
59
+
60
+
61
+ top_k_sim = st.slider(label="Top k similarity", min_value=1, max_value=10, value=3, step=1, key="top_k_sim")
62
+ #text_only_embed = st.toggle("Text only embedding", key="text_only_embed", value=False)
63
+ chat_prompt = st.text_area("Prompt template", key="chat_prompt", value=VDR_PROMPT, height=300)
64
+
65
+ query = st.text_input(label="Query",key='query',placeholder="Enter your query here",label_visibility="hidden", disabled=not st.session_state.get("vdr_session").indexed_images)
66
+
67
+ with st.expander(f"**Top {top_k_sim} retrieved contexts**", expanded=True):
68
+ try:
69
+ if len(query.strip()) > 2:
70
+ if query != st.session_state.get("last_query", None):
71
+ with st.spinner('Searching...'):
72
+ st.session_state["last_query"] = query
73
+ st.session_state["result_images"] = st.session_state["vdr_session"].search_images(query, top_k_sim)
74
+
75
+ if st.session_state.get("result_images", []):
76
+ images = st.session_state["result_images"]
77
+
78
+ clicked = clickable_images(
79
+ images,
80
+ titles=[f"Image #{str(i)}" for i in range(len(images))],
81
+ div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
82
+ img_style={"margin": "5px", "height": "200px"},
83
+ )
84
+ st.write(f"**Retrieved by: {selected_embed}**")
85
+
86
+ @st.dialog(" ", width="large")
87
+ def show_selected_image(id):
88
+ st.markdown(f"**Similarity rank: {id}**")
89
+ st.image(images[id])
90
+
91
+ if clicked > -1 and clicked != st.session_state.get("clicked", None):
92
+ show_selected_image(clicked)
93
+ st.session_state["clicked"] = clicked
94
+
95
+ except Exception as e:
96
+ st.error(f"Error during search: {e}")
97
+
98
+ if st.session_state.get("result_images", None):
99
+ if st.button("Generate answer", key="ask", type="primary"):
100
+ if len(query.strip()) > 2:
101
+ try:
102
+ with st.spinner('Generating response...'):
103
+ stream_response = st.session_state["vdr_session"].ask(
104
+ query=query,
105
+ model=selected_llm,
106
+ prompt_template= chat_prompt,
107
+ retrieved_context=st.session_state["result_images"],
108
+ stream=True
109
+ )
110
+ #print(stream_response)
111
+ st.write_stream(stream_response)
112
+ st.write(f"**Answered by: {selected_llm}**")
113
+ except Exception as e:
114
+ st.error(f"Error during asking: {e}")
115
+ else:
116
+ st.warning('Please enter query first!', icon='⚠️')
117
+
118
+
119
+
120
+
server/start_server.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ #cd /home/server/ && python3 main.py
3
+ cd /home/server/ && streamlit run main.py --server.port 7860