PDFTranslate / pdf2zh /jbig2.py
sanbo
update sth. at 2024-11-26 16:15:47
9b0f4a0
raw
history blame
11.6 kB
import math
import os
from struct import calcsize, pack, unpack
from typing import BinaryIO, Dict, Iterable, List, Optional, Tuple, Union, cast
from pdf2zh.pdfexceptions import PDFValueError
# segment structure base
SEG_STRUCT = [
(">L", "number"),
(">B", "flags"),
(">B", "retention_flags"),
(">B", "page_assoc"),
(">L", "data_length"),
]
# segment header literals
HEADER_FLAG_DEFERRED = 0b10000000
HEADER_FLAG_PAGE_ASSOC_LONG = 0b01000000
SEG_TYPE_MASK = 0b00111111
REF_COUNT_SHORT_MASK = 0b11100000
REF_COUNT_LONG_MASK = 0x1FFFFFFF
REF_COUNT_LONG = 7
DATA_LEN_UNKNOWN = 0xFFFFFFFF
# segment types
SEG_TYPE_IMMEDIATE_GEN_REGION = 38
SEG_TYPE_END_OF_PAGE = 49
SEG_TYPE_END_OF_FILE = 51
# file literals
FILE_HEADER_ID = b"\x97\x4a\x42\x32\x0d\x0a\x1a\x0a"
FILE_HEAD_FLAG_SEQUENTIAL = 0b00000001
def bit_set(bit_pos: int, value: int) -> bool:
return bool((value >> bit_pos) & 1)
def check_flag(flag: int, value: int) -> bool:
return bool(flag & value)
def masked_value(mask: int, value: int) -> int:
for bit_pos in range(31):
if bit_set(bit_pos, mask):
return (value & mask) >> bit_pos
raise PDFValueError("Invalid mask or value")
def mask_value(mask: int, value: int) -> int:
for bit_pos in range(31):
if bit_set(bit_pos, mask):
return (value & (mask >> bit_pos)) << bit_pos
raise PDFValueError("Invalid mask or value")
def unpack_int(format: str, buffer: bytes) -> int:
assert format in {">B", ">I", ">L"}
[result] = cast(Tuple[int], unpack(format, buffer))
return result
JBIG2SegmentFlags = Dict[str, Union[int, bool]]
JBIG2RetentionFlags = Dict[str, Union[int, List[int], List[bool]]]
JBIG2Segment = Dict[
str,
Union[bool, int, bytes, JBIG2SegmentFlags, JBIG2RetentionFlags],
]
class JBIG2StreamReader:
"""Read segments from a JBIG2 byte stream"""
def __init__(self, stream: BinaryIO) -> None:
self.stream = stream
def get_segments(self) -> List[JBIG2Segment]:
segments: List[JBIG2Segment] = []
while not self.is_eof():
segment: JBIG2Segment = {}
for field_format, name in SEG_STRUCT:
field_len = calcsize(field_format)
field = self.stream.read(field_len)
if len(field) < field_len:
segment["_error"] = True
break
value = unpack_int(field_format, field)
parser = getattr(self, "parse_%s" % name, None)
if callable(parser):
value = parser(segment, value, field)
segment[name] = value
if not segment.get("_error"):
segments.append(segment)
return segments
def is_eof(self) -> bool:
if self.stream.read(1) == b"":
return True
else:
self.stream.seek(-1, os.SEEK_CUR)
return False
def parse_flags(
self,
segment: JBIG2Segment,
flags: int,
field: bytes,
) -> JBIG2SegmentFlags:
return {
"deferred": check_flag(HEADER_FLAG_DEFERRED, flags),
"page_assoc_long": check_flag(HEADER_FLAG_PAGE_ASSOC_LONG, flags),
"type": masked_value(SEG_TYPE_MASK, flags),
}
def parse_retention_flags(
self,
segment: JBIG2Segment,
flags: int,
field: bytes,
) -> JBIG2RetentionFlags:
ref_count = masked_value(REF_COUNT_SHORT_MASK, flags)
retain_segments = []
ref_segments = []
if ref_count < REF_COUNT_LONG:
for bit_pos in range(5):
retain_segments.append(bit_set(bit_pos, flags))
else:
field += self.stream.read(3)
ref_count = unpack_int(">L", field)
ref_count = masked_value(REF_COUNT_LONG_MASK, ref_count)
ret_bytes_count = int(math.ceil((ref_count + 1) / 8))
for ret_byte_index in range(ret_bytes_count):
ret_byte = unpack_int(">B", self.stream.read(1))
for bit_pos in range(7):
retain_segments.append(bit_set(bit_pos, ret_byte))
seg_num = segment["number"]
assert isinstance(seg_num, int)
if seg_num <= 256:
ref_format = ">B"
elif seg_num <= 65536:
ref_format = ">I"
else:
ref_format = ">L"
ref_size = calcsize(ref_format)
for ref_index in range(ref_count):
ref_data = self.stream.read(ref_size)
ref = unpack_int(ref_format, ref_data)
ref_segments.append(ref)
return {
"ref_count": ref_count,
"retain_segments": retain_segments,
"ref_segments": ref_segments,
}
def parse_page_assoc(self, segment: JBIG2Segment, page: int, field: bytes) -> int:
if cast(JBIG2SegmentFlags, segment["flags"])["page_assoc_long"]:
field += self.stream.read(3)
page = unpack_int(">L", field)
return page
def parse_data_length(
self,
segment: JBIG2Segment,
length: int,
field: bytes,
) -> int:
if length:
if (
cast(JBIG2SegmentFlags, segment["flags"])["type"]
== SEG_TYPE_IMMEDIATE_GEN_REGION
) and (length == DATA_LEN_UNKNOWN):
raise NotImplementedError(
"Working with unknown segment length is not implemented yet",
)
else:
segment["raw_data"] = self.stream.read(length)
return length
class JBIG2StreamWriter:
"""Write JBIG2 segments to a file in JBIG2 format"""
EMPTY_RETENTION_FLAGS: JBIG2RetentionFlags = {
"ref_count": 0,
"ref_segments": cast(List[int], []),
"retain_segments": cast(List[bool], []),
}
def __init__(self, stream: BinaryIO) -> None:
self.stream = stream
def write_segments(
self,
segments: Iterable[JBIG2Segment],
fix_last_page: bool = True,
) -> int:
data_len = 0
current_page: Optional[int] = None
seg_num: Optional[int] = None
for segment in segments:
data = self.encode_segment(segment)
self.stream.write(data)
data_len += len(data)
seg_num = cast(Optional[int], segment["number"])
if fix_last_page:
seg_page = cast(int, segment.get("page_assoc"))
if (
cast(JBIG2SegmentFlags, segment["flags"])["type"]
== SEG_TYPE_END_OF_PAGE
):
current_page = None
elif seg_page:
current_page = seg_page
if fix_last_page and current_page and (seg_num is not None):
segment = self.get_eop_segment(seg_num + 1, current_page)
data = self.encode_segment(segment)
self.stream.write(data)
data_len += len(data)
return data_len
def write_file(
self,
segments: Iterable[JBIG2Segment],
fix_last_page: bool = True,
) -> int:
header = FILE_HEADER_ID
header_flags = FILE_HEAD_FLAG_SEQUENTIAL
header += pack(">B", header_flags)
# The embedded JBIG2 files in a PDF always
# only have one page
number_of_pages = pack(">L", 1)
header += number_of_pages
self.stream.write(header)
data_len = len(header)
data_len += self.write_segments(segments, fix_last_page)
seg_num = 0
for segment in segments:
seg_num = cast(int, segment["number"])
if fix_last_page:
seg_num_offset = 2
else:
seg_num_offset = 1
eof_segment = self.get_eof_segment(seg_num + seg_num_offset)
data = self.encode_segment(eof_segment)
self.stream.write(data)
data_len += len(data)
return data_len
def encode_segment(self, segment: JBIG2Segment) -> bytes:
data = b""
for field_format, name in SEG_STRUCT:
value = segment.get(name)
encoder = getattr(self, "encode_%s" % name, None)
if callable(encoder):
field = encoder(value, segment)
else:
field = pack(field_format, value)
data += field
return data
def encode_flags(self, value: JBIG2SegmentFlags, segment: JBIG2Segment) -> bytes:
flags = 0
if value.get("deferred"):
flags |= HEADER_FLAG_DEFERRED
if "page_assoc_long" in value:
flags |= HEADER_FLAG_PAGE_ASSOC_LONG if value["page_assoc_long"] else flags
else:
flags |= (
HEADER_FLAG_PAGE_ASSOC_LONG
if cast(int, segment.get("page", 0)) > 255
else flags
)
flags |= mask_value(SEG_TYPE_MASK, value["type"])
return pack(">B", flags)
def encode_retention_flags(
self,
value: JBIG2RetentionFlags,
segment: JBIG2Segment,
) -> bytes:
flags = []
flags_format = ">B"
ref_count = value["ref_count"]
assert isinstance(ref_count, int)
retain_segments = cast(List[bool], value.get("retain_segments", []))
if ref_count <= 4:
flags_byte = mask_value(REF_COUNT_SHORT_MASK, ref_count)
for ref_index, ref_retain in enumerate(retain_segments):
if ref_retain:
flags_byte |= 1 << ref_index
flags.append(flags_byte)
else:
bytes_count = math.ceil((ref_count + 1) / 8)
flags_format = ">L" + ("B" * bytes_count)
flags_dword = mask_value(REF_COUNT_SHORT_MASK, REF_COUNT_LONG) << 24
flags.append(flags_dword)
for byte_index in range(bytes_count):
ret_byte = 0
ret_part = retain_segments[byte_index * 8 : byte_index * 8 + 8]
for bit_pos, ret_seg in enumerate(ret_part):
ret_byte |= 1 << bit_pos if ret_seg else ret_byte
flags.append(ret_byte)
ref_segments = cast(List[int], value.get("ref_segments", []))
seg_num = cast(int, segment["number"])
if seg_num <= 256:
ref_format = "B"
elif seg_num <= 65536:
ref_format = "I"
else:
ref_format = "L"
for ref in ref_segments:
flags_format += ref_format
flags.append(ref)
return pack(flags_format, *flags)
def encode_data_length(self, value: int, segment: JBIG2Segment) -> bytes:
data = pack(">L", value)
data += cast(bytes, segment["raw_data"])
return data
def get_eop_segment(self, seg_number: int, page_number: int) -> JBIG2Segment:
return {
"data_length": 0,
"flags": {"deferred": False, "type": SEG_TYPE_END_OF_PAGE},
"number": seg_number,
"page_assoc": page_number,
"raw_data": b"",
"retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
}
def get_eof_segment(self, seg_number: int) -> JBIG2Segment:
return {
"data_length": 0,
"flags": {"deferred": False, "type": SEG_TYPE_END_OF_FILE},
"number": seg_number,
"page_assoc": 0,
"raw_data": b"",
"retention_flags": JBIG2StreamWriter.EMPTY_RETENTION_FLAGS,
}