|
from __future__ import annotations |
|
|
|
import io |
|
import os |
|
import typing |
|
from pathlib import Path |
|
|
|
from ._types import ( |
|
AsyncByteStream, |
|
FileContent, |
|
FileTypes, |
|
RequestData, |
|
RequestFiles, |
|
SyncByteStream, |
|
) |
|
from ._utils import ( |
|
format_form_param, |
|
guess_content_type, |
|
peek_filelike_length, |
|
primitive_value_to_str, |
|
to_bytes, |
|
) |
|
|
|
|
|
def get_multipart_boundary_from_content_type( |
|
content_type: bytes | None, |
|
) -> bytes | None: |
|
if not content_type or not content_type.startswith(b"multipart/form-data"): |
|
return None |
|
|
|
|
|
if b";" in content_type: |
|
for section in content_type.split(b";"): |
|
if section.strip().lower().startswith(b"boundary="): |
|
return section.strip()[len(b"boundary=") :].strip(b'"') |
|
return None |
|
|
|
|
|
class DataField: |
|
""" |
|
A single form field item, within a multipart form field. |
|
""" |
|
|
|
def __init__(self, name: str, value: str | bytes | int | float | None) -> None: |
|
if not isinstance(name, str): |
|
raise TypeError( |
|
f"Invalid type for name. Expected str, got {type(name)}: {name!r}" |
|
) |
|
if value is not None and not isinstance(value, (str, bytes, int, float)): |
|
raise TypeError( |
|
"Invalid type for value. Expected primitive type," |
|
f" got {type(value)}: {value!r}" |
|
) |
|
self.name = name |
|
self.value: str | bytes = ( |
|
value if isinstance(value, bytes) else primitive_value_to_str(value) |
|
) |
|
|
|
def render_headers(self) -> bytes: |
|
if not hasattr(self, "_headers"): |
|
name = format_form_param("name", self.name) |
|
self._headers = b"".join( |
|
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"] |
|
) |
|
|
|
return self._headers |
|
|
|
def render_data(self) -> bytes: |
|
if not hasattr(self, "_data"): |
|
self._data = to_bytes(self.value) |
|
|
|
return self._data |
|
|
|
def get_length(self) -> int: |
|
headers = self.render_headers() |
|
data = self.render_data() |
|
return len(headers) + len(data) |
|
|
|
def render(self) -> typing.Iterator[bytes]: |
|
yield self.render_headers() |
|
yield self.render_data() |
|
|
|
|
|
class FileField: |
|
""" |
|
A single file field item, within a multipart form field. |
|
""" |
|
|
|
CHUNK_SIZE = 64 * 1024 |
|
|
|
def __init__(self, name: str, value: FileTypes) -> None: |
|
self.name = name |
|
|
|
fileobj: FileContent |
|
|
|
headers: dict[str, str] = {} |
|
content_type: str | None = None |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(value, tuple): |
|
if len(value) == 2: |
|
|
|
|
|
filename, fileobj = value |
|
elif len(value) == 3: |
|
filename, fileobj, content_type = value |
|
else: |
|
|
|
filename, fileobj, content_type, headers = value |
|
else: |
|
filename = Path(str(getattr(value, "name", "upload"))).name |
|
fileobj = value |
|
|
|
if content_type is None: |
|
content_type = guess_content_type(filename) |
|
|
|
has_content_type_header = any("content-type" in key.lower() for key in headers) |
|
if content_type is not None and not has_content_type_header: |
|
|
|
|
|
|
|
headers["Content-Type"] = content_type |
|
|
|
if isinstance(fileobj, io.StringIO): |
|
raise TypeError( |
|
"Multipart file uploads require 'io.BytesIO', not 'io.StringIO'." |
|
) |
|
if isinstance(fileobj, io.TextIOBase): |
|
raise TypeError( |
|
"Multipart file uploads must be opened in binary mode, not text mode." |
|
) |
|
|
|
self.filename = filename |
|
self.file = fileobj |
|
self.headers = headers |
|
|
|
def get_length(self) -> int | None: |
|
headers = self.render_headers() |
|
|
|
if isinstance(self.file, (str, bytes)): |
|
return len(headers) + len(to_bytes(self.file)) |
|
|
|
file_length = peek_filelike_length(self.file) |
|
|
|
|
|
|
|
if file_length is None: |
|
return None |
|
|
|
return len(headers) + file_length |
|
|
|
def render_headers(self) -> bytes: |
|
if not hasattr(self, "_headers"): |
|
parts = [ |
|
b"Content-Disposition: form-data; ", |
|
format_form_param("name", self.name), |
|
] |
|
if self.filename: |
|
filename = format_form_param("filename", self.filename) |
|
parts.extend([b"; ", filename]) |
|
for header_name, header_value in self.headers.items(): |
|
key, val = f"\r\n{header_name}: ".encode(), header_value.encode() |
|
parts.extend([key, val]) |
|
parts.append(b"\r\n\r\n") |
|
self._headers = b"".join(parts) |
|
|
|
return self._headers |
|
|
|
def render_data(self) -> typing.Iterator[bytes]: |
|
if isinstance(self.file, (str, bytes)): |
|
yield to_bytes(self.file) |
|
return |
|
|
|
if hasattr(self.file, "seek"): |
|
try: |
|
self.file.seek(0) |
|
except io.UnsupportedOperation: |
|
pass |
|
|
|
chunk = self.file.read(self.CHUNK_SIZE) |
|
while chunk: |
|
yield to_bytes(chunk) |
|
chunk = self.file.read(self.CHUNK_SIZE) |
|
|
|
def render(self) -> typing.Iterator[bytes]: |
|
yield self.render_headers() |
|
yield from self.render_data() |
|
|
|
|
|
class MultipartStream(SyncByteStream, AsyncByteStream): |
|
""" |
|
Request content as streaming multipart encoded form data. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
data: RequestData, |
|
files: RequestFiles, |
|
boundary: bytes | None = None, |
|
) -> None: |
|
if boundary is None: |
|
boundary = os.urandom(16).hex().encode("ascii") |
|
|
|
self.boundary = boundary |
|
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode( |
|
"ascii" |
|
) |
|
self.fields = list(self._iter_fields(data, files)) |
|
|
|
def _iter_fields( |
|
self, data: RequestData, files: RequestFiles |
|
) -> typing.Iterator[FileField | DataField]: |
|
for name, value in data.items(): |
|
if isinstance(value, (tuple, list)): |
|
for item in value: |
|
yield DataField(name=name, value=item) |
|
else: |
|
yield DataField(name=name, value=value) |
|
|
|
file_items = files.items() if isinstance(files, typing.Mapping) else files |
|
for name, value in file_items: |
|
yield FileField(name=name, value=value) |
|
|
|
def iter_chunks(self) -> typing.Iterator[bytes]: |
|
for field in self.fields: |
|
yield b"--%s\r\n" % self.boundary |
|
yield from field.render() |
|
yield b"\r\n" |
|
yield b"--%s--\r\n" % self.boundary |
|
|
|
def get_content_length(self) -> int | None: |
|
""" |
|
Return the length of the multipart encoded content, or `None` if |
|
any of the files have a length that cannot be determined upfront. |
|
""" |
|
boundary_length = len(self.boundary) |
|
length = 0 |
|
|
|
for field in self.fields: |
|
field_length = field.get_length() |
|
if field_length is None: |
|
return None |
|
|
|
length += 2 + boundary_length + 2 |
|
length += field_length |
|
length += 2 |
|
|
|
length += 2 + boundary_length + 4 |
|
return length |
|
|
|
|
|
|
|
def get_headers(self) -> dict[str, str]: |
|
content_length = self.get_content_length() |
|
content_type = self.content_type |
|
if content_length is None: |
|
return {"Transfer-Encoding": "chunked", "Content-Type": content_type} |
|
return {"Content-Length": str(content_length), "Content-Type": content_type} |
|
|
|
def __iter__(self) -> typing.Iterator[bytes]: |
|
for chunk in self.iter_chunks(): |
|
yield chunk |
|
|
|
async def __aiter__(self) -> typing.AsyncIterator[bytes]: |
|
for chunk in self.iter_chunks(): |
|
yield chunk |
|
|