File size: 11,569 Bytes
9b0f4a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
import math
import os
from struct import calcsize, pack, unpack
from typing import BinaryIO, Dict, Iterable, List, Optional, Tuple, Union, cast

from pdf2zh.pdfexceptions import PDFValueError

# segment structure base
SEG_STRUCT = [
    (">L", "number"),
    (">B", "flags"),
    (">B", "retention_flags"),
    (">B", "page_assoc"),
    (">L", "data_length"),
]

# segment header literals
HEADER_FLAG_DEFERRED = 0b10000000
HEADER_FLAG_PAGE_ASSOC_LONG = 0b01000000

SEG_TYPE_MASK = 0b00111111

REF_COUNT_SHORT_MASK = 0b11100000
REF_COUNT_LONG_MASK = 0x1FFFFFFF
REF_COUNT_LONG = 7

DATA_LEN_UNKNOWN = 0xFFFFFFFF

# segment types
SEG_TYPE_IMMEDIATE_GEN_REGION = 38
SEG_TYPE_END_OF_PAGE = 49
SEG_TYPE_END_OF_FILE = 51

# file literals
FILE_HEADER_ID = b"\x97\x4a\x42\x32\x0d\x0a\x1a\x0a"
FILE_HEAD_FLAG_SEQUENTIAL = 0b00000001


def bit_set(bit_pos: int, value: int) -> bool:
    return bool((value >> bit_pos) & 1)


def check_flag(flag: int, value: int) -> bool:
    return bool(flag & value)


def masked_value(mask: int, value: int) -> int:
    for bit_pos in range(31):
        if bit_set(bit_pos, mask):
            return (value & mask) >> bit_pos

    raise PDFValueError("Invalid mask or value")


def mask_value(mask: int, value: int) -> int:
    for bit_pos in range(31):
        if bit_set(bit_pos, mask):
            return (value & (mask >> bit_pos)) << bit_pos

    raise PDFValueError("Invalid mask or value")


def unpack_int(format: str, buffer: bytes) -> int:
    assert format in {">B", ">I", ">L"}
    [result] = cast(Tuple[int], unpack(format, buffer))
    return result


JBIG2SegmentFlags = Dict[str, Union[int, bool]]
JBIG2RetentionFlags = Dict[str, Union[int, List[int], List[bool]]]
JBIG2Segment = Dict[
    str,
    Union[bool, int, bytes, JBIG2SegmentFlags, JBIG2RetentionFlags],
]


class JBIG2StreamReader:
    """Read segments from a JBIG2 byte stream"""

    def __init__(self, stream: BinaryIO) -> None:
        self.stream = stream

    def get_segments(self) -> List[JBIG2Segment]:
        segments: List[JBIG2Segment] = []
        while not self.is_eof():
            segment: JBIG2Segment = {}
            for field_format, name in SEG_STRUCT:
                field_len = calcsize(field_format)
                field = self.stream.read(field_len)
                if len(field) < field_len:
                    segment["_error"] = True
                    break
                value = unpack_int(field_format, field)
                parser = getattr(self, "parse_%s" % name, None)
                if callable(parser):
                    value = parser(segment, value, field)
                segment[name] = value

            if not segment.get("_error"):
                segments.append(segment)
        return segments

    def is_eof(self) -> bool:
        if self.stream.read(1) == b"":
            return True
        else:
            self.stream.seek(-1, os.SEEK_CUR)
            return False

    def parse_flags(
        self,
        segment: JBIG2Segment,
        flags: int,
        field: bytes,
    ) -> JBIG2SegmentFlags:
        return {
            "deferred": check_flag(HEADER_FLAG_DEFERRED, flags),
            "page_assoc_long": check_flag(HEADER_FLAG_PAGE_ASSOC_LONG, flags),
            "type": masked_value(SEG_TYPE_MASK, flags),
        }

    def parse_retention_flags(
        self,
        segment: JBIG2Segment,
        flags: int,
        field: bytes,
    ) -> JBIG2RetentionFlags:
        ref_count = masked_value(REF_COUNT_SHORT_MASK, flags)
        retain_segments = []
        ref_segments = []

        if ref_count < REF_COUNT_LONG:
            for bit_pos in range(5):
                retain_segments.append(bit_set(bit_pos, flags))
        else:
            field += self.stream.read(3)
            ref_count = unpack_int(">L", field)
            ref_count = masked_value(REF_COUNT_LONG_MASK, ref_count)
            ret_bytes_count = int(math.ceil((ref_count + 1) / 8))
            for ret_byte_index in range(ret_bytes_count):
                ret_byte = unpack_int(">B", self.stream.read(1))
                for bit_pos in range(7):
                    retain_segments.append(bit_set(bit_pos, ret_byte))

        seg_num = segment["number"]
        assert isinstance(seg_num, int)
        if seg_num <= 256:
            ref_format = ">B"
        elif seg_num <= 65536:
            ref_format = ">I"
        else:
            ref_format = ">L"

        ref_size = calcsize(ref_format)

        for ref_index in range(ref_count):
            ref_data = self.stream.read(ref_size)
            ref = unpack_int(ref_format, ref_data)
            ref_segments.append(ref)

        return {
            "ref_count": ref_count,
            "retain_segments": retain_segments,
            "ref_segments": ref_segments,
        }

    def parse_page_assoc(self, segment: JBIG2Segment, page: int, field: bytes) -> int:
        if cast(JBIG2SegmentFlags, segment["flags"])["page_assoc_long"]:
            field += self.stream.read(3)
            page = unpack_int(">L", field)
        return page

    def parse_data_length(
        self,
        segment: JBIG2Segment,
        length: int,
        field: bytes,
    ) -> int:
        if length:
            if (
                cast(JBIG2SegmentFlags, segment["flags"])["type"]
                == SEG_TYPE_IMMEDIATE_GEN_REGION
            ) and (length == DATA_LEN_UNKNOWN):
                raise NotImplementedError(
                    "Working with unknown segment length is not implemented yet",
                )
            else:
                segment["raw_data"] = self.stream.read(length)

        return length


class JBIG2StreamWriter:
    """Write JBIG2 segments to a file in JBIG2 format"""

    EMPTY_RETENTION_FLAGS: JBIG2RetentionFlags = {
        "ref_count": 0,
        "ref_segments": cast(List[int], []),
        "retain_segments": cast(List[bool], []),
    }

    def __init__(self, stream: BinaryIO) -> None:
        self.stream = stream

    def write_segments(
        self,
        segments: Iterable[JBIG2Segment],
        fix_last_page: bool = True,
    ) -> int:
        data_len = 0
        current_page: Optional[int] = None
        seg_num: Optional[int] = None

        for segment in segments:
            data = self.encode_segment(segment)
            self.stream.write(data)
            data_len += len(data)

            seg_num = cast(Optional[int], segment["number"])

            if fix_last_page:
                seg_page = cast(int, segment.get("page_assoc"))

                if (
                    cast(JBIG2SegmentFlags, segment["flags"])["type"]
                    == SEG_TYPE_END_OF_PAGE
                ):
                    current_page = None
                elif seg_page:
                    current_page = seg_page

        if fix_last_page and current_page and (seg_num is not None):
            segment = self.get_eop_segment(seg_num + 1, current_page)
            data = self.encode_segment(segment)
            self.stream.write(data)
            data_len += len(data)

        return data_len

    def write_file(
        self,
        segments: Iterable[JBIG2Segment],
        fix_last_page: bool = True,
    ) -> int:
        header = FILE_HEADER_ID
        header_flags = FILE_HEAD_FLAG_SEQUENTIAL
        header += pack(">B", header_flags)
        # The embedded JBIG2 files in a PDF always
        # only have one page
        number_of_pages = pack(">L", 1)
        header += number_of_pages
        self.stream.write(header)
        data_len = len(header)

        data_len += self.write_segments(segments, fix_last_page)

        seg_num = 0
        for segment in segments:
            seg_num = cast(int, segment["number"])

        if fix_last_page:
            seg_num_offset = 2
        else:
            seg_num_offset = 1
        eof_segment = self.get_eof_segment(seg_num + seg_num_offset)
        data = self.encode_segment(eof_segment)

        self.stream.write(data)
        data_len += len(data)

        return data_len

    def encode_segment(self, segment: JBIG2Segment) -> bytes:
        data = b""
        for field_format, name in SEG_STRUCT:
            value = segment.get(name)
            encoder = getattr(self, "encode_%s" % name, None)
            if callable(encoder):
                field = encoder(value, segment)
            else:
                field = pack(field_format, value)
            data += field
        return data

    def encode_flags(self, value: JBIG2SegmentFlags, segment: JBIG2Segment) -> bytes:
        flags = 0
        if value.get("deferred"):
            flags |= HEADER_FLAG_DEFERRED

        if "page_assoc_long" in value:
            flags |= HEADER_FLAG_PAGE_ASSOC_LONG if value["page_assoc_long"] else flags
        else:
            flags |= (
                HEADER_FLAG_PAGE_ASSOC_LONG
                if cast(int, segment.get("page", 0)) > 255
                else flags
            )

        flags |= mask_value(SEG_TYPE_MASK, value["type"])

        return pack(">B", flags)

    def encode_retention_flags(
        self,
        value: JBIG2RetentionFlags,
        segment: JBIG2Segment,
    ) -> bytes:
        flags = []
        flags_format = ">B"
        ref_count = value["ref_count"]
        assert isinstance(ref_count, int)
        retain_segments = cast(List[bool], value.get("retain_segments", []))

        if ref_count <= 4:
            flags_byte = mask_value(REF_COUNT_SHORT_MASK, ref_count)
            for ref_index, ref_retain in enumerate(retain_segments):
                if ref_retain:
                    flags_byte |= 1 << ref_index
            flags.append(flags_byte)
        else:
            bytes_count = math.ceil((ref_count + 1) / 8)
            flags_format = ">L" + ("B" * bytes_count)
            flags_dword = mask_value(REF_COUNT_SHORT_MASK, REF_COUNT_LONG) << 24
            flags.append(flags_dword)

            for byte_index in range(bytes_count):
                ret_byte = 0
                ret_part = retain_segments[byte_index * 8 : byte_index * 8 + 8]
                for bit_pos, ret_seg in enumerate(ret_part):
                    ret_byte |= 1 << bit_pos if ret_seg else ret_byte

                flags.append(ret_byte)

        ref_segments = cast(List[int], value.get("ref_segments", []))

        seg_num = cast(int, segment["number"])
        if seg_num <= 256:
            ref_format = "B"
        elif seg_num <= 65536:
            ref_format = "I"
        else:
            ref_format = "L"

        for ref in ref_segments:
            flags_format += ref_format
            flags.append(ref)

        return pack(flags_format, *flags)

    def encode_data_length(self, value: int, segment: JBIG2Segment) -> bytes:
        data = pack(">L", value)
        data += cast(bytes, segment["raw_data"])
        return data

    def get_eop_segment(self, seg_number: int, page_number: int) -> JBIG2Segment:
        return {
            "data_length": 0,
            "flags": {"deferred": False, "type": SEG_TYPE_END_OF_PAGE},
            "number": seg_number,
            "page_assoc": page_number,
            "raw_data": b"",
            "retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
        }

    def get_eof_segment(self, seg_number: int) -> JBIG2Segment:
        return {
            "data_length": 0,
            "flags": {"deferred": False, "type": SEG_TYPE_END_OF_FILE},
            "number": seg_number,
            "page_assoc": 0,
            "raw_data": b"",
            "retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
        }