Skip to content

Commit 2244d33

Browse files
committed
Add content_type field to File class
1 parent 7aa8d99 commit 2244d33

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

python_multipart/multipart.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,20 @@ class File:
358358
config: The configuration for this File. See above for valid configuration keys and their corresponding values.
359359
""" # noqa: E501
360360

361-
def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
361+
def __init__(
362+
self,
363+
file_name: bytes | None,
364+
field_name: bytes | None = None,
365+
config: FileConfig = {},
366+
content_type: bytes | None = None,
367+
) -> None:
362368
# Save configuration, set other variables default.
363369
self.logger = logging.getLogger(__name__)
364370
self._config = config
365371
self._in_memory = True
366372
self._bytes_written = 0
367373
self._fileobj: BytesIO | BufferedRandom = BytesIO()
374+
self._content_type = content_type
368375

369376
# Save the provided field/file name.
370377
self._field_name = field_name
@@ -392,6 +399,11 @@ def file_name(self) -> bytes | None:
392399
"""The file name given in the upload request."""
393400
return self._file_name
394401

402+
@property
403+
def content_type(self) -> bytes | None:
404+
"""The Content-Type given in the upload request."""
405+
return self._content_type
406+
395407
@property
396408
def actual_file_name(self) -> bytes | None:
397409
"""The file name that this file is saved as. Will be None if it's not
@@ -570,7 +582,9 @@ def close(self) -> None:
570582
self._fileobj.close()
571583

572584
def __repr__(self) -> str:
573-
return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name)
585+
return "{}(file_name={!r}, field_name={!r}, content_type={!r})".format(
586+
self.__class__.__name__, self.file_name, self.field_name, self.content_type
587+
)
574588

575589

576590
class BaseParser:
@@ -1695,7 +1709,12 @@ def on_headers_finished() -> None:
16951709
if file_name is None:
16961710
f_multi = FieldClass(field_name)
16971711
else:
1698-
f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config))
1712+
f_multi = FileClass(
1713+
file_name,
1714+
field_name,
1715+
config=cast("FileConfig", self.config),
1716+
content_type=headers.get(b"Content-Type", None),
1717+
)
16991718
is_file = True
17001719

17011720
# Parse the given Content-Transfer-Encoding to determine what

tests/test_multipart.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,31 @@ def on_header_begin() -> None:
13671367
# for each header in the multipart message.
13681368
self.assertEqual(calls, 3)
13691369

1370+
def test_file_content_type_is_set(self) -> None:
1371+
"""
1372+
This test verifies that the content_type is set on File
1373+
https://github.com/Kludex/python-multipart/issues/207
1374+
"""
1375+
1376+
file: FileProtocol | None = None
1377+
1378+
with open(os.path.join(http_tests_dir, "single_file.http"), "rb") as f:
1379+
test_data = f.read()
1380+
1381+
def on_file(f: FileProtocol) -> None:
1382+
nonlocal file
1383+
file = f
1384+
1385+
parser = FormParser("multipart/form-data", None, on_file, boundary=b"----WebKitFormBoundary5BZGOJCWtXGYC9HW")
1386+
1387+
# Create multipart parser and feed it
1388+
i = parser.write(test_data)
1389+
parser.finalize()
1390+
1391+
self.assertEqual(i, len(test_data))
1392+
self.assertIsNotNone(file)
1393+
self.assertEqual(file.content_type, b"text/plain")
1394+
13701395

13711396
class TestHelperFunctions(unittest.TestCase):
13721397
def test_create_form_parser(self) -> None:

0 commit comments

Comments
 (0)