diff --git a/.devcontainer/on-create-command.sh b/.devcontainer/on-create-command.sh index fdf77952f..eaebea618 100755 --- a/.devcontainer/on-create-command.sh +++ b/.devcontainer/on-create-command.sh @@ -1,9 +1,7 @@ #!/bin/bash set -e - -python3 -m venv .venv +python3 -m venv --upgrade-deps .venv . .venv/bin/activate -pip install -U pip pip install -r requirements/dev.txt pip install -e . pre-commit install --install-hooks diff --git a/.editorconfig b/.editorconfig index e32c8029d..2ff985a67 100644 --- a/.editorconfig +++ b/.editorconfig @@ -9,5 +9,5 @@ end_of_line = lf charset = utf-8 max_line_length = 88 -[*.{yml,yaml,json,js,css,html}] +[*.{css,html,js,json,jsx,scss,ts,tsx,yaml,yml}] indent_size = 2 diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 6ac59c8e2..000000000 --- a/.flake8 +++ /dev/null @@ -1,29 +0,0 @@ -[flake8] -extend-select = - # bugbear - B - # bugbear opinions - B9 - # implicit str concat - ISC -extend-ignore = - # slice notation whitespace, invalid - E203 - # import at top, too many circular import fixes - E402 - # line length, handled by bugbear B950 - E501 - # bare except, handled by bugbear B001 - E722 - # zip with strict=, requires python >= 3.10 - B905 - # string formatting opinion, B028 renamed to B907 - B028 - B907 -# up to 88 allowed by bugbear B950 -max-line-length = 80 -per-file-ignores = - # __init__ exports names - **/__init__.py: F401 - # LocalProxy assigns lambdas - src/werkzeug/local.py: E731 diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md index eb5e22b21..cdbeececf 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.md +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -4,8 +4,8 @@ about: Report a bug in Werkzeug (not other projects which depend on Werkzeug) --- diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 9df4cec0e..88a049ead 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,11 +1,11 @@ blank_issues_enabled: false contact_links: - name: Security issue - url: security@palletsprojects.com - about: Do not report security issues publicly. Email our security contact. - - name: Questions - url: https://stackoverflow.com/questions/tagged/werkzeug?tab=Frequent - about: Search for and ask questions about your code on Stack Overflow. - - name: Questions and discussions + url: https://github.com/pallets/werkzeug/security/advisories/new + about: Do not report security issues publicly. Create a private advisory. + - name: Questions on Discussions + url: https://github.com/pallets/werkzeug/discussions/ + about: Ask questions about your own code on the Discussions tab. + - name: Questions on Chat url: https://discord.gg/pallets - about: Discuss questions about your code on our Discord chat. + about: Ask questions about your own code on our Discord chat. diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md index 48698798f..18eaef7b5 100644 --- a/.github/ISSUE_TEMPLATE/feature-request.md +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -5,7 +5,7 @@ about: Suggest a new feature for Werkzeug -- fixes # +fixes # +--> - -Checklist: - -- [ ] Add tests that demonstrate the correct behavior of the change. Tests should fail without the change. -- [ ] Add or update relevant docs, in the docs folder and in code. -- [ ] Add an entry in `CHANGES.rst` summarizing the change and linking to the issue. -- [ ] Add `.. versionchanged::` entries in any relevant code docs. -- [ ] Run `pre-commit` hooks and fix any issues. -- [ ] Run `pytest` and `tox`, no tests failed. diff --git a/.github/workflows/lock.yaml b/.github/workflows/lock.yaml index c790fae5c..22228a1cd 100644 --- a/.github/workflows/lock.yaml +++ b/.github/workflows/lock.yaml @@ -1,25 +1,23 @@ -name: 'Lock threads' -# Lock closed issues that have not received any further activity for -# two weeks. This does not close open issues, only humans may do that. -# We find that it is easier to respond to new issues with fresh examples -# rather than continuing discussions on old issues. +name: Lock inactive closed issues +# Lock closed issues that have not received any further activity for two weeks. +# This does not close open issues, only humans may do that. It is easier to +# respond to new issues with fresh examples rather than continuing discussions +# on old issues. on: schedule: - cron: '0 0 * * *' - permissions: issues: write pull-requests: write - concurrency: group: lock - jobs: lock: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@c1b35aecc5cdb1a34539d14196df55838bb2f836 + - uses: dessant/lock-threads@1bf7ec25051fe7c00bdd17e6a7cf3d7bfb7dc771 # v5.0.1 with: issue-inactive-days: 14 pr-inactive-days: 14 + discussion-inactive-days: 14 diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml new file mode 100644 index 000000000..adddea75d --- /dev/null +++ b/.github/workflows/pre-commit.yaml @@ -0,0 +1,16 @@ +name: pre-commit +on: + pull_request: + push: + branches: [main, stable] +jobs: + main: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 + with: + python-version: 3.x + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 + - uses: pre-commit-ci/lite-action@9d882e7a565f7008d4faf128f27d1cb6503d4ebf # v1.0.2 + if: ${{ !cancelled() }} diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index f5209f4e9..61c622140 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -9,12 +9,12 @@ jobs: outputs: hash: ${{ steps.hash.outputs.hash }} steps: - - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 - - uses: actions/setup-python@d27e3f3d7c64b4bbf8e4abfb9b63b83e846e0435 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.x' - cache: 'pip' - cache-dependency-path: 'requirements/*.txt' + cache: pip + cache-dependency-path: requirements*/*.txt - run: pip install -r requirements/build.txt # Use the commit date instead of the current date during the build. - run: echo "SOURCE_DATE_EPOCH=$(git log -1 --pretty=%ct)" >> $GITHUB_ENV @@ -23,28 +23,28 @@ jobs: - name: generate hash id: hash run: cd dist && echo "hash=$(sha256sum * | base64 -w0)" >> $GITHUB_OUTPUT - - uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce + - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: path: ./dist provenance: - needs: ['build'] + needs: [build] permissions: actions: read id-token: write contents: write # Can't pin with hash due to how this workflow works. - uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v1.5.0 + uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v2.0.0 with: base64-subjects: ${{ needs.build.outputs.hash }} create-release: # Upload the sdist, wheels, and provenance to a GitHub release. They remain # available as build artifacts for a while as well. - needs: ['provenance'] + needs: [provenance] runs-on: ubuntu-latest permissions: contents: write steps: - - uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a + - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 - name: create release run: > gh release create --draft --repo ${{ github.repository }} @@ -53,20 +53,21 @@ jobs: env: GH_TOKEN: ${{ github.token }} publish-pypi: - needs: ['provenance'] + needs: [provenance] # Wait for approval before attempting to upload to PyPI. This allows reviewing the # files in the draft release. - environment: 'publish' + environment: + name: publish + url: https://pypi.org/project/Werkzeug/${{ github.ref_name }} runs-on: ubuntu-latest permissions: id-token: write steps: - - uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a - # Try uploading to Test PyPI first, in case something fails. - - uses: pypa/gh-action-pypi-publish@29930c9cf57955dc1b98162d0d8bc3ec80d9e75c + - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 + - uses: pypa/gh-action-pypi-publish@f7600683efdcb7656dec5b29656edb7bc586e597 # v1.10.3 with: repository-url: https://test.pypi.org/legacy/ packages-dir: artifact/ - - uses: pypa/gh-action-pypi-publish@29930c9cf57955dc1b98162d0d8bc3ec80d9e75c + - uses: pypa/gh-action-pypi-publish@f7600683efdcb7656dec5b29656edb7bc586e597 # v1.10.3 with: packages-dir: artifact/ diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f29fe5f73..5d26be45b 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,50 +1,49 @@ name: Tests on: push: - branches: - - main - - '*.x' - paths-ignore: - - 'docs/**' - - '*.md' - - '*.rst' + branches: [main, stable] + paths-ignore: ['docs/**', '*.md', '*.rst'] pull_request: - branches: - - main - - '*.x' - paths-ignore: - - 'docs/**' - - '*.md' - - '*.rst' + paths-ignore: ['docs/**', '*.md', '*.rst'] jobs: tests: - name: ${{ matrix.name }} - runs-on: ${{ matrix.os }} + name: ${{ matrix.name || matrix.python }} + runs-on: ${{ matrix.os || 'ubuntu-latest' }} strategy: fail-fast: false matrix: include: - - {name: Linux, python: '3.11', os: ubuntu-latest, tox: py311} - - {name: Windows, python: '3.11', os: windows-latest, tox: py311} - - {name: Mac, python: '3.11', os: macos-latest, tox: py311} - - {name: '3.12-dev', python: '3.12-dev', os: ubuntu-latest, tox: py312} - - {name: '3.10', python: '3.10', os: ubuntu-latest, tox: py310} - - {name: '3.9', python: '3.9', os: ubuntu-latest, tox: py39} - - {name: '3.8', python: '3.8', os: ubuntu-latest, tox: py38} - - {name: 'PyPy', python: 'pypy-3.10', os: ubuntu-latest, tox: pypy310} - - {name: Typing, python: '3.11', os: ubuntu-latest, tox: typing} + - {python: '3.13'} + - {python: '3.12'} + - {name: Windows, python: '3.12', os: windows-latest} + - {name: Mac, python: '3.12', os: macos-latest} + - {python: '3.11'} + - {python: '3.10'} + - {python: '3.9'} + - {name: PyPy, python: 'pypy-3.10', tox: pypy310} steps: - - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 - - uses: actions/setup-python@d27e3f3d7c64b4bbf8e4abfb9b63b83e846e0435 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python }} - cache: 'pip' - cache-dependency-path: 'requirements/*.txt' + allow-prereleases: true + cache: pip + cache-dependency-path: requirements*/*.txt + - run: pip install tox + - run: tox run -e ${{ matrix.tox || format('py{0}', matrix.python) }} + typing: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: '3.x' + cache: pip + cache-dependency-path: requirements*/*.txt - name: cache mypy - uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 with: path: ./.mypy_cache - key: mypy|${{ matrix.python }}|${{ hashFiles('pyproject.toml') }} - if: matrix.tox == 'typing' + key: mypy|${{ hashFiles('pyproject.toml') }} - run: pip install tox - - run: tox run -e ${{ matrix.tox }} + - run: tox run -e typing diff --git a/.gitignore b/.gitignore index aecea1a7b..bbeb14f16 100644 --- a/.gitignore +++ b/.gitignore @@ -1,26 +1,11 @@ -MANIFEST -build -dist -/src/Werkzeug.egg-info -*.pyc -*.pyo -.venv -.DS_Store -docs/_build -bench/a -bench/b -.tox -.coverage -.coverage.* -coverage_out -htmlcov -.cache -.xprocess -.hypothesis -test_uwsgi_failed -.idea +.idea/ +.vscode/ +.venv*/ +venv*/ +__pycache__/ +dist/ +.coverage* +htmlcov/ .pytest_cache/ -venv/ -.vscode -.mypy_cache/ -.dmypy.json +.tox/ +docs/_build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6425015cf..6ad19aacd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,42 +1,14 @@ -ci: - autoupdate_branch: "2.3.x" - autoupdate_schedule: monthly repos: - - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.1 hooks: - - id: pyupgrade - args: ["--py38-plus"] - - repo: https://github.com/asottile/reorder-python-imports - rev: v3.10.0 - hooks: - - id: reorder-python-imports - name: Reorder Python imports (src, tests) - files: "^(?!examples/)" - args: ["--application-directories", ".:src"] - - id: reorder-python-imports - name: Reorder Python imports (examples) - files: "^examples/" - args: ["--application-directories", "examples"] - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - additional_dependencies: - - flake8-bugbear - - flake8-implicit-str-concat - - repo: https://github.com/peterdemin/pip-compile-multi - rev: v2.6.3 - hooks: - - id: pip-compile-multi-verify + - id: ruff + - id: ruff-format - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: + - id: check-merge-conflict + - id: debug-statements - id: fix-byte-order-marker - id: trailing-whitespace - id: end-of-file-fixer - exclude: "^tests/.*.http$" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 346900b20..865c68597 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,8 +1,8 @@ version: 2 build: - os: ubuntu-20.04 + os: ubuntu-22.04 tools: - python: "3.10" + python: '3.12' python: install: - requirements: requirements/docs.txt diff --git a/CHANGES.rst b/CHANGES.rst index b348506d0..de3f2b7c9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,12 +1,194 @@ .. currentmodule:: werkzeug +Version 3.1.3 +------------- + +Released 2024-11-08 + +- Initial data passed to ``MultiDict`` and similar interfaces only accepts + ``list``, ``tuple``, or ``set`` when passing multiple values. It had been + changed to accept any ``Collection``, but this matched types that should be + treated as single values, such as ``bytes``. :issue:`2994` +- When the ``Host`` header is not set and ``Request.host`` falls back to the + WSGI ``SERVER_NAME`` value, if that value is an IPv6 address it is wrapped + in ``[]`` to match the ``Host`` header. :issue:`2993` + + +Version 3.1.2 +------------- + +Released 2024-11-04 + +- Improve type annotation for ``TypeConversionDict.get`` to allow the ``type`` + parameter to be a callable. :issue:`2988` +- ``Headers`` does not inherit from ``MutableMapping``, as it is does not + exactly match that interface. :issue:`2989` + + +Version 3.1.1 +------------- + +Released 2024-11-01 + +- Fix an issue that caused ``str(Request.headers)`` to always appear empty. + :issue:`2985` + + +Version 3.1.0 +------------- + +Released 2024-10-31 + +- Drop support for Python 3.8. :pr:`2966` +- Remove previously deprecated code. :pr:`2967` +- ``Request.max_form_memory_size`` defaults to 500kB instead of unlimited. + Non-file form fields over this size will cause a ``RequestEntityTooLarge`` + error. :issue:`2964` +- ``OrderedMultiDict`` and ``ImmutableOrderedMultiDict`` are deprecated. + Use ``MultiDict`` and ``ImmutableMultiDict`` instead. :issue:`2968` +- Behavior of properties on ``request.cache_control`` and + ``response.cache_control`` has been significantly adjusted. + + - Dict values are always ``str | None``. Setting properties will convert + the value to a string. Setting a property to ``False`` is equivalent to + setting it to ``None``. Getting typed properties will return ``None`` if + conversion raises ``ValueError``, rather than the string. :issue:`2980` + - ``max_age`` is ``None`` if present without a value, rather than ``-1``. + :issue:`2980` + - ``no_cache`` is a boolean for requests, it is ``True`` instead of + ``"*"`` when present. It remains a string for responses. :issue:`2980` + - ``max_stale`` is ``True`` if present without a value, rather + than ``"*"``. :issue:`2980` + - ``no_transform`` is a boolean. Previously it was mistakenly always + ``None``. :issue:`2881` + - ``min_fresh`` is ``None`` if present without a value, rather than + ``"*"``. :issue:`2881` + - ``private`` is ``True`` if present without a value, rather than ``"*"``. + :issue:`2980` + - Added the ``must_understand`` property. :issue:`2881` + - Added the ``stale_while_revalidate``, and ``stale_if_error`` + properties. :issue:`2948` + - Type annotations more accurately reflect the values. :issue:`2881` + +- Support Cookie CHIPS (Partitioned Cookies). :issue:`2797` +- Add 421 ``MisdirectedRequest`` HTTP exception. :issue:`2850` +- Increase default work factor for PBKDF2 to 1,000,000 iterations. + :issue:`2969` +- Inline annotations for ``datastructures``, removing stub files. + :issue:`2970` +- ``MultiDict.getlist`` catches ``TypeError`` in addition to ``ValueError`` + when doing type conversion. :issue:`2976` +- Implement ``|`` and ``|=`` operators for ``MultiDict``, ``Headers``, and + ``CallbackDict``, and disallow ``|=`` on immutable types. :issue:`2977` + + +Version 3.0.6 +------------- + +Released 2024-10-25 + +- Fix how ``max_form_memory_size`` is applied when parsing large non-file + fields. :ghsa:`q34m-jh98-gwm2` +- ``safe_join`` catches certain paths on Windows that were not caught by + ``ntpath.isabs`` on Python < 3.11. :ghsa:`f9vj-2wh5-fj8j` + + +Version 3.0.5 +------------- + +Released 2024-10-24 + +- The Watchdog reloader ignores file closed no write events. :issue:`2945` +- Logging works with client addresses containing an IPv6 scope :issue:`2952` +- Ignore invalid authorization parameters. :issue:`2955` +- Improve type annotation fore ``SharedDataMiddleware``. :issue:`2958` +- Compatibility with Python 3.13 when generating debugger pin and the current + UID does not have an associated name. :issue:`2957` + + +Version 3.0.4 +------------- + +Released 2024-08-21 + +- Restore behavior where parsing `multipart/x-www-form-urlencoded` data with + invalid UTF-8 bytes in the body results in no form data parsed rather than a + 413 error. :issue:`2930` +- Improve ``parse_options_header`` performance when parsing unterminated + quoted string values. :issue:`2904` +- Debugger pin auth is synchronized across threads/processes when tracking + failed entries. :issue:`2916` +- Dev server handles unexpected `SSLEOFError` due to issue in Python < 3.13. + :issue:`2926` +- Debugger pin auth works when the URL already contains a query string. + :issue:`2918` + + +Version 3.0.3 +------------- + +Released 2024-05-05 + +- Only allow ``localhost``, ``.localhost``, ``127.0.0.1``, or the specified + hostname when running the dev server, to make debugger requests. Additional + hosts can be added by using the debugger middleware directly. The debugger + UI makes requests using the full URL rather than only the path. + :ghsa:`2g68-c3qc-8985` +- Make reloader more robust when ``""`` is in ``sys.path``. :pr:`2823` +- Better TLS cert format with ``adhoc`` dev certs. :pr:`2891` +- Inform Python < 3.12 how to handle ``itms-services`` URIs correctly, rather + than using an overly-broad workaround in Werkzeug that caused some redirect + URIs to be passed on without encoding. :issue:`2828` +- Type annotation for ``Rule.endpoint`` and other uses of ``endpoint`` is + ``Any``. :issue:`2836` +- Make reloader more robust when ``""`` is in ``sys.path``. :pr:`2823` + + +Version 3.0.2 +------------- + +Released 2024-04-01 + +- Ensure setting ``merge_slashes`` to ``False`` results in ``NotFound`` for + repeated-slash requests against single slash routes. :issue:`2834` +- Fix handling of ``TypeError`` in ``TypeConversionDict.get()`` to match + ``ValueError``. :issue:`2843` +- Fix ``response_wrapper`` type check in test client. :issue:`2831` +- Make the return type of ``MultiPartParser.parse`` more precise. + :issue:`2840` +- Raise an error if converter arguments cannot be parsed. :issue:`2822` + + +Version 3.0.1 +------------- + +Released 2023-10-24 + +- Fix slow multipart parsing for large parts potentially enabling DoS attacks. + + +Version 3.0.0 +------------- + +Released 2023-09-30 + +- Remove previously deprecated code. :pr:`2768` +- Deprecate the ``__version__`` attribute. Use feature detection, or + ``importlib.metadata.version("werkzeug")``, instead. :issue:`2770` +- ``generate_password_hash`` uses scrypt by default. :issue:`2769` +- Add the ``"werkzeug.profiler"`` item to the WSGI ``environ`` dictionary + passed to `ProfilerMiddleware`'s `filename_format` function. It contains + the ``elapsed`` and ``time`` values for the profiled request. :issue:`2775` +- Explicitly marked the PathConverter as non path isolating. :pr:`2784` + + Version 2.3.8 ------------- Released 2023-11-08 - Fix slow multipart parsing for large parts potentially enabling DoS - attacks. :cwe:`CWE-407` + attacks. Version 2.3.7 @@ -15,8 +197,8 @@ Version 2.3.7 Released 2023-08-14 - Use ``flit_core`` instead of ``setuptools`` as build backend. -- Fix parsing of multipart bodies. :issue:`2734` Adjust index of last newline - in data start. :issue:`2761` +- Fix parsing of multipart bodies. :issue:`2734` +- Adjust index of last newline in data start. :issue:`2761` - Parsing ints from header values strips spacing first. :issue:`2734` - Fix empty file streaming when testing. :issue:`2740` - Clearer error message when URL rule does not start with slash. :pr:`2750` @@ -1028,7 +1210,7 @@ Released 2019-03-19 (:pr:`1358`) - :func:`http.parse_cookie` ignores empty segments rather than producing a cookie with no key or value. (:issue:`1245`, :pr:`1301`) -- :func:`~http.parse_authorization_header` (and +- ``http.parse_authorization_header`` (and :class:`~datastructures.Authorization`, :attr:`~wrappers.Request.authorization`) treats the authorization header as UTF-8. On Python 2, basic auth username and password are @@ -1793,8 +1975,8 @@ Version 0.9.2 (bugfix release, released on July 18th 2013) -- Added `unsafe` parameter to :func:`~werkzeug.urls.url_quote`. -- Fixed an issue with :func:`~werkzeug.urls.url_quote_plus` not quoting +- Added ``unsafe`` parameter to ``urls.url_quote``. +- Fixed an issue with ``urls.url_quote_plus`` not quoting `'+'` correctly. - Ported remaining parts of :class:`~werkzeug.contrib.RedisCache` to Python 3.3. @@ -1843,9 +2025,8 @@ Released on June 13nd 2013, codename Planierraupe. certificates easily and load them from files. - Refactored test client to invoke the open method on the class for redirects. This makes subclassing more powerful. -- :func:`werkzeug.wsgi.make_chunk_iter` and - :func:`werkzeug.wsgi.make_line_iter` now support processing of - iterators and streams. +- ``wsgi.make_chunk_iter`` and ``make_line_iter`` now support processing + of iterators and streams. - URL generation by the routing system now no longer quotes ``+``. - URL fixing now no longer quotes certain reserved characters. @@ -1943,7 +2124,7 @@ Version 0.8.3 (bugfix release, released on February 5th 2012) -- Fixed another issue with :func:`werkzeug.wsgi.make_line_iter` +- Fixed another issue with ``wsgi.make_line_iter`` where lines longer than the buffer size were not handled properly. - Restore stdout after debug console finished executing so @@ -2011,7 +2192,7 @@ Released on September 29th 2011, codename Lötkolben - Werkzeug now uses a new method to check that the length of incoming data is complete and will raise IO errors by itself if the server fails to do so. -- :func:`~werkzeug.wsgi.make_line_iter` now requires a limit that is +- ``wsgi.make_line_iter`` now requires a limit that is not higher than the length the stream can provide. - Refactored form parsing into a form parser class that makes it possible to hook into individual parts of the parsing process for debugging and @@ -2211,7 +2392,7 @@ Released on Feb 19th 2010, codename Hammer. - the form data parser will now look at the filename instead the content type to figure out if it should treat the upload as regular form data or file upload. This fixes a bug with Google Chrome. -- improved performance of `make_line_iter` and the multipart parser +- improved performance of ``make_line_iter`` and the multipart parser for binary uploads. - fixed :attr:`~werkzeug.BaseResponse.is_streamed` - fixed a path quoting bug in `EnvironBuilder` that caused PATH_INFO and @@ -2340,7 +2521,7 @@ Released on April 24th, codename Schlagbohrer. - added :mod:`werkzeug.contrib.lint` - added `passthrough_errors` to `run_simple`. - added `secure_filename` -- added :func:`make_line_iter` +- added ``make_line_iter`` - :class:`MultiDict` copies now instead of revealing internal lists to the caller for `getlist` and iteration functions that return lists. diff --git a/LICENSE.rst b/LICENSE.txt similarity index 100% rename from LICENSE.rst rename to LICENSE.txt diff --git a/README.rst b/README.md similarity index 56% rename from README.rst rename to README.md index 220c9979a..011c0c45f 100644 --- a/README.rst +++ b/README.md @@ -1,9 +1,8 @@ -Werkzeug -======== +# Werkzeug *werkzeug* German noun: "tool". Etymology: *werk* ("work"), *zeug* ("stuff") -Werkzeug is a comprehensive `WSGI`_ web application library. It began as +Werkzeug is a comprehensive [WSGI][] web application library. It began as a simple collection of various utilities for WSGI applications and has become one of the most advanced WSGI utility libraries. @@ -31,59 +30,40 @@ choose a template engine, database adapter, and even how to handle requests. It can be used to build all sorts of end user applications such as blogs, wikis, or bulletin boards. -`Flask`_ wraps Werkzeug, using it to handle the details of WSGI while +[Flask][] wraps Werkzeug, using it to handle the details of WSGI while providing more structure and patterns for defining powerful applications. -.. _WSGI: https://wsgi.readthedocs.io/en/latest/ -.. _Flask: https://www.palletsprojects.com/p/flask/ +[WSGI]: https://wsgi.readthedocs.io/en/latest/ +[Flask]: https://www.palletsprojects.com/p/flask/ -Installing ----------- +## A Simple Example -Install and update using `pip`_: +```python +# save this as app.py +from werkzeug.wrappers import Request, Response -.. code-block:: text +@Request.application +def application(request: Request) -> Response: + return Response("Hello, World!") - pip install -U Werkzeug +if __name__ == "__main__": + from werkzeug.serving import run_simple + run_simple("127.0.0.1", 5000, application) +``` -.. _pip: https://pip.pypa.io/en/stable/getting-started/ +``` +$ python -m app + * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) +``` -A Simple Example ----------------- - -.. code-block:: python - - from werkzeug.wrappers import Request, Response - - @Request.application - def application(request): - return Response('Hello, World!') - - if __name__ == '__main__': - from werkzeug.serving import run_simple - run_simple('localhost', 4000, application) - - -Donate ------- +## Donate The Pallets organization develops and supports Werkzeug and other popular packages. In order to grow the community of contributors and users, and allow the maintainers to devote more time to the projects, -`please donate today`_. - -.. _please donate today: https://palletsprojects.com/donate - - -Links ------ +[please donate today][]. -- Documentation: https://werkzeug.palletsprojects.com/ -- Changes: https://werkzeug.palletsprojects.com/changes/ -- PyPI Releases: https://pypi.org/project/Werkzeug/ -- Source Code: https://github.com/pallets/werkzeug/ -- Issue Tracker: https://github.com/pallets/werkzeug/issues/ -- Chat: https://discord.gg/pallets +[please donate today]: https://palletsprojects.com/donate diff --git a/docs/conf.py b/docs/conf.py index e09ef8f7b..5cbbd4fe7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,18 +10,26 @@ # General -------------------------------------------------------------- -master_doc = "index" +default_role = "code" extensions = [ "sphinx.ext.autodoc", + "sphinx.ext.extlinks", "sphinx.ext.intersphinx", - "pallets_sphinx_themes", - "sphinx_issues", "sphinxcontrib.log_cabinet", + "pallets_sphinx_themes", ] autoclass_content = "both" +autodoc_member_order = "bysource" autodoc_typehints = "description" -intersphinx_mapping = {"python": ("https://docs.python.org/3/", None)} -issues_github_path = "pallets/werkzeug" +autodoc_preserve_defaults = True +extlinks = { + "issue": ("https://github.com/pallets/werkzeug/issues/%s", "#%s"), + "pr": ("https://github.com/pallets/werkzeug/pull/%s", "#%s"), + "ghsa": ("https://github.com/advisories/GHSA-%s", "GHSA-%s"), +} +intersphinx_mapping = { + "python": ("https://docs.python.org/3/", None), +} # HTML ----------------------------------------------------------------- @@ -46,9 +54,3 @@ html_logo = "_static/werkzeug-vertical.png" html_title = f"Werkzeug Documentation ({version})" html_show_sourcelink = False - -# LaTeX ---------------------------------------------------------------- - -latex_documents = [ - (master_doc, f"Werkzeug-{version}.tex", html_title, author, "manual") -] diff --git a/docs/datastructures.rst b/docs/datastructures.rst index 01432f413..e70252534 100644 --- a/docs/datastructures.rst +++ b/docs/datastructures.rst @@ -27,10 +27,33 @@ General Purpose :members: :inherited-members: -.. autoclass:: OrderedMultiDict +.. class:: OrderedMultiDict -.. autoclass:: ImmutableMultiDict - :members: copy + Works like a regular :class:`MultiDict` but preserves the + order of the fields. To convert the ordered multi dict into a + list you can use the :meth:`items` method and pass it ``multi=True``. + + In general an :class:`OrderedMultiDict` is an order of magnitude + slower than a :class:`MultiDict`. + + .. admonition:: note + + Due to a limitation in Python you cannot convert an ordered + multi dict into a regular dict by using ``dict(multidict)``. + Instead you have to use the :meth:`to_dict` method, otherwise + the internal bucket objects are exposed. + + .. deprecated:: 3.1 + Will be removed in Werkzeug 3.2. Use ``MultiDict`` instead. + +.. class:: ImmutableMultiDict + + An immutable :class:`OrderedMultiDict`. + + .. deprecated:: 3.1 + Will be removed in Werkzeug 3.2. Use ``ImmutableMultiDict`` instead. + + .. versionadded:: 0.6 .. autoclass:: ImmutableOrderedMultiDict :members: copy @@ -69,26 +92,14 @@ HTTP Related .. autoclass:: LanguageAccept .. autoclass:: RequestCacheControl - :members: - - .. autoattribute:: no_cache - - .. autoattribute:: no_store - - .. autoattribute:: max_age - - .. autoattribute:: no_transform + :members: + :inherited-members: ImmutableDictMixin, CallbackDict + :member-order: groupwise .. autoclass:: ResponseCacheControl - :members: - - .. autoattribute:: no_cache - - .. autoattribute:: no_store - - .. autoattribute:: max_age - - .. autoattribute:: no_transform + :members: + :inherited-members: CallbackDict + :member-order: groupwise .. autoclass:: ETags :members: diff --git a/docs/debug.rst b/docs/debug.rst index 25a9f0b2d..d842135a7 100644 --- a/docs/debug.rst +++ b/docs/debug.rst @@ -16,7 +16,8 @@ interactive debug console to execute code in any frame. The debugger allows the execution of arbitrary code which makes it a major security risk. **The debugger must never be used on production machines. We cannot stress this enough. Do not enable the debugger - in production.** + in production.** Production means anything that is not development, + and anything that is publicly accessible. .. note:: @@ -72,10 +73,9 @@ argument to get a detailed list of all the attributes it has. Debugger PIN ------------ -Starting with Werkzeug 0.11 the debug console is protected by a PIN. -This is a security helper to make it less likely for the debugger to be -exploited if you forget to disable it when deploying to production. The -PIN based authentication is enabled by default. +The debug console is protected by a PIN. This is a security helper to make it +less likely for the debugger to be exploited if you forget to disable it when +deploying to production. The PIN based authentication is enabled by default. The first time a console is opened, a dialog will prompt for a PIN that is printed to the command line. The PIN is generated in a stable way @@ -92,6 +92,31 @@ intended to make it harder for an attacker to exploit the debugger. Never enable the debugger in production.** +Allowed Hosts +------------- + +The debug console will only be served if the request comes from a trusted host. +If a request comes from a browser page that is not served on a trusted URL, a +400 error will be returned. + +By default, ``localhost``, any ``.localhost`` subdomain, and ``127.0.0.1`` are +trusted. ``run_simple`` will trust its ``hostname`` argument as well. To change +this further, use the debug middleware directly rather than through +``use_debugger=True``. + +.. code-block:: python + + if os.environ.get("USE_DEBUGGER") in {"1", "true"}: + app = DebuggedApplication(app, evalex=True) + app.trusted_hosts = [...] + + run_simple("localhost", 8080, app) + +**This feature is not meant to entirely secure the debugger. It is +intended to make it harder for an attacker to exploit the debugger. +Never enable the debugger in production.** + + Pasting Errors -------------- diff --git a/docs/exceptions.rst b/docs/exceptions.rst index 88a309d45..d5b6970b1 100644 --- a/docs/exceptions.rst +++ b/docs/exceptions.rst @@ -44,6 +44,8 @@ The following error classes exist in Werkzeug: .. autoexception:: ImATeapot +.. autoexception:: MisdirectedRequest + .. autoexception:: UnprocessableEntity .. autoexception:: Locked diff --git a/docs/http.rst b/docs/http.rst index cbf4e04ed..790de3172 100644 --- a/docs/http.rst +++ b/docs/http.rst @@ -53,10 +53,6 @@ by :rfc:`2616`, Werkzeug implements some custom data structures that are .. autofunction:: parse_cache_control_header -.. autofunction:: parse_authorization_header - -.. autofunction:: parse_www_authenticate_header - .. autofunction:: parse_if_range_header .. autofunction:: parse_range_header diff --git a/docs/installation.rst b/docs/installation.rst index 7138f08c1..00513e123 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -6,7 +6,7 @@ Python Version -------------- We recommend using the latest version of Python. Werkzeug supports -Python 3.8 and newer. +Python 3.9 and newer. Optional dependencies diff --git a/docs/license.rst b/docs/license.rst index a53a98cf3..2a445f9c6 100644 --- a/docs/license.rst +++ b/docs/license.rst @@ -1,4 +1,5 @@ BSD-3-Clause License ==================== -.. include:: ../LICENSE.rst +.. literalinclude:: ../LICENSE.txt + :language: text diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 0f3714e6e..d97764e98 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -9,7 +9,7 @@ understanding of :pep:`3333` (WSGI) and :rfc:`2616` (HTTP). WSGI Environment -================ +---------------- The WSGI environment contains all the information the user request transmits to the application. It is passed to the WSGI application but you can also @@ -33,7 +33,7 @@ access the form data besides parsing that data by hand. Enter Request -============= +------------- For access to the request data the :class:`Request` object is much more fun. It wraps the `environ` and provides a read-only access to the data from @@ -112,7 +112,7 @@ The keys for the headers are of course case insensitive. Header Parsing -============== +-------------- There is more. Werkzeug provides convenient access to often used HTTP headers and other request data. @@ -141,7 +141,7 @@ the quality, the best item being the first: 'text/html' >>> 'application/xhtml+xml' in request.accept_mimetypes True ->>> print request.accept_mimetypes["application/json"] +>>> print(request.accept_mimetypes["application/json"]) 0.8 The same works for languages: @@ -183,7 +183,7 @@ True Responses -========= +--------- Response objects are the opposite of request objects. They are used to send data back to the client. In reality, response objects are nothing more than diff --git a/docs/request_data.rst b/docs/request_data.rst index b1c97b2c7..75811a902 100644 --- a/docs/request_data.rst +++ b/docs/request_data.rst @@ -79,16 +79,23 @@ request in such a way that the server uses too many resources to handle it. Each these limits will raise a :exc:`~werkzeug.exceptions.RequestEntityTooLarge` if they are exceeded. -- :attr:`~Request.max_content_length` Stop reading request data after this number +- :attr:`~Request.max_content_length` - Stop reading request data after this number of bytes. It's better to configure this in the WSGI server or HTTP server, rather than the WSGI application. -- :attr:`~Request.max_form_memory_size` Stop reading request data if any form part is - larger than this number of bytes. While file parts can be moved to disk, regular - form field data is stored in memory only. +- :attr:`~Request.max_form_memory_size` - Stop reading request data if any + non-file form field is larger than this number of bytes. While file parts + can be moved to disk, regular form field data is stored in memory only and + could fill up memory. The default is 500kB. - :attr:`~Request.max_form_parts` Stop reading request data if more than this number of parts are sent in multipart form data. This is useful to stop a very large number of very small parts, especially file parts. The default is 1000. +Each of these values can be set on the ``Request`` class to affect the default +for all requests, or on a ``request`` instance to change the behavior for a +specific request. For example, a small limit can be set by default, and a large +limit can be set on an endpoint that accepts video uploads. These values should +be tuned to the specific needs of your application and endpoints. + Using Werkzeug to set these limits is only one layer of protection. WSGI servers and HTTPS servers should set their own limits on size and timeouts. The operating system or container manager should set limits on memory and processing time for server diff --git a/docs/test.rst b/docs/test.rst index 704eb5f59..d31ac5938 100644 --- a/docs/test.rst +++ b/docs/test.rst @@ -18,8 +18,8 @@ requests. >>> response = c.get("/") >>> response.status_code 200 ->>> resp.headers -Headers([('Content-Type', 'text/html; charset=utf-8'), ('Content-Length', '6658')]) +>>> response.headers +Headers([('Content-Type', 'text/html; charset=utf-8'), ('Content-Length', '5211')]) >>> response.get_data(as_text=True) '...' diff --git a/docs/tutorial.rst b/docs/tutorial.rst index 943787a7c..9cb5aef47 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -123,7 +123,7 @@ if they are not used right away, to keep it from being confusing:: import os import redis - from werkzeug.urls import url_parse + from urllib.parse import urlparse from werkzeug.wrappers import Request, Response from werkzeug.routing import Map, Rule from werkzeug.exceptions import HTTPException, NotFound @@ -308,7 +308,7 @@ we need to write a function and a helper method. For URL validation this is good enough:: def is_valid_url(url): - parts = url_parse(url) + parts = urlparse(url) return parts.scheme in ('http', 'https') For inserting the URL, all we need is this little method on our class:: diff --git a/docs/wsgi.rst b/docs/wsgi.rst index 1992bece6..67b3bb6b8 100644 --- a/docs/wsgi.rst +++ b/docs/wsgi.rst @@ -22,10 +22,6 @@ iterator and the input stream. .. autoclass:: LimitedStream :members: -.. autofunction:: make_line_iter - -.. autofunction:: make_chunk_iter - .. autofunction:: wrap_file diff --git a/pyproject.toml b/pyproject.toml index 3a1965554..2d5a6cee2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,9 @@ [project] name = "Werkzeug" +version = "3.1.3" description = "The comprehensive WSGI web application library." -readme = "README.rst" -license = {file = "LICENSE.rst"} +readme = "README.md" +license = {file = "LICENSE.txt"} maintainers = [{name = "Pallets", email = "contact@palletsprojects.com"}] classifiers = [ "Development Status :: 5 - Production/Stable", @@ -16,10 +17,12 @@ classifiers = [ "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", "Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware", "Topic :: Software Development :: Libraries :: Application Frameworks", + "Typing :: Typed", +] +requires-python = ">=3.9" +dependencies = [ + "MarkupSafe>=2.1.1", ] -requires-python = ">=3.8" -dependencies = ["MarkupSafe>=2.1.1"] -dynamic = ["version"] [project.urls] Donate = "https://palletsprojects.com/donate" @@ -67,38 +70,51 @@ source = ["werkzeug", "tests"] source = ["src", "*/site-packages"] [tool.mypy] -python_version = "3.8" +python_version = "3.9" files = ["src/werkzeug"] show_error_codes = true pretty = true -#strict = true -allow_redefinition = true -disallow_subclassing_any = true -#disallow_untyped_calls = true -disallow_untyped_defs = true -disallow_incomplete_defs = true -no_implicit_optional = true -local_partial_types = true -no_implicit_reexport = true -strict_equality = true -warn_redundant_casts = true -warn_unused_configs = true -warn_unused_ignores = true -warn_return_any = true -#warn_unreachable = True - -[[tool.mypy.overrides]] -module = ["werkzeug.wrappers"] -no_implicit_reexport = false +strict = true [[tool.mypy.overrides]] module = [ "colorama.*", "cryptography.*", - "eventlet.*", - "gevent.*", - "greenlet.*", + "ephemeral_port_reserve", "watchdog.*", "xprocess.*", ] ignore_missing_imports = true + +[tool.pyright] +pythonVersion = "3.9" +include = ["src/werkzeug"] + +[tool.ruff] +extend-exclude = ["examples/"] +src = ["src"] +fix = true +show-fixes = true +output-format = "full" + +[tool.ruff.lint] +select = [ + "B", # flake8-bugbear + "E", # pycodestyle error + "F", # pyflakes + "I", # isort + "UP", # pyupgrade + "W", # pycodestyle warning +] +ignore = [ + "E402", # allow circular imports at end of file +] + +[tool.ruff.lint.isort] +force-single-line = true +order-by-type = false + +[tool.gha-update] +tag-only = [ + "slsa-framework/slsa-github-generator", +] diff --git a/requirements/build.txt b/requirements/build.txt index 196545d0e..1b13b0552 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -1,13 +1,12 @@ -# SHA1:80754af91bfb6d1073585b046fe0a474ce868509 # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: # -# pip-compile-multi +# pip-compile build.in # -build==0.10.0 - # via -r requirements/build.in -packaging==23.1 +build==1.2.2.post1 + # via -r build.in +packaging==24.1 # via build -pyproject-hooks==1.0.0 +pyproject-hooks==1.2.0 # via build diff --git a/requirements/dev.in b/requirements/dev.in index 99f5942f8..1efde82b1 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -1,6 +1,5 @@ --r docs.in --r tests.in --r typing.in -pip-compile-multi +-r docs.txt +-r tests.txt +-r typing.txt pre-commit tox diff --git a/requirements/dev.txt b/requirements/dev.txt index ed462080a..24eb34a53 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,64 +1,196 @@ -# SHA1:54b5b77ec8c7a0064ffa93b2fd16cb0130ba177c # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: # -# pip-compile-multi +# pip-compile dev.in # --r docs.txt --r tests.txt --r typing.txt -build==0.10.0 - # via pip-tools -cachetools==5.3.1 +alabaster==1.0.0 + # via + # -r docs.txt + # sphinx +babel==2.16.0 + # via + # -r docs.txt + # sphinx +cachetools==5.5.0 # via tox -cfgv==3.3.1 +certifi==2024.8.30 + # via + # -r docs.txt + # requests +cffi==1.17.1 + # via + # -r tests.txt + # cryptography +cfgv==3.4.0 # via pre-commit -chardet==5.1.0 +chardet==5.2.0 # via tox -click==8.1.3 +charset-normalizer==3.4.0 # via - # pip-compile-multi - # pip-tools + # -r docs.txt + # requests colorama==0.4.6 # via tox -distlib==0.3.6 +cryptography==43.0.3 + # via -r tests.txt +distlib==0.3.9 # via virtualenv -filelock==3.12.2 +docutils==0.21.2 + # via + # -r docs.txt + # sphinx +ephemeral-port-reserve==1.1.4 + # via -r tests.txt +filelock==3.16.1 # via # tox # virtualenv -identify==2.5.24 +identify==2.6.1 # via pre-commit -nodeenv==1.8.0 - # via pre-commit -pip-compile-multi==2.6.3 - # via -r requirements/dev.in -pip-tools==6.13.0 - # via pip-compile-multi -platformdirs==3.8.0 +idna==3.10 + # via + # -r docs.txt + # requests +imagesize==1.4.1 + # via + # -r docs.txt + # sphinx +iniconfig==2.0.0 + # via + # -r tests.txt + # -r typing.txt + # pytest +jinja2==3.1.4 + # via + # -r docs.txt + # sphinx +markupsafe==3.0.2 + # via + # -r docs.txt + # jinja2 +mypy==1.13.0 + # via -r typing.txt +mypy-extensions==1.0.0 + # via + # -r typing.txt + # mypy +nodeenv==1.9.1 + # via + # -r typing.txt + # pre-commit + # pyright +packaging==24.1 + # via + # -r docs.txt + # -r tests.txt + # -r typing.txt + # pallets-sphinx-themes + # pyproject-api + # pytest + # sphinx + # tox +pallets-sphinx-themes==2.3.0 + # via -r docs.txt +platformdirs==4.3.6 # via # tox # virtualenv -pre-commit==3.3.3 - # via -r requirements/dev.in -pyproject-api==1.5.2 +pluggy==1.5.0 + # via + # -r tests.txt + # -r typing.txt + # pytest + # tox +pre-commit==4.0.1 + # via -r dev.in +pycparser==2.22 + # via + # -r tests.txt + # cffi +pygments==2.18.0 + # via + # -r docs.txt + # sphinx +pyproject-api==1.8.0 # via tox -pyproject-hooks==1.0.0 - # via build -pyyaml==6.0 +pyright==1.1.386 + # via -r typing.txt +pytest==8.3.3 + # via + # -r tests.txt + # -r typing.txt + # pytest-timeout +pytest-timeout==2.3.1 + # via -r tests.txt +pyyaml==6.0.2 # via pre-commit -toposort==1.10 - # via pip-compile-multi -tox==4.6.3 - # via -r requirements/dev.in -virtualenv==20.23.1 +requests==2.32.3 + # via + # -r docs.txt + # sphinx +snowballstemmer==2.2.0 + # via + # -r docs.txt + # sphinx +sphinx==8.1.3 + # via + # -r docs.txt + # pallets-sphinx-themes + # sphinx-notfound-page + # sphinxcontrib-log-cabinet +sphinx-notfound-page==1.0.4 + # via + # -r docs.txt + # pallets-sphinx-themes +sphinxcontrib-applehelp==2.0.0 + # via + # -r docs.txt + # sphinx +sphinxcontrib-devhelp==2.0.0 + # via + # -r docs.txt + # sphinx +sphinxcontrib-htmlhelp==2.1.0 + # via + # -r docs.txt + # sphinx +sphinxcontrib-jsmath==1.0.1 + # via + # -r docs.txt + # sphinx +sphinxcontrib-log-cabinet==1.0.1 + # via -r docs.txt +sphinxcontrib-qthelp==2.0.0 + # via + # -r docs.txt + # sphinx +sphinxcontrib-serializinghtml==2.0.0 + # via + # -r docs.txt + # sphinx +tox==4.23.2 + # via -r dev.in +types-contextvars==2.4.7.3 + # via -r typing.txt +types-dataclasses==0.6.6 + # via -r typing.txt +types-setuptools==75.2.0.20241019 + # via -r typing.txt +typing-extensions==4.12.2 + # via + # -r typing.txt + # mypy + # pyright +urllib3==2.2.3 + # via + # -r docs.txt + # requests +virtualenv==20.27.0 # via # pre-commit # tox -wheel==0.40.0 - # via pip-tools - -# The following packages are considered to be unsafe in a requirements file: -# pip -# setuptools +watchdog==5.0.3 + # via + # -r tests.txt + # -r typing.txt diff --git a/requirements/docs.in b/requirements/docs.in index 7ec501b6d..ba3fd7774 100644 --- a/requirements/docs.in +++ b/requirements/docs.in @@ -1,4 +1,3 @@ -Pallets-Sphinx-Themes -Sphinx -sphinx-issues +pallets-sphinx-themes +sphinx sphinxcontrib-log-cabinet diff --git a/requirements/docs.txt b/requirements/docs.txt index e125c59a4..1e3a54ebb 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,61 +1,60 @@ -# SHA1:45c590f97fe95b8bdc755eef796e91adf5fbe4ea # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: # -# pip-compile-multi +# pip-compile docs.in # -alabaster==0.7.13 +alabaster==1.0.0 # via sphinx -babel==2.12.1 +babel==2.16.0 # via sphinx -certifi==2023.5.7 +certifi==2024.8.30 # via requests -charset-normalizer==3.1.0 +charset-normalizer==3.4.0 # via requests -docutils==0.20.1 +docutils==0.21.2 # via sphinx -idna==3.4 +idna==3.10 # via requests imagesize==1.4.1 # via sphinx -jinja2==3.1.2 +jinja2==3.1.4 # via sphinx -markupsafe==2.1.3 +markupsafe==3.0.2 # via jinja2 -packaging==23.1 +packaging==24.1 # via # pallets-sphinx-themes # sphinx -pallets-sphinx-themes==2.1.1 - # via -r requirements/docs.in -pygments==2.15.1 +pallets-sphinx-themes==2.3.0 + # via -r docs.in +pygments==2.18.0 # via sphinx -requests==2.31.0 +requests==2.32.3 # via sphinx snowballstemmer==2.2.0 # via sphinx -sphinx==7.0.1 +sphinx==8.1.3 # via - # -r requirements/docs.in + # -r docs.in # pallets-sphinx-themes - # sphinx-issues + # sphinx-notfound-page # sphinxcontrib-log-cabinet -sphinx-issues==3.0.1 - # via -r requirements/docs.in -sphinxcontrib-applehelp==1.0.4 +sphinx-notfound-page==1.0.4 + # via pallets-sphinx-themes +sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-devhelp==1.0.2 +sphinxcontrib-devhelp==2.0.0 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 +sphinxcontrib-htmlhelp==2.1.0 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-log-cabinet==1.0.1 - # via -r requirements/docs.in -sphinxcontrib-qthelp==1.0.3 + # via -r docs.in +sphinxcontrib-qthelp==2.0.0 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 +sphinxcontrib-serializinghtml==2.0.0 # via sphinx -urllib3==2.0.3 +urllib3==2.2.3 # via requests diff --git a/requirements/tests.in b/requirements/tests.in index 3ced491be..c1b5bc313 100644 --- a/requirements/tests.in +++ b/requirements/tests.in @@ -1,7 +1,6 @@ pytest pytest-timeout -pytest-xprocess cryptography -greenlet ; python_version < "3.11" watchdog ephemeral-port-reserve +cffi diff --git a/requirements/tests.txt b/requirements/tests.txt index 057d62859..d64cace9a 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,36 +1,30 @@ -# SHA1:42b4e3e66395275e048d9a92c294b2c650393866 # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: # -# pip-compile-multi +# pip-compile tests.in # -cffi==1.15.1 - # via cryptography -cryptography==41.0.1 - # via -r requirements/tests.in +cffi==1.17.1 + # via + # -r tests.in + # cryptography +cryptography==43.0.3 + # via -r tests.in ephemeral-port-reserve==1.1.4 - # via -r requirements/tests.in + # via -r tests.in iniconfig==2.0.0 # via pytest -packaging==23.1 +packaging==24.1 # via pytest -pluggy==1.2.0 +pluggy==1.5.0 # via pytest -psutil==5.9.5 - # via pytest-xprocess -py==1.11.0 - # via pytest-xprocess -pycparser==2.21 +pycparser==2.22 # via cffi -pytest==7.4.0 +pytest==8.3.3 # via - # -r requirements/tests.in + # -r tests.in # pytest-timeout - # pytest-xprocess -pytest-timeout==2.1.0 - # via -r requirements/tests.in -pytest-xprocess==0.22.2 - # via -r requirements/tests.in -watchdog==3.0.0 - # via -r requirements/tests.in +pytest-timeout==2.3.1 + # via -r tests.in +watchdog==5.0.3 + # via -r tests.in diff --git a/requirements/typing.in b/requirements/typing.in index 23ab1587b..096413b22 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -1,4 +1,6 @@ mypy +pyright +pytest types-contextvars types-dataclasses types-setuptools diff --git a/requirements/typing.txt b/requirements/typing.txt index 99c46d2e0..b90f838dd 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,21 +1,34 @@ -# SHA1:162796b1b3ac7a29da65fe0e32278f14b68ed8c8 # -# This file is autogenerated by pip-compile-multi -# To update, run: +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: # -# pip-compile-multi +# pip-compile typing.in # -mypy==1.4.1 - # via -r requirements/typing.in +iniconfig==2.0.0 + # via pytest +mypy==1.13.0 + # via -r typing.in mypy-extensions==1.0.0 # via mypy -types-contextvars==2.4.7.2 - # via -r requirements/typing.in +nodeenv==1.9.1 + # via pyright +packaging==24.1 + # via pytest +pluggy==1.5.0 + # via pytest +pyright==1.1.386 + # via -r typing.in +pytest==8.3.3 + # via -r typing.in +types-contextvars==2.4.7.3 + # via -r typing.in types-dataclasses==0.6.6 - # via -r requirements/typing.in -types-setuptools==68.0.0.0 - # via -r requirements/typing.in -typing-extensions==4.6.3 - # via mypy -watchdog==3.0.0 - # via -r requirements/typing.in + # via -r typing.in +types-setuptools==75.2.0.20241019 + # via -r typing.in +typing-extensions==4.12.2 + # via + # mypy + # pyright +watchdog==5.0.3 + # via -r typing.in diff --git a/src/werkzeug/__init__.py b/src/werkzeug/__init__.py index 0a472ae7d..0b248fd86 100644 --- a/src/werkzeug/__init__.py +++ b/src/werkzeug/__init__.py @@ -2,5 +2,3 @@ from .test import Client as Client from .wrappers import Request as Request from .wrappers import Response as Response - -__version__ = "2.3.8" diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py index 6ed4d3024..7dd2fbccd 100644 --- a/src/werkzeug/_internal.py +++ b/src/werkzeug/_internal.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import operator import re import sys import typing as t @@ -10,6 +9,7 @@ if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment + from .wrappers.request import Request _logger: logging.Logger | None = None @@ -26,102 +26,12 @@ def __reduce__(self) -> str: _missing = _Missing() -@t.overload -def _make_encode_wrapper(reference: str) -> t.Callable[[str], str]: - ... - - -@t.overload -def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]: - ... - - -def _make_encode_wrapper(reference: t.AnyStr) -> t.Callable[[str], t.AnyStr]: - """Create a function that will be called with a string argument. If - the reference is bytes, values will be encoded to bytes. - """ - if isinstance(reference, str): - return lambda x: x - - return operator.methodcaller("encode", "latin1") - - -def _check_str_tuple(value: tuple[t.AnyStr, ...]) -> None: - """Ensure tuple items are all strings or all bytes.""" - if not value: - return - - item_type = str if isinstance(value[0], str) else bytes - - if any(not isinstance(item, item_type) for item in value): - raise TypeError(f"Cannot mix str and bytes arguments (got {value!r})") - - -_default_encoding = sys.getdefaultencoding() - - -def _to_bytes( - x: str | bytes, charset: str = _default_encoding, errors: str = "strict" -) -> bytes: - if x is None or isinstance(x, bytes): - return x - - if isinstance(x, (bytearray, memoryview)): - return bytes(x) - - if isinstance(x, str): - return x.encode(charset, errors) - - raise TypeError("Expected bytes") +def _wsgi_decoding_dance(s: str) -> str: + return s.encode("latin1").decode(errors="replace") -@t.overload -def _to_str( # type: ignore - x: None, - charset: str | None = ..., - errors: str = ..., - allow_none_charset: bool = ..., -) -> None: - ... - - -@t.overload -def _to_str( - x: t.Any, - charset: str | None = ..., - errors: str = ..., - allow_none_charset: bool = ..., -) -> str: - ... - - -def _to_str( - x: t.Any | None, - charset: str | None = _default_encoding, - errors: str = "strict", - allow_none_charset: bool = False, -) -> str | bytes | None: - if x is None or isinstance(x, str): - return x - - if not isinstance(x, (bytes, bytearray)): - return str(x) - - if charset is None: - if allow_none_charset: - return x - - return x.decode(charset, errors) # type: ignore - - -def _wsgi_decoding_dance( - s: str, charset: str = "utf-8", errors: str = "replace" -) -> str: - return s.encode("latin1").decode(charset, errors) - - -def _wsgi_encoding_dance(s: str, charset: str = "utf-8", errors: str = "strict") -> str: - return s.encode(charset).decode("latin1", errors) +def _wsgi_encoding_dance(s: str) -> str: + return s.encode().decode("latin1") def _get_environ(obj: WSGIEnvironment | Request) -> WSGIEnvironment: @@ -151,7 +61,7 @@ def _has_level_handler(logger: logging.Logger) -> bool: return False -class _ColorStreamHandler(logging.StreamHandler): +class _ColorStreamHandler(logging.StreamHandler): # type: ignore[type-arg] """On Windows, wrap stream with Colorama for ANSI style support.""" def __init__(self) -> None: @@ -188,13 +98,11 @@ def _log(type: str, message: str, *args: t.Any, **kwargs: t.Any) -> None: @t.overload -def _dt_as_utc(dt: None) -> None: - ... +def _dt_as_utc(dt: None) -> None: ... @t.overload -def _dt_as_utc(dt: datetime) -> datetime: - ... +def _dt_as_utc(dt: datetime) -> datetime: ... def _dt_as_utc(dt: datetime | None) -> datetime | None: @@ -240,12 +148,10 @@ def lookup(self, instance: t.Any) -> t.MutableMapping[str, t.Any]: @t.overload def __get__( self, instance: None, owner: type - ) -> _DictAccessorProperty[_TAccessorValue]: - ... + ) -> _DictAccessorProperty[_TAccessorValue]: ... @t.overload - def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: - ... + def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: ... def __get__( self, instance: t.Any | None, owner: type @@ -287,31 +193,6 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {self.name}>" -def _decode_idna(domain: str) -> str: - try: - data = domain.encode("ascii") - except UnicodeEncodeError: - # If the domain is not ASCII, it's decoded already. - return domain - - try: - # Try decoding in one shot. - return data.decode("idna") - except UnicodeDecodeError: - pass - - # Decode each part separately, leaving invalid parts as punycode. - parts = [] - - for part in data.split(b"."): - try: - parts.append(part.decode("idna")) - except UnicodeDecodeError: - parts.append(part.decode("ascii")) - - return ".".join(parts) - - _plain_int_re = re.compile(r"-?\d+", re.ASCII) diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index c8683593f..8fd50b963 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -141,7 +141,7 @@ def _find_watchdog_paths( def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: - root: dict[str, dict] = {} + root: dict[str, dict[str, t.Any]] = {} for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True): node = root @@ -153,11 +153,13 @@ def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: rv = set() - def _walk(node: t.Mapping[str, dict], path: tuple[str, ...]) -> None: + def _walk(node: t.Mapping[str, dict[str, t.Any]], path: tuple[str, ...]) -> None: for prefix, child in node.items(): _walk(child, path + (prefix,)) - if not node: + # If there are no more nodes, and a path has been accumulated, add it. + # Path may be empty if the "" entry is in sys.path. + if not node and path: rv.add(os.path.join(*path)) _walk(root, ()) @@ -279,7 +281,7 @@ def trigger_reload(self, filename: str) -> None: self.log_reload(filename) sys.exit(3) - def log_reload(self, filename: str) -> None: + def log_reload(self, filename: str | bytes) -> None: filename = os.path.abspath(filename) _log("info", f" * Detected change in {filename!r}, reloading") @@ -310,17 +312,28 @@ def run_step(self) -> None: class WatchdogReloaderLoop(ReloaderLoop): def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: - from watchdog.observers import Observer - from watchdog.events import PatternMatchingEventHandler - from watchdog.events import EVENT_TYPE_OPENED + from watchdog.events import EVENT_TYPE_CLOSED + from watchdog.events import EVENT_TYPE_CREATED + from watchdog.events import EVENT_TYPE_DELETED + from watchdog.events import EVENT_TYPE_MODIFIED + from watchdog.events import EVENT_TYPE_MOVED from watchdog.events import FileModifiedEvent + from watchdog.events import PatternMatchingEventHandler + from watchdog.observers import Observer super().__init__(*args, **kwargs) trigger_reload = self.trigger_reload class EventHandler(PatternMatchingEventHandler): def on_any_event(self, event: FileModifiedEvent): # type: ignore - if event.event_type == EVENT_TYPE_OPENED: + if event.event_type not in { + EVENT_TYPE_CLOSED, + EVENT_TYPE_CREATED, + EVENT_TYPE_DELETED, + EVENT_TYPE_MODIFIED, + EVENT_TYPE_MOVED, + }: + # skip events that don't involve changes to the file return trigger_reload(event.src_path) @@ -347,7 +360,7 @@ def on_any_event(self, event: FileModifiedEvent): # type: ignore ) self.should_reload = False - def trigger_reload(self, filename: str) -> None: + def trigger_reload(self, filename: str | bytes) -> None: # This is called inside an event handler, which means throwing # SystemExit has no effect. # https://github.com/gorakhargosh/watchdog/issues/294 diff --git a/src/werkzeug/datastructures/__init__.py b/src/werkzeug/datastructures/__init__.py index 846ffce67..6582de02c 100644 --- a/src/werkzeug/datastructures/__init__.py +++ b/src/werkzeug/datastructures/__init__.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +import typing as t + from .accept import Accept as Accept from .accept import CharsetAccept as CharsetAccept from .accept import LanguageAccept as LanguageAccept @@ -26,9 +30,35 @@ from .structures import ImmutableDict as ImmutableDict from .structures import ImmutableList as ImmutableList from .structures import ImmutableMultiDict as ImmutableMultiDict -from .structures import ImmutableOrderedMultiDict as ImmutableOrderedMultiDict from .structures import ImmutableTypeConversionDict as ImmutableTypeConversionDict from .structures import iter_multi_items as iter_multi_items from .structures import MultiDict as MultiDict -from .structures import OrderedMultiDict as OrderedMultiDict from .structures import TypeConversionDict as TypeConversionDict + + +def __getattr__(name: str) -> t.Any: + import warnings + + if name == "OrderedMultiDict": + from .structures import _OrderedMultiDict + + warnings.warn( + "'OrderedMultiDict' is deprecated and will be removed in Werkzeug" + " 3.2. Use 'MultiDict' instead.", + DeprecationWarning, + stacklevel=2, + ) + return _OrderedMultiDict + + if name == "ImmutableOrderedMultiDict": + from .structures import _ImmutableOrderedMultiDict + + warnings.warn( + "'OrderedMultiDict' is deprecated and will be removed in Werkzeug" + " 3.2. Use 'ImmutableMultiDict' instead.", + DeprecationWarning, + stacklevel=2, + ) + return _ImmutableOrderedMultiDict + + raise AttributeError(name) diff --git a/src/werkzeug/datastructures/accept.py b/src/werkzeug/datastructures/accept.py index d80f0bbb8..44179a93f 100644 --- a/src/werkzeug/datastructures/accept.py +++ b/src/werkzeug/datastructures/accept.py @@ -1,12 +1,14 @@ from __future__ import annotations import codecs +import collections.abc as cabc import re +import typing as t from .structures import ImmutableList -class Accept(ImmutableList): +class Accept(ImmutableList[tuple[str, float]]): """An :class:`Accept` object is just a list subclass for lists of ``(value, quality)`` tuples. It is automatically sorted by specificity and quality. @@ -42,29 +44,39 @@ class Accept(ImmutableList): """ - def __init__(self, values=()): + def __init__( + self, values: Accept | cabc.Iterable[tuple[str, float]] | None = () + ) -> None: if values is None: - list.__init__(self) + super().__init__() self.provided = False elif isinstance(values, Accept): self.provided = values.provided - list.__init__(self, values) + super().__init__(values) else: self.provided = True values = sorted( values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True ) - list.__init__(self, values) + super().__init__(values) - def _specificity(self, value): + def _specificity(self, value: str) -> tuple[bool, ...]: """Returns a tuple describing the value's specificity.""" return (value != "*",) - def _value_matches(self, value, item): + def _value_matches(self, value: str, item: str) -> bool: """Check if a value matches a given accept item.""" return item == "*" or item.lower() == value.lower() - def __getitem__(self, key): + @t.overload + def __getitem__(self, key: str) -> float: ... + @t.overload + def __getitem__(self, key: t.SupportsIndex) -> tuple[str, float]: ... + @t.overload + def __getitem__(self, key: slice) -> list[tuple[str, float]]: ... + def __getitem__( + self, key: str | t.SupportsIndex | slice + ) -> float | tuple[str, float] | list[tuple[str, float]]: """Besides index lookup (getting item n) you can also pass it a string to get the quality for the item. If the item is not in the list, the returned quality is ``0``. @@ -73,7 +85,7 @@ def __getitem__(self, key): return self.quality(key) return list.__getitem__(self, key) - def quality(self, key): + def quality(self, key: str) -> float: """Returns the quality of the key. .. versionadded:: 0.6 @@ -85,17 +97,17 @@ def quality(self, key): return quality return 0 - def __contains__(self, value): + def __contains__(self, value: str) -> bool: # type: ignore[override] for item, _quality in self: if self._value_matches(value, item): return True return False - def __repr__(self): + def __repr__(self) -> str: pairs_str = ", ".join(f"({x!r}, {y})" for x, y in self) return f"{type(self).__name__}([{pairs_str}])" - def index(self, key): + def index(self, key: str | tuple[str, float]) -> int: # type: ignore[override] """Get the position of an entry or raise :exc:`ValueError`. :param key: The key to be looked up. @@ -111,7 +123,7 @@ def index(self, key): raise ValueError(key) return list.index(self, key) - def find(self, key): + def find(self, key: str | tuple[str, float]) -> int: """Get the position of an entry or return -1. :param key: The key to be looked up. @@ -121,12 +133,12 @@ def find(self, key): except ValueError: return -1 - def values(self): + def values(self) -> cabc.Iterator[str]: """Iterate over all values.""" for item in self: yield item[0] - def to_header(self): + def to_header(self) -> str: """Convert the header set into an HTTP header string.""" result = [] for value, quality in self: @@ -135,17 +147,23 @@ def to_header(self): result.append(value) return ",".join(result) - def __str__(self): + def __str__(self) -> str: return self.to_header() - def _best_single_match(self, match): + def _best_single_match(self, match: str) -> tuple[str, float] | None: for client_item, quality in self: if self._value_matches(match, client_item): # self is sorted by specificity descending, we can exit return client_item, quality return None - def best_match(self, matches, default=None): + @t.overload + def best_match(self, matches: cabc.Iterable[str]) -> str | None: ... + @t.overload + def best_match(self, matches: cabc.Iterable[str], default: str = ...) -> str: ... + def best_match( + self, matches: cabc.Iterable[str], default: str | None = None + ) -> str | None: """Returns the best match from a list of possible matches based on the specificity and quality of the client. If two items have the same quality and specificity, the one is returned that comes first. @@ -154,8 +172,8 @@ def best_match(self, matches, default=None): :param default: the value that is returned if none match """ result = default - best_quality = -1 - best_specificity = (-1,) + best_quality: float = -1 + best_specificity: tuple[float, ...] = (-1,) for server_item in matches: match = self._best_single_match(server_item) if not match: @@ -172,16 +190,18 @@ def best_match(self, matches, default=None): return result @property - def best(self): + def best(self) -> str | None: """The best match as value.""" if self: return self[0][0] + return None + _mime_split_re = re.compile(r"/|(?:\s*;\s*)") -def _normalize_mime(value): +def _normalize_mime(value: str) -> list[str]: return _mime_split_re.split(value.lower()) @@ -190,10 +210,10 @@ class MIMEAccept(Accept): mimetypes. """ - def _specificity(self, value): + def _specificity(self, value: str) -> tuple[bool, ...]: return tuple(x != "*" for x in _mime_split_re.split(value)) - def _value_matches(self, value, item): + def _value_matches(self, value: str, item: str) -> bool: # item comes from the client, can't match if it's invalid. if "/" not in item: return False @@ -234,27 +254,25 @@ def _value_matches(self, value, item): ) @property - def accept_html(self): + def accept_html(self) -> bool: """True if this object accepts HTML.""" - return ( - "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml - ) + return "text/html" in self or self.accept_xhtml # type: ignore[comparison-overlap] @property - def accept_xhtml(self): + def accept_xhtml(self) -> bool: """True if this object accepts XHTML.""" - return "application/xhtml+xml" in self or "application/xml" in self + return "application/xhtml+xml" in self or "application/xml" in self # type: ignore[comparison-overlap] @property - def accept_json(self): + def accept_json(self) -> bool: """True if this object accepts JSON.""" - return "application/json" in self + return "application/json" in self # type: ignore[comparison-overlap] _locale_delim_re = re.compile(r"[_-]") -def _normalize_lang(value): +def _normalize_lang(value: str) -> list[str]: """Process a language tag for matching.""" return _locale_delim_re.split(value.lower()) @@ -262,10 +280,16 @@ def _normalize_lang(value): class LanguageAccept(Accept): """Like :class:`Accept` but with normalization for language tags.""" - def _value_matches(self, value, item): + def _value_matches(self, value: str, item: str) -> bool: return item == "*" or _normalize_lang(value) == _normalize_lang(item) - def best_match(self, matches, default=None): + @t.overload + def best_match(self, matches: cabc.Iterable[str]) -> str | None: ... + @t.overload + def best_match(self, matches: cabc.Iterable[str], default: str = ...) -> str: ... + def best_match( + self, matches: cabc.Iterable[str], default: str | None = None + ) -> str | None: """Given a list of supported values, finds the best match from the list of accepted values. @@ -316,8 +340,8 @@ def best_match(self, matches, default=None): class CharsetAccept(Accept): """Like :class:`Accept` but with normalization for charsets.""" - def _value_matches(self, value, item): - def _normalize(name): + def _value_matches(self, value: str, item: str) -> bool: + def _normalize(name: str) -> str: try: return codecs.lookup(name).name except LookupError: diff --git a/src/werkzeug/datastructures/accept.pyi b/src/werkzeug/datastructures/accept.pyi deleted file mode 100644 index 4b74dd950..000000000 --- a/src/werkzeug/datastructures/accept.pyi +++ /dev/null @@ -1,54 +0,0 @@ -from collections.abc import Iterable -from collections.abc import Iterator -from typing import overload - -from .structures import ImmutableList - -class Accept(ImmutableList[tuple[str, int]]): - provided: bool - def __init__( - self, values: Accept | Iterable[tuple[str, float]] | None = None - ) -> None: ... - def _specificity(self, value: str) -> tuple[bool, ...]: ... - def _value_matches(self, value: str, item: str) -> bool: ... - @overload # type: ignore - def __getitem__(self, key: str) -> int: ... - @overload - def __getitem__(self, key: int) -> tuple[str, int]: ... - @overload - def __getitem__(self, key: slice) -> Iterable[tuple[str, int]]: ... - def quality(self, key: str) -> int: ... - def __contains__(self, value: str) -> bool: ... # type: ignore - def index(self, key: str) -> int: ... # type: ignore - def find(self, key: str) -> int: ... - def values(self) -> Iterator[str]: ... - def to_header(self) -> str: ... - def _best_single_match(self, match: str) -> tuple[str, int] | None: ... - @overload - def best_match(self, matches: Iterable[str], default: str) -> str: ... - @overload - def best_match( - self, matches: Iterable[str], default: str | None = None - ) -> str | None: ... - @property - def best(self) -> str: ... - -def _normalize_mime(value: str) -> list[str]: ... - -class MIMEAccept(Accept): - def _specificity(self, value: str) -> tuple[bool, ...]: ... - def _value_matches(self, value: str, item: str) -> bool: ... - @property - def accept_html(self) -> bool: ... - @property - def accept_xhtml(self) -> bool: ... - @property - def accept_json(self) -> bool: ... - -def _normalize_lang(value: str) -> list[str]: ... - -class LanguageAccept(Accept): - def _value_matches(self, value: str, item: str) -> bool: ... - -class CharsetAccept(Accept): - def _value_matches(self, value: str, item: str) -> bool: ... diff --git a/src/werkzeug/datastructures/auth.py b/src/werkzeug/datastructures/auth.py index 2f2515020..42f7aa468 100644 --- a/src/werkzeug/datastructures/auth.py +++ b/src/werkzeug/datastructures/auth.py @@ -2,16 +2,13 @@ import base64 import binascii +import collections.abc as cabc import typing as t -import warnings -from functools import wraps from ..http import dump_header from ..http import parse_dict_header -from ..http import parse_set_header from ..http import quote_header_value from .structures import CallbackDict -from .structures import HeaderSet if t.TYPE_CHECKING: import typing_extensions as te @@ -46,7 +43,7 @@ class Authorization: def __init__( self, auth_type: str, - data: dict[str, str] | None = None, + data: dict[str, str | None] | None = None, token: str | None = None, ) -> None: self.type = auth_type @@ -128,7 +125,7 @@ def to_header(self) -> str: if self.type == "basic": value = base64.b64encode( f"{self.username}:{self.password}".encode() - ).decode("utf8") + ).decode("ascii") return f"Basic {value}" if self.token is not None: @@ -143,31 +140,6 @@ def __repr__(self) -> str: return f"<{type(self).__name__} {self.to_header()}>" -def auth_property(name: str, doc: str | None = None) -> property: - """A static helper function for Authentication subclasses to add - extra authentication system properties onto a class:: - - class FooAuthenticate(WWWAuthenticate): - special_realm = auth_property('special_realm') - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. - """ - warnings.warn( - "'auth_property' is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - def _set_value(self, value): # type: ignore[no-untyped-def] - if value is None: - self.pop(name, None) - else: - self[name] = str(value) - - return property(lambda x: x.get(name), _set_value, doc=doc) - - class WWWAuthenticate: """Represents the parts of a ``WWW-Authenticate`` response header. @@ -196,25 +168,16 @@ class WWWAuthenticate: def __init__( self, - auth_type: str | None = None, - values: dict[str, str] | None = None, + auth_type: str, + values: dict[str, str | None] | None = None, token: str | None = None, ): - if auth_type is None: - warnings.warn( - "An auth type must be given as the first parameter. Assuming 'basic' is" - " deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - auth_type = "basic" - self._type = auth_type.lower() - self._parameters: dict[str, str] = CallbackDict( # type: ignore[misc] + self._parameters: dict[str, str | None] = CallbackDict( values, lambda _: self._trigger_on_update() ) self._token = token - self._on_update: t.Callable[[WWWAuthenticate], None] | None = None + self._on_update: cabc.Callable[[WWWAuthenticate], None] | None = None def _trigger_on_update(self) -> None: if self._on_update is not None: @@ -231,7 +194,7 @@ def type(self, value: str) -> None: self._trigger_on_update() @property - def parameters(self) -> dict[str, str]: + def parameters(self) -> dict[str, str | None]: """A dict of parameters for the header. Only one of this or :attr:`token` should have a value for a given scheme. """ @@ -239,9 +202,7 @@ def parameters(self) -> dict[str, str]: @parameters.setter def parameters(self, value: dict[str, str]) -> None: - self._parameters = CallbackDict( # type: ignore[misc] - value, lambda _: self._trigger_on_update() - ) + self._parameters = CallbackDict(value, lambda _: self._trigger_on_update()) self._trigger_on_update() @property @@ -261,62 +222,6 @@ def token(self, value: str | None) -> None: self._token = value self._trigger_on_update() - def set_basic(self, realm: str = "authentication required") -> None: - """Clear any existing data and set a ``Basic`` challenge. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Create and assign an instance instead. - """ - warnings.warn( - "The 'set_basic' method is deprecated and will be removed in Werkzeug 3.0." - " Create and assign an instance instead." - ) - self._type = "basic" - dict.clear(self.parameters) # type: ignore[arg-type] - dict.update( - self.parameters, # type: ignore[arg-type] - {"realm": realm}, # type: ignore[dict-item] - ) - self._token = None - self._trigger_on_update() - - def set_digest( - self, - realm: str, - nonce: str, - qop: t.Sequence[str] = ("auth",), - opaque: str | None = None, - algorithm: str | None = None, - stale: bool = False, - ) -> None: - """Clear any existing data and set a ``Digest`` challenge. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Create and assign an instance instead. - """ - warnings.warn( - "The 'set_digest' method is deprecated and will be removed in Werkzeug 3.0." - " Create and assign an instance instead." - ) - self._type = "digest" - dict.clear(self.parameters) # type: ignore[arg-type] - parameters = { - "realm": realm, - "nonce": nonce, - "qop": ", ".join(qop), - "stale": "TRUE" if stale else "FALSE", - } - - if opaque is not None: - parameters["opaque"] = opaque - - if algorithm is not None: - parameters["algorithm"] = algorithm - - dict.update(self.parameters, parameters) # type: ignore[arg-type] - self._token = None - self._trigger_on_update() - def __getitem__(self, key: str) -> str | None: return self.parameters.get(key) @@ -410,101 +315,3 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"<{type(self).__name__} {self.to_header()}>" - - @property - def qop(self) -> set[str]: - """The ``qop`` parameter as a set. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. It will become the same as other - parameters, returning a string. - """ - warnings.warn( - "The 'qop' property is deprecated and will be removed in Werkzeug 3.0." - " It will become the same as other parameters, returning a string.", - DeprecationWarning, - stacklevel=2, - ) - - def on_update(value: HeaderSet) -> None: - if not value: - if "qop" in self: - del self["qop"] - - return - - self.parameters["qop"] = value.to_header() - - return parse_set_header(self.parameters.get("qop"), on_update) - - @property - def stale(self) -> bool | None: - """The ``stale`` parameter as a boolean. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. It will become the same as other - parameters, returning a string. - """ - warnings.warn( - "The 'stale' property is deprecated and will be removed in Werkzeug 3.0." - " It will become the same as other parameters, returning a string.", - DeprecationWarning, - stacklevel=2, - ) - - if "stale" in self.parameters: - return self.parameters["stale"].lower() == "true" - - return None - - @stale.setter - def stale(self, value: bool | str | None) -> None: - if value is None: - if "stale" in self.parameters: - del self.parameters["stale"] - - return - - if isinstance(value, bool): - warnings.warn( - "Setting the 'stale' property to a boolean is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - self.parameters["stale"] = "TRUE" if value else "FALSE" - else: - self.parameters["stale"] = value - - auth_property = staticmethod(auth_property) - - -def _deprecated_dict_method(f): # type: ignore[no-untyped-def] - @wraps(f) - def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] - warnings.warn( - "Treating 'Authorization' and 'WWWAuthenticate' as a dict is deprecated and" - " will be removed in Werkzeug 3.0. Use the 'parameters' attribute instead.", - DeprecationWarning, - stacklevel=2, - ) - return f(*args, **kwargs) - - return wrapper - - -for name in ( - "__iter__", - "clear", - "copy", - "items", - "keys", - "pop", - "popitem", - "setdefault", - "update", - "values", -): - f = _deprecated_dict_method(getattr(dict, name)) - setattr(Authorization, name, f) - setattr(WWWAuthenticate, name, f) diff --git a/src/werkzeug/datastructures/cache_control.py b/src/werkzeug/datastructures/cache_control.py index bff4c18bb..8d700ab6a 100644 --- a/src/werkzeug/datastructures/cache_control.py +++ b/src/werkzeug/datastructures/cache_control.py @@ -1,25 +1,59 @@ from __future__ import annotations +import collections.abc as cabc +import typing as t +from inspect import cleandoc + from .mixins import ImmutableDictMixin -from .mixins import UpdateDictMixin +from .structures import CallbackDict -def cache_control_property(key, empty, type): +def cache_control_property( + key: str, empty: t.Any, type: type[t.Any] | None, *, doc: str | None = None +) -> t.Any: """Return a new property object for a cache header. Useful if you want to add support for a cache extension in a subclass. + :param key: The attribute name present in the parsed cache-control header dict. + :param empty: The value to use if the key is present without a value. + :param type: The type to convert the string value to instead of a string. If + conversion raises a ``ValueError``, the returned value is ``None``. + :param doc: The docstring for the property. If not given, it is generated + based on the other params. + + .. versionchanged:: 3.1 + Added the ``doc`` param. + .. versionchanged:: 2.0 Renamed from ``cache_property``. """ + if doc is None: + parts = [f"The ``{key}`` attribute."] + + if type is bool: + parts.append("A ``bool``, either present or not.") + else: + if type is None: + parts.append("A ``str``,") + else: + parts.append(f"A ``{type.__name__}``,") + + if empty is not None: + parts.append(f"``{empty!r}`` if present with no value,") + + parts.append("or ``None`` if not present.") + + doc = " ".join(parts) + return property( lambda x: x._get_cache_value(key, empty, type), lambda x, v: x._set_cache_value(key, v, type), lambda x: x._del_cache_value(key), - f"accessor for {key!r}", + doc=cleandoc(doc), ) -class _CacheControl(UpdateDictMixin, dict): +class _CacheControl(CallbackDict[str, t.Optional[str]]): """Subclass of a dict that stores values for a Cache-Control header. It has accessors for all the cache-control directives specified in RFC 2616. The class does not differentiate between request and response directives. @@ -32,93 +66,95 @@ class _CacheControl(UpdateDictMixin, dict): to subclass it and add your own items have a look at the sourcecode for that class. - .. versionchanged:: 2.1.0 + .. versionchanged:: 3.1 + Dict values are always ``str | None``. Setting properties will + convert the value to a string. Setting a non-bool property to + ``False`` is equivalent to setting it to ``None``. Getting typed + properties will return ``None`` if conversion raises + ``ValueError``, rather than the string. + + .. versionchanged:: 2.1 Setting int properties such as ``max_age`` will convert the value to an int. .. versionchanged:: 0.4 - - Setting `no_cache` or `private` to boolean `True` will set the implicit - none-value which is ``*``: - - >>> cc = ResponseCacheControl() - >>> cc.no_cache = True - >>> cc - - >>> cc.no_cache - '*' - >>> cc.no_cache = None - >>> cc - - - In versions before 0.5 the behavior documented here affected the now - no longer existing `CacheControl` class. + Setting ``no_cache`` or ``private`` to ``True`` will set the + implicit value ``"*"``. """ - no_cache = cache_control_property("no-cache", "*", None) - no_store = cache_control_property("no-store", None, bool) - max_age = cache_control_property("max-age", -1, int) - no_transform = cache_control_property("no-transform", None, None) - - def __init__(self, values=(), on_update=None): - dict.__init__(self, values or ()) - self.on_update = on_update + no_store: bool = cache_control_property("no-store", None, bool) + max_age: int | None = cache_control_property("max-age", None, int) + no_transform: bool = cache_control_property("no-transform", None, bool) + stale_if_error: int | None = cache_control_property("stale-if-error", None, int) + + def __init__( + self, + values: cabc.Mapping[str, t.Any] | cabc.Iterable[tuple[str, t.Any]] | None = (), + on_update: cabc.Callable[[_CacheControl], None] | None = None, + ): + super().__init__(values, on_update) self.provided = values is not None - def _get_cache_value(self, key, empty, type): + def _get_cache_value( + self, key: str, empty: t.Any, type: type[t.Any] | None + ) -> t.Any: """Used internally by the accessor properties.""" if type is bool: return key in self - if key in self: - value = self[key] - if value is None: - return empty - elif type is not None: - try: - value = type(value) - except ValueError: - pass - return value - return None - - def _set_cache_value(self, key, value, type): + + if key not in self: + return None + + if (value := self[key]) is None: + return empty + + if type is not None: + try: + value = type(value) + except ValueError: + return None + + return value + + def _set_cache_value( + self, key: str, value: t.Any, type: type[t.Any] | None + ) -> None: """Used internally by the accessor properties.""" if type is bool: if value: self[key] = None else: self.pop(key, None) + elif value is None or value is False: + self.pop(key, None) + elif value is True: + self[key] = None else: - if value is None: - self.pop(key, None) - elif value is True: - self[key] = None - else: - if type is not None: - self[key] = type(value) - else: - self[key] = value + if type is not None: + value = type(value) - def _del_cache_value(self, key): + self[key] = str(value) + + def _del_cache_value(self, key: str) -> None: """Used internally by the accessor properties.""" if key in self: del self[key] - def to_header(self): + def to_header(self) -> str: """Convert the stored values into a cache control header.""" return http.dump_header(self) - def __str__(self): + def __str__(self) -> str: return self.to_header() - def __repr__(self): + def __repr__(self) -> str: kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) return f"<{type(self).__name__} {kv_str}>" cache_property = staticmethod(cache_control_property) -class RequestCacheControl(ImmutableDictMixin, _CacheControl): +class RequestCacheControl(ImmutableDictMixin[str, t.Optional[str]], _CacheControl): # type: ignore[misc] """A cache control for requests. This is immutable and gives access to all the request-relevant cache control headers. @@ -127,18 +163,49 @@ class RequestCacheControl(ImmutableDictMixin, _CacheControl): you plan to subclass it and add your own items have a look at the sourcecode for that class. - .. versionchanged:: 2.1.0 + .. versionchanged:: 3.1 + Dict values are always ``str | None``. Setting properties will + convert the value to a string. Setting a non-bool property to + ``False`` is equivalent to setting it to ``None``. Getting typed + properties will return ``None`` if conversion raises + ``ValueError``, rather than the string. + + .. versionchanged:: 3.1 + ``max_age`` is ``None`` if present without a value, rather + than ``-1``. + + .. versionchanged:: 3.1 + ``no_cache`` is a boolean, it is ``True`` instead of ``"*"`` + when present. + + .. versionchanged:: 3.1 + ``max_stale`` is ``True`` if present without a value, rather + than ``"*"``. + + .. versionchanged:: 3.1 + ``no_transform`` is a boolean. Previously it was mistakenly + always ``None``. + + .. versionchanged:: 3.1 + ``min_fresh`` is ``None`` if present without a value, rather + than ``"*"``. + + .. versionchanged:: 2.1 Setting int properties such as ``max_age`` will convert the value to an int. .. versionadded:: 0.5 - In previous versions a `CacheControl` class existed that was used - both for request and response. + Response-only properties are not present on this request class. """ - max_stale = cache_control_property("max-stale", "*", int) - min_fresh = cache_control_property("min-fresh", "*", int) - only_if_cached = cache_control_property("only-if-cached", None, bool) + no_cache: bool = cache_control_property("no-cache", None, bool) + max_stale: int | t.Literal[True] | None = cache_control_property( + "max-stale", + True, + int, + ) + min_fresh: int | None = cache_control_property("min-fresh", None, int) + only_if_cached: bool = cache_control_property("only-if-cached", None, bool) class ResponseCacheControl(_CacheControl): @@ -151,24 +218,55 @@ class ResponseCacheControl(_CacheControl): you plan to subclass it and add your own items have a look at the sourcecode for that class. + .. versionchanged:: 3.1 + Dict values are always ``str | None``. Setting properties will + convert the value to a string. Setting a non-bool property to + ``False`` is equivalent to setting it to ``None``. Getting typed + properties will return ``None`` if conversion raises + ``ValueError``, rather than the string. + + .. versionchanged:: 3.1 + ``no_cache`` is ``True`` if present without a value, rather than + ``"*"``. + + .. versionchanged:: 3.1 + ``private`` is ``True`` if present without a value, rather than + ``"*"``. + + .. versionchanged:: 3.1 + ``no_transform`` is a boolean. Previously it was mistakenly + always ``None``. + + .. versionchanged:: 3.1 + Added the ``must_understand``, ``stale_while_revalidate``, and + ``stale_if_error`` properties. + .. versionchanged:: 2.1.1 ``s_maxage`` converts the value to an int. - .. versionchanged:: 2.1.0 + .. versionchanged:: 2.1 Setting int properties such as ``max_age`` will convert the value to an int. .. versionadded:: 0.5 - In previous versions a `CacheControl` class existed that was used - both for request and response. + Request-only properties are not present on this response class. """ - public = cache_control_property("public", None, bool) - private = cache_control_property("private", "*", None) - must_revalidate = cache_control_property("must-revalidate", None, bool) - proxy_revalidate = cache_control_property("proxy-revalidate", None, bool) - s_maxage = cache_control_property("s-maxage", None, int) - immutable = cache_control_property("immutable", None, bool) + no_cache: str | t.Literal[True] | None = cache_control_property( + "no-cache", True, None + ) + public: bool = cache_control_property("public", None, bool) + private: str | t.Literal[True] | None = cache_control_property( + "private", True, None + ) + must_revalidate: bool = cache_control_property("must-revalidate", None, bool) + proxy_revalidate: bool = cache_control_property("proxy-revalidate", None, bool) + s_maxage: int | None = cache_control_property("s-maxage", None, int) + immutable: bool = cache_control_property("immutable", None, bool) + must_understand: bool = cache_control_property("must-understand", None, bool) + stale_while_revalidate: int | None = cache_control_property( + "stale-while-revalidate", None, int + ) # circular dependencies diff --git a/src/werkzeug/datastructures/cache_control.pyi b/src/werkzeug/datastructures/cache_control.pyi deleted file mode 100644 index 06fe667a2..000000000 --- a/src/werkzeug/datastructures/cache_control.pyi +++ /dev/null @@ -1,109 +0,0 @@ -from collections.abc import Callable -from collections.abc import Iterable -from collections.abc import Mapping -from typing import TypeVar - -from .mixins import ImmutableDictMixin -from .mixins import UpdateDictMixin - -T = TypeVar("T") -_CPT = TypeVar("_CPT", str, int, bool) -_OptCPT = _CPT | None - -def cache_control_property(key: str, empty: _OptCPT, type: type[_CPT]) -> property: ... - -class _CacheControl(UpdateDictMixin[str, _OptCPT], dict[str, _OptCPT]): - provided: bool - def __init__( - self, - values: Mapping[str, _OptCPT] | Iterable[tuple[str, _OptCPT]] = (), - on_update: Callable[[_CacheControl], None] | None = None, - ) -> None: ... - @property - def no_cache(self) -> bool | None: ... - @no_cache.setter - def no_cache(self, value: bool | None) -> None: ... - @no_cache.deleter - def no_cache(self) -> None: ... - @property - def no_store(self) -> bool | None: ... - @no_store.setter - def no_store(self, value: bool | None) -> None: ... - @no_store.deleter - def no_store(self) -> None: ... - @property - def max_age(self) -> int | None: ... - @max_age.setter - def max_age(self, value: int | None) -> None: ... - @max_age.deleter - def max_age(self) -> None: ... - @property - def no_transform(self) -> bool | None: ... - @no_transform.setter - def no_transform(self, value: bool | None) -> None: ... - @no_transform.deleter - def no_transform(self) -> None: ... - def _get_cache_value(self, key: str, empty: T | None, type: type[T]) -> T: ... - def _set_cache_value(self, key: str, value: T | None, type: type[T]) -> None: ... - def _del_cache_value(self, key: str) -> None: ... - def to_header(self) -> str: ... - @staticmethod - def cache_property(key: str, empty: _OptCPT, type: type[_CPT]) -> property: ... - -class RequestCacheControl(ImmutableDictMixin[str, _OptCPT], _CacheControl): - @property - def max_stale(self) -> int | None: ... - @max_stale.setter - def max_stale(self, value: int | None) -> None: ... - @max_stale.deleter - def max_stale(self) -> None: ... - @property - def min_fresh(self) -> int | None: ... - @min_fresh.setter - def min_fresh(self, value: int | None) -> None: ... - @min_fresh.deleter - def min_fresh(self) -> None: ... - @property - def only_if_cached(self) -> bool | None: ... - @only_if_cached.setter - def only_if_cached(self, value: bool | None) -> None: ... - @only_if_cached.deleter - def only_if_cached(self) -> None: ... - -class ResponseCacheControl(_CacheControl): - @property - def public(self) -> bool | None: ... - @public.setter - def public(self, value: bool | None) -> None: ... - @public.deleter - def public(self) -> None: ... - @property - def private(self) -> bool | None: ... - @private.setter - def private(self, value: bool | None) -> None: ... - @private.deleter - def private(self) -> None: ... - @property - def must_revalidate(self) -> bool | None: ... - @must_revalidate.setter - def must_revalidate(self, value: bool | None) -> None: ... - @must_revalidate.deleter - def must_revalidate(self) -> None: ... - @property - def proxy_revalidate(self) -> bool | None: ... - @proxy_revalidate.setter - def proxy_revalidate(self, value: bool | None) -> None: ... - @proxy_revalidate.deleter - def proxy_revalidate(self) -> None: ... - @property - def s_maxage(self) -> int | None: ... - @s_maxage.setter - def s_maxage(self, value: int | None) -> None: ... - @s_maxage.deleter - def s_maxage(self) -> None: ... - @property - def immutable(self) -> bool | None: ... - @immutable.setter - def immutable(self, value: bool | None) -> None: ... - @immutable.deleter - def immutable(self) -> None: ... diff --git a/src/werkzeug/datastructures/csp.py b/src/werkzeug/datastructures/csp.py index dde941495..0353eebea 100644 --- a/src/werkzeug/datastructures/csp.py +++ b/src/werkzeug/datastructures/csp.py @@ -1,9 +1,12 @@ from __future__ import annotations -from .mixins import UpdateDictMixin +import collections.abc as cabc +import typing as t +from .structures import CallbackDict -def csp_property(key): + +def csp_property(key: str) -> t.Any: """Return a new property object for a content security policy header. Useful if you want to add support for a csp extension in a subclass. @@ -16,7 +19,7 @@ def csp_property(key): ) -class ContentSecurityPolicy(UpdateDictMixin, dict): +class ContentSecurityPolicy(CallbackDict[str, str]): """Subclass of a dict that stores values for a Content Security Policy header. It has accessors for all the level 3 policies. @@ -33,62 +36,65 @@ class ContentSecurityPolicy(UpdateDictMixin, dict): """ - base_uri = csp_property("base-uri") - child_src = csp_property("child-src") - connect_src = csp_property("connect-src") - default_src = csp_property("default-src") - font_src = csp_property("font-src") - form_action = csp_property("form-action") - frame_ancestors = csp_property("frame-ancestors") - frame_src = csp_property("frame-src") - img_src = csp_property("img-src") - manifest_src = csp_property("manifest-src") - media_src = csp_property("media-src") - navigate_to = csp_property("navigate-to") - object_src = csp_property("object-src") - prefetch_src = csp_property("prefetch-src") - plugin_types = csp_property("plugin-types") - report_to = csp_property("report-to") - report_uri = csp_property("report-uri") - sandbox = csp_property("sandbox") - script_src = csp_property("script-src") - script_src_attr = csp_property("script-src-attr") - script_src_elem = csp_property("script-src-elem") - style_src = csp_property("style-src") - style_src_attr = csp_property("style-src-attr") - style_src_elem = csp_property("style-src-elem") - worker_src = csp_property("worker-src") - - def __init__(self, values=(), on_update=None): - dict.__init__(self, values or ()) - self.on_update = on_update + base_uri: str | None = csp_property("base-uri") + child_src: str | None = csp_property("child-src") + connect_src: str | None = csp_property("connect-src") + default_src: str | None = csp_property("default-src") + font_src: str | None = csp_property("font-src") + form_action: str | None = csp_property("form-action") + frame_ancestors: str | None = csp_property("frame-ancestors") + frame_src: str | None = csp_property("frame-src") + img_src: str | None = csp_property("img-src") + manifest_src: str | None = csp_property("manifest-src") + media_src: str | None = csp_property("media-src") + navigate_to: str | None = csp_property("navigate-to") + object_src: str | None = csp_property("object-src") + prefetch_src: str | None = csp_property("prefetch-src") + plugin_types: str | None = csp_property("plugin-types") + report_to: str | None = csp_property("report-to") + report_uri: str | None = csp_property("report-uri") + sandbox: str | None = csp_property("sandbox") + script_src: str | None = csp_property("script-src") + script_src_attr: str | None = csp_property("script-src-attr") + script_src_elem: str | None = csp_property("script-src-elem") + style_src: str | None = csp_property("style-src") + style_src_attr: str | None = csp_property("style-src-attr") + style_src_elem: str | None = csp_property("style-src-elem") + worker_src: str | None = csp_property("worker-src") + + def __init__( + self, + values: cabc.Mapping[str, str] | cabc.Iterable[tuple[str, str]] | None = (), + on_update: cabc.Callable[[ContentSecurityPolicy], None] | None = None, + ) -> None: + super().__init__(values, on_update) self.provided = values is not None - def _get_value(self, key): + def _get_value(self, key: str) -> str | None: """Used internally by the accessor properties.""" return self.get(key) - def _set_value(self, key, value): + def _set_value(self, key: str, value: str | None) -> None: """Used internally by the accessor properties.""" if value is None: self.pop(key, None) else: self[key] = value - def _del_value(self, key): + def _del_value(self, key: str) -> None: """Used internally by the accessor properties.""" if key in self: del self[key] - def to_header(self): + def to_header(self) -> str: """Convert the stored values into a cache control header.""" from ..http import dump_csp_header return dump_csp_header(self) - def __str__(self): + def __str__(self) -> str: return self.to_header() - def __repr__(self): + def __repr__(self) -> str: kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) return f"<{type(self).__name__} {kv_str}>" diff --git a/src/werkzeug/datastructures/csp.pyi b/src/werkzeug/datastructures/csp.pyi deleted file mode 100644 index f9e2ac0f4..000000000 --- a/src/werkzeug/datastructures/csp.pyi +++ /dev/null @@ -1,169 +0,0 @@ -from collections.abc import Callable -from collections.abc import Iterable -from collections.abc import Mapping - -from .mixins import UpdateDictMixin - -def csp_property(key: str) -> property: ... - -class ContentSecurityPolicy(UpdateDictMixin[str, str], dict[str, str]): - @property - def base_uri(self) -> str | None: ... - @base_uri.setter - def base_uri(self, value: str | None) -> None: ... - @base_uri.deleter - def base_uri(self) -> None: ... - @property - def child_src(self) -> str | None: ... - @child_src.setter - def child_src(self, value: str | None) -> None: ... - @child_src.deleter - def child_src(self) -> None: ... - @property - def connect_src(self) -> str | None: ... - @connect_src.setter - def connect_src(self, value: str | None) -> None: ... - @connect_src.deleter - def connect_src(self) -> None: ... - @property - def default_src(self) -> str | None: ... - @default_src.setter - def default_src(self, value: str | None) -> None: ... - @default_src.deleter - def default_src(self) -> None: ... - @property - def font_src(self) -> str | None: ... - @font_src.setter - def font_src(self, value: str | None) -> None: ... - @font_src.deleter - def font_src(self) -> None: ... - @property - def form_action(self) -> str | None: ... - @form_action.setter - def form_action(self, value: str | None) -> None: ... - @form_action.deleter - def form_action(self) -> None: ... - @property - def frame_ancestors(self) -> str | None: ... - @frame_ancestors.setter - def frame_ancestors(self, value: str | None) -> None: ... - @frame_ancestors.deleter - def frame_ancestors(self) -> None: ... - @property - def frame_src(self) -> str | None: ... - @frame_src.setter - def frame_src(self, value: str | None) -> None: ... - @frame_src.deleter - def frame_src(self) -> None: ... - @property - def img_src(self) -> str | None: ... - @img_src.setter - def img_src(self, value: str | None) -> None: ... - @img_src.deleter - def img_src(self) -> None: ... - @property - def manifest_src(self) -> str | None: ... - @manifest_src.setter - def manifest_src(self, value: str | None) -> None: ... - @manifest_src.deleter - def manifest_src(self) -> None: ... - @property - def media_src(self) -> str | None: ... - @media_src.setter - def media_src(self, value: str | None) -> None: ... - @media_src.deleter - def media_src(self) -> None: ... - @property - def navigate_to(self) -> str | None: ... - @navigate_to.setter - def navigate_to(self, value: str | None) -> None: ... - @navigate_to.deleter - def navigate_to(self) -> None: ... - @property - def object_src(self) -> str | None: ... - @object_src.setter - def object_src(self, value: str | None) -> None: ... - @object_src.deleter - def object_src(self) -> None: ... - @property - def prefetch_src(self) -> str | None: ... - @prefetch_src.setter - def prefetch_src(self, value: str | None) -> None: ... - @prefetch_src.deleter - def prefetch_src(self) -> None: ... - @property - def plugin_types(self) -> str | None: ... - @plugin_types.setter - def plugin_types(self, value: str | None) -> None: ... - @plugin_types.deleter - def plugin_types(self) -> None: ... - @property - def report_to(self) -> str | None: ... - @report_to.setter - def report_to(self, value: str | None) -> None: ... - @report_to.deleter - def report_to(self) -> None: ... - @property - def report_uri(self) -> str | None: ... - @report_uri.setter - def report_uri(self, value: str | None) -> None: ... - @report_uri.deleter - def report_uri(self) -> None: ... - @property - def sandbox(self) -> str | None: ... - @sandbox.setter - def sandbox(self, value: str | None) -> None: ... - @sandbox.deleter - def sandbox(self) -> None: ... - @property - def script_src(self) -> str | None: ... - @script_src.setter - def script_src(self, value: str | None) -> None: ... - @script_src.deleter - def script_src(self) -> None: ... - @property - def script_src_attr(self) -> str | None: ... - @script_src_attr.setter - def script_src_attr(self, value: str | None) -> None: ... - @script_src_attr.deleter - def script_src_attr(self) -> None: ... - @property - def script_src_elem(self) -> str | None: ... - @script_src_elem.setter - def script_src_elem(self, value: str | None) -> None: ... - @script_src_elem.deleter - def script_src_elem(self) -> None: ... - @property - def style_src(self) -> str | None: ... - @style_src.setter - def style_src(self, value: str | None) -> None: ... - @style_src.deleter - def style_src(self) -> None: ... - @property - def style_src_attr(self) -> str | None: ... - @style_src_attr.setter - def style_src_attr(self, value: str | None) -> None: ... - @style_src_attr.deleter - def style_src_attr(self) -> None: ... - @property - def style_src_elem(self) -> str | None: ... - @style_src_elem.setter - def style_src_elem(self, value: str | None) -> None: ... - @style_src_elem.deleter - def style_src_elem(self) -> None: ... - @property - def worker_src(self) -> str | None: ... - @worker_src.setter - def worker_src(self, value: str | None) -> None: ... - @worker_src.deleter - def worker_src(self) -> None: ... - provided: bool - def __init__( - self, - values: Mapping[str, str] | Iterable[tuple[str, str]] = (), - on_update: Callable[[ContentSecurityPolicy], None] | None = None, - ) -> None: ... - def _get_value(self, key: str) -> str | None: ... - def _set_value(self, key: str, value: str) -> None: ... - def _del_value(self, key: str) -> None: ... - def to_header(self) -> str: ... diff --git a/src/werkzeug/datastructures/etag.py b/src/werkzeug/datastructures/etag.py index 747d9966d..a4ef34245 100644 --- a/src/werkzeug/datastructures/etag.py +++ b/src/werkzeug/datastructures/etag.py @@ -1,14 +1,19 @@ from __future__ import annotations -from collections.abc import Collection +import collections.abc as cabc -class ETags(Collection): +class ETags(cabc.Collection[str]): """A set that can be used to check if one etag is present in a collection of etags. """ - def __init__(self, strong_etags=None, weak_etags=None, star_tag=False): + def __init__( + self, + strong_etags: cabc.Iterable[str] | None = None, + weak_etags: cabc.Iterable[str] | None = None, + star_tag: bool = False, + ): if not star_tag and strong_etags: self._strong = frozenset(strong_etags) else: @@ -17,7 +22,7 @@ def __init__(self, strong_etags=None, weak_etags=None, star_tag=False): self._weak = frozenset(weak_etags or ()) self.star_tag = star_tag - def as_set(self, include_weak=False): + def as_set(self, include_weak: bool = False) -> set[str]: """Convert the `ETags` object into a python set. Per default all the weak etags are not part of this set.""" rv = set(self._strong) @@ -25,19 +30,19 @@ def as_set(self, include_weak=False): rv.update(self._weak) return rv - def is_weak(self, etag): + def is_weak(self, etag: str) -> bool: """Check if an etag is weak.""" return etag in self._weak - def is_strong(self, etag): + def is_strong(self, etag: str) -> bool: """Check if an etag is strong.""" return etag in self._strong - def contains_weak(self, etag): + def contains_weak(self, etag: str) -> bool: """Check if an etag is part of the set including weak and strong tags.""" return self.is_weak(etag) or self.contains(etag) - def contains(self, etag): + def contains(self, etag: str) -> bool: """Check if an etag is part of the set ignoring weak tags. It is also possible to use the ``in`` operator. """ @@ -45,7 +50,7 @@ def contains(self, etag): return True return self.is_strong(etag) - def contains_raw(self, etag): + def contains_raw(self, etag: str) -> bool: """When passed a quoted tag it will check if this tag is part of the set. If the tag is weak it is checked against weak and strong tags, otherwise strong only.""" @@ -56,7 +61,7 @@ def contains_raw(self, etag): return self.contains_weak(etag) return self.contains(etag) - def to_header(self): + def to_header(self) -> str: """Convert the etags set into a HTTP header string.""" if self.star_tag: return "*" @@ -64,10 +69,16 @@ def to_header(self): [f'"{x}"' for x in self._strong] + [f'W/"{x}"' for x in self._weak] ) - def __call__(self, etag=None, data=None, include_weak=False): - if [etag, data].count(None) != 1: - raise TypeError("either tag or data required, but at least one") + def __call__( + self, + etag: str | None = None, + data: bytes | None = None, + include_weak: bool = False, + ) -> bool: if etag is None: + if data is None: + raise TypeError("'data' is required when 'etag' is not given.") + from ..http import generate_etag etag = generate_etag(data) @@ -76,20 +87,20 @@ def __call__(self, etag=None, data=None, include_weak=False): return True return etag in self._strong - def __bool__(self): + def __bool__(self) -> bool: return bool(self.star_tag or self._strong or self._weak) - def __str__(self): + def __str__(self) -> str: return self.to_header() - def __len__(self): + def __len__(self) -> int: return len(self._strong) - def __iter__(self): + def __iter__(self) -> cabc.Iterator[str]: return iter(self._strong) - def __contains__(self, etag): + def __contains__(self, etag: str) -> bool: # type: ignore[override] return self.contains(etag) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {str(self)!r}>" diff --git a/src/werkzeug/datastructures/etag.pyi b/src/werkzeug/datastructures/etag.pyi deleted file mode 100644 index 88e54f154..000000000 --- a/src/werkzeug/datastructures/etag.pyi +++ /dev/null @@ -1,30 +0,0 @@ -from collections.abc import Collection -from collections.abc import Iterable -from collections.abc import Iterator - -class ETags(Collection[str]): - _strong: frozenset[str] - _weak: frozenset[str] - star_tag: bool - def __init__( - self, - strong_etags: Iterable[str] | None = None, - weak_etags: Iterable[str] | None = None, - star_tag: bool = False, - ) -> None: ... - def as_set(self, include_weak: bool = False) -> set[str]: ... - def is_weak(self, etag: str) -> bool: ... - def is_strong(self, etag: str) -> bool: ... - def contains_weak(self, etag: str) -> bool: ... - def contains(self, etag: str) -> bool: ... - def contains_raw(self, etag: str) -> bool: ... - def to_header(self) -> str: ... - def __call__( - self, - etag: str | None = None, - data: bytes | None = None, - include_weak: bool = False, - ) -> bool: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[str]: ... - def __contains__(self, item: str) -> bool: ... # type: ignore diff --git a/src/werkzeug/datastructures/file_storage.py b/src/werkzeug/datastructures/file_storage.py index e878a56d4..123424477 100644 --- a/src/werkzeug/datastructures/file_storage.py +++ b/src/werkzeug/datastructures/file_storage.py @@ -1,11 +1,15 @@ from __future__ import annotations +import collections.abc as cabc import mimetypes +import os +import typing as t from io import BytesIO from os import fsdecode from os import fspath from .._internal import _plain_int +from .headers import Headers from .structures import MultiDict @@ -19,12 +23,12 @@ class FileStorage: def __init__( self, - stream=None, - filename=None, - name=None, - content_type=None, - content_length=None, - headers=None, + stream: t.IO[bytes] | None = None, + filename: str | None = None, + name: str | None = None, + content_type: str | None = None, + content_length: int | None = None, + headers: Headers | None = None, ): self.name = name self.stream = stream or BytesIO() @@ -46,8 +50,6 @@ def __init__( self.filename = filename if headers is None: - from .headers import Headers - headers = Headers() self.headers = headers if content_type is not None: @@ -55,17 +57,17 @@ def __init__( if content_length is not None: headers["Content-Length"] = str(content_length) - def _parse_content_type(self): + def _parse_content_type(self) -> None: if not hasattr(self, "_parsed_content_type"): self._parsed_content_type = http.parse_options_header(self.content_type) @property - def content_type(self): + def content_type(self) -> str | None: """The content-type sent in the header. Usually not available""" return self.headers.get("content-type") @property - def content_length(self): + def content_length(self) -> int: """The content-length sent in the header. Usually not available""" if "content-length" in self.headers: try: @@ -76,7 +78,7 @@ def content_length(self): return 0 @property - def mimetype(self): + def mimetype(self) -> str: """Like :attr:`content_type`, but without parameters (eg, without charset, type etc.) and always lowercase. For example if the content type is ``text/HTML; charset=utf-8`` the mimetype would be @@ -88,7 +90,7 @@ def mimetype(self): return self._parsed_content_type[0].lower() @property - def mimetype_params(self): + def mimetype_params(self) -> dict[str, str]: """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -98,7 +100,9 @@ def mimetype_params(self): self._parse_content_type() return self._parsed_content_type[1] - def save(self, dst, buffer_size=16384): + def save( + self, dst: str | os.PathLike[str] | t.IO[bytes], buffer_size: int = 16384 + ) -> None: """Save the file to a destination path or file object. If the destination is a file object you have to close it yourself after the call. The buffer size is the number of bytes held in memory during @@ -131,35 +135,34 @@ def save(self, dst, buffer_size=16384): if close_dst: dst.close() - def close(self): + def close(self) -> None: """Close the underlying file if possible.""" try: self.stream.close() except Exception: pass - def __bool__(self): + def __bool__(self) -> bool: return bool(self.filename) - def __getattr__(self, name): + def __getattr__(self, name: str) -> t.Any: try: return getattr(self.stream, name) except AttributeError: - # SpooledTemporaryFile doesn't implement IOBase, get the - # attribute from its backing file instead. - # https://github.com/python/cpython/pull/3249 + # SpooledTemporaryFile on Python < 3.11 doesn't implement IOBase, + # get the attribute from its backing file instead. if hasattr(self.stream, "_file"): return getattr(self.stream._file, name) raise - def __iter__(self): + def __iter__(self) -> cabc.Iterator[bytes]: return iter(self.stream) - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__}: {self.filename!r} ({self.content_type!r})>" -class FileMultiDict(MultiDict): +class FileMultiDict(MultiDict[str, FileStorage]): """A special :class:`MultiDict` that has convenience methods to add files to it. This is used for :class:`EnvironBuilder` and generally useful for unittesting. @@ -167,7 +170,13 @@ class FileMultiDict(MultiDict): .. versionadded:: 0.5 """ - def add_file(self, name, file, filename=None, content_type=None): + def add_file( + self, + name: str, + file: str | os.PathLike[str] | t.IO[bytes] | FileStorage, + filename: str | None = None, + content_type: str | None = None, + ) -> None: """Adds a new file to the dict. `file` can be a file name or a :class:`file`-like or a :class:`FileStorage` object. @@ -177,19 +186,23 @@ def add_file(self, name, file, filename=None, content_type=None): :param content_type: an optional content type """ if isinstance(file, FileStorage): - value = file + self.add(name, file) + return + + if isinstance(file, (str, os.PathLike)): + if filename is None: + filename = os.fspath(file) + + file_obj: t.IO[bytes] = open(file, "rb") else: - if isinstance(file, str): - if filename is None: - filename = file - file = open(file, "rb") - if filename and content_type is None: - content_type = ( - mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - value = FileStorage(file, filename, name, content_type) - - self.add(name, value) + file_obj = file # type: ignore[assignment] + + if filename and content_type is None: + content_type = ( + mimetypes.guess_type(filename)[0] or "application/octet-stream" + ) + + self.add(name, FileStorage(file_obj, filename, name, content_type)) # circular dependencies diff --git a/src/werkzeug/datastructures/file_storage.pyi b/src/werkzeug/datastructures/file_storage.pyi deleted file mode 100644 index 730789e35..000000000 --- a/src/werkzeug/datastructures/file_storage.pyi +++ /dev/null @@ -1,47 +0,0 @@ -from collections.abc import Iterator -from os import PathLike -from typing import Any -from typing import IO - -from .headers import Headers -from .structures import MultiDict - -class FileStorage: - name: str | None - stream: IO[bytes] - filename: str | None - headers: Headers - _parsed_content_type: tuple[str, dict[str, str]] - def __init__( - self, - stream: IO[bytes] | None = None, - filename: str | PathLike | None = None, - name: str | None = None, - content_type: str | None = None, - content_length: int | None = None, - headers: Headers | None = None, - ) -> None: ... - def _parse_content_type(self) -> None: ... - @property - def content_type(self) -> str: ... - @property - def content_length(self) -> int: ... - @property - def mimetype(self) -> str: ... - @property - def mimetype_params(self) -> dict[str, str]: ... - def save(self, dst: str | PathLike | IO[bytes], buffer_size: int = ...) -> None: ... - def close(self) -> None: ... - def __bool__(self) -> bool: ... - def __getattr__(self, name: str) -> Any: ... - def __iter__(self) -> Iterator[bytes]: ... - def __repr__(self) -> str: ... - -class FileMultiDict(MultiDict[str, FileStorage]): - def add_file( - self, - name: str, - file: FileStorage | str | IO[bytes], - filename: str | None = None, - content_type: str | None = None, - ) -> None: ... diff --git a/src/werkzeug/datastructures/headers.py b/src/werkzeug/datastructures/headers.py index dc060c41e..1088e3bc9 100644 --- a/src/werkzeug/datastructures/headers.py +++ b/src/werkzeug/datastructures/headers.py @@ -1,8 +1,8 @@ from __future__ import annotations +import collections.abc as cabc import re import typing as t -import warnings from .._internal import _missing from ..exceptions import BadRequestKeyError @@ -10,6 +10,12 @@ from .structures import iter_multi_items from .structures import MultiDict +if t.TYPE_CHECKING: + import typing_extensions as te + from _typeshed.wsgi import WSGIEnvironment + +T = t.TypeVar("T") + class Headers: """An object that stores some headers. It has a dict-like interface, @@ -35,6 +41,9 @@ class Headers: :param defaults: The list of default values for the :class:`Headers`. + .. versionchanged:: 3.1 + Implement ``|`` and ``|=`` operators. + .. versionchanged:: 2.1.0 Default values are validated the same as values added later. @@ -48,41 +57,72 @@ class Headers: was an API that does not support the changes to the encoding model. """ - def __init__(self, defaults=None): - self._list = [] + def __init__( + self, + defaults: ( + Headers + | MultiDict[str, t.Any] + | cabc.Mapping[str, t.Any | list[t.Any] | tuple[t.Any, ...] | set[t.Any]] + | cabc.Iterable[tuple[str, t.Any]] + | None + ) = None, + ) -> None: + self._list: list[tuple[str, str]] = [] + if defaults is not None: self.extend(defaults) - def __getitem__(self, key, _get_mode=False): - if not _get_mode: - if isinstance(key, int): - return self._list[key] - elif isinstance(key, slice): - return self.__class__(self._list[key]) - if not isinstance(key, str): - raise BadRequestKeyError(key) + @t.overload + def __getitem__(self, key: str) -> str: ... + @t.overload + def __getitem__(self, key: int) -> tuple[str, str]: ... + @t.overload + def __getitem__(self, key: slice) -> te.Self: ... + def __getitem__(self, key: str | int | slice) -> str | tuple[str, str] | te.Self: + if isinstance(key, str): + return self._get_key(key) + + if isinstance(key, int): + return self._list[key] + + return self.__class__(self._list[key]) + + def _get_key(self, key: str) -> str: ikey = key.lower() + for k, v in self._list: if k.lower() == ikey: return v - # micro optimization: if we are in get mode we will catch that - # exception one stack level down so we can raise a standard - # key error instead of our special one. - if _get_mode: - raise KeyError() - raise BadRequestKeyError(key) - - def __eq__(self, other): - def lowered(item): - return (item[0].lower(),) + item[1:] - return other.__class__ is self.__class__ and set( - map(lowered, other._list) - ) == set(map(lowered, self._list)) - - __hash__ = None + raise BadRequestKeyError(key) - def get(self, key, default=None, type=None, as_bytes=None): + def __eq__(self, other: object) -> bool: + if other.__class__ is not self.__class__: + return NotImplemented + + def lowered(item: tuple[str, ...]) -> tuple[str, ...]: + return item[0].lower(), *item[1:] + + return set(map(lowered, other._list)) == set(map(lowered, self._list)) # type: ignore[attr-defined] + + __hash__ = None # type: ignore[assignment] + + @t.overload + def get(self, key: str) -> str | None: ... + @t.overload + def get(self, key: str, default: str) -> str: ... + @t.overload + def get(self, key: str, default: T) -> str | T: ... + @t.overload + def get(self, key: str, type: cabc.Callable[[str], T]) -> T | None: ... + @t.overload + def get(self, key: str, default: T, type: cabc.Callable[[str], T]) -> T: ... + def get( # type: ignore[misc] + self, + key: str, + default: str | T | None = None, + type: cabc.Callable[[str], T] | None = None, + ) -> str | T | None: """Return the default value if the requested data doesn't exist. If `type` is provided and is a callable it should convert the value, return it or raise a :exc:`ValueError` if that is not possible. In @@ -101,35 +141,32 @@ def get(self, key, default=None, type=None, as_bytes=None): :class:`Headers`. If a :exc:`ValueError` is raised by this callable the default value is returned. - .. versionchanged:: 2.3 - The ``as_bytes`` parameter is deprecated and will be removed - in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``as_bytes`` parameter was removed. .. versionchanged:: 0.9 The ``as_bytes`` parameter was added. """ - if as_bytes is not None: - warnings.warn( - "The 'as_bytes' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - try: - rv = self.__getitem__(key, _get_mode=True) + rv = self._get_key(key) except KeyError: return default - if as_bytes: - rv = rv.encode("latin1") + if type is None: return rv + try: return type(rv) except ValueError: return default - def getlist(self, key, type=None, as_bytes=None): + @t.overload + def getlist(self, key: str) -> list[str]: ... + @t.overload + def getlist(self, key: str, type: cabc.Callable[[str], T]) -> list[T]: ... + def getlist( + self, key: str, type: cabc.Callable[[str], T] | None = None + ) -> list[str] | list[T]: """Return the list of items for a given key. If that key is not in the :class:`Headers`, the return value will be an empty list. Just like :meth:`get`, :meth:`getlist` accepts a `type` parameter. All items will @@ -141,36 +178,29 @@ def getlist(self, key, type=None, as_bytes=None): by this callable the value will be removed from the list. :return: a :class:`list` of all the values for the key. - .. versionchanged:: 2.3 - The ``as_bytes`` parameter is deprecated and will be removed - in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``as_bytes`` parameter was removed. .. versionchanged:: 0.9 The ``as_bytes`` parameter was added. """ - if as_bytes is not None: - warnings.warn( - "The 'as_bytes' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - ikey = key.lower() - result = [] - for k, v in self: - if k.lower() == ikey: - if as_bytes: - v = v.encode("latin1") - if type is not None: + + if type is not None: + result = [] + + for k, v in self: + if k.lower() == ikey: try: - v = type(v) + result.append(type(v)) except ValueError: continue - result.append(v) - return result - def get_all(self, name): + return result + + return [v for k, v in self if k.lower() == ikey] + + def get_all(self, name: str) -> list[str]: """Return a list of all the values for the named field. This method is compatible with the :mod:`wsgiref` @@ -178,21 +208,32 @@ def get_all(self, name): """ return self.getlist(name) - def items(self, lower=False): + def items(self, lower: bool = False) -> t.Iterable[tuple[str, str]]: for key, value in self: if lower: key = key.lower() yield key, value - def keys(self, lower=False): + def keys(self, lower: bool = False) -> t.Iterable[str]: for key, _ in self.items(lower): yield key - def values(self): + def values(self) -> t.Iterable[str]: for _, value in self.items(): yield value - def extend(self, *args, **kwargs): + def extend( + self, + arg: ( + Headers + | MultiDict[str, t.Any] + | cabc.Mapping[str, t.Any | list[t.Any] | tuple[t.Any, ...] | set[t.Any]] + | cabc.Iterable[tuple[str, t.Any]] + | None + ) = None, + /, + **kwargs: str, + ) -> None: """Extend headers in this object with items from another object containing header items as well as keyword arguments. @@ -206,35 +247,52 @@ def extend(self, *args, **kwargs): .. versionchanged:: 1.0 Support :class:`MultiDict`. Allow passing ``kwargs``. """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - for key, value in iter_multi_items(args[0]): + if arg is not None: + for key, value in iter_multi_items(arg): self.add(key, value) for key, value in iter_multi_items(kwargs): self.add(key, value) - def __delitem__(self, key, _index_operation=True): - if _index_operation and isinstance(key, (int, slice)): - del self._list[key] + def __delitem__(self, key: str | int | slice) -> None: + if isinstance(key, str): + self._del_key(key) return + + del self._list[key] + + def _del_key(self, key: str) -> None: key = key.lower() new = [] + for k, v in self._list: if k.lower() != key: new.append((k, v)) + self._list[:] = new - def remove(self, key): + def remove(self, key: str) -> None: """Remove a key. :param key: The key to be removed. """ - return self.__delitem__(key, _index_operation=False) - - def pop(self, key=None, default=_missing): + return self._del_key(key) + + @t.overload + def pop(self) -> tuple[str, str]: ... + @t.overload + def pop(self, key: str) -> str: ... + @t.overload + def pop(self, key: int | None = ...) -> tuple[str, str]: ... + @t.overload + def pop(self, key: str, default: str) -> str: ... + @t.overload + def pop(self, key: str, default: T) -> str | T: ... + def pop( + self, + key: str | int | None = None, + default: str | T = _missing, # type: ignore[assignment] + ) -> str | tuple[str, str] | T: """Removes and returns a key or index. :param key: The key to be popped. If this is an integer the item at @@ -245,37 +303,42 @@ def pop(self, key=None, default=_missing): """ if key is None: return self._list.pop() + if isinstance(key, int): return self._list.pop(key) + try: - rv = self[key] - self.remove(key) + rv = self._get_key(key) except KeyError: if default is not _missing: return default + raise + + self.remove(key) return rv - def popitem(self): + def popitem(self) -> tuple[str, str]: """Removes a key or index and returns a (key, value) item.""" - return self.pop() + return self._list.pop() - def __contains__(self, key): + def __contains__(self, key: str) -> bool: """Check if a key is present.""" try: - self.__getitem__(key, _get_mode=True) + self._get_key(key) except KeyError: return False + return True - def __iter__(self): + def __iter__(self) -> t.Iterator[tuple[str, str]]: """Yield ``(key, value)`` tuples.""" return iter(self._list) - def __len__(self): + def __len__(self) -> int: return len(self._list) - def add(self, _key, _value, **kw): + def add(self, key: str, value: t.Any, /, **kwargs: t.Any) -> None: """Add a new header tuple to the list. Keyword arguments can specify additional parameters for the header @@ -288,28 +351,28 @@ def add(self, _key, _value, **kw): The keyword argument dumping uses :func:`dump_options_header` behind the scenes. - .. versionadded:: 0.4.1 + .. versionchanged:: 0.4.1 keyword arguments were added for :mod:`wsgiref` compatibility. """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _str_header_key(_key) - _value = _str_header_value(_value) - self._list.append((_key, _value)) + if kwargs: + value = _options_header_vkw(value, kwargs) + + value_str = _str_header_value(value) + self._list.append((key, value_str)) - def add_header(self, _key, _value, **_kw): + def add_header(self, key: str, value: t.Any, /, **kwargs: t.Any) -> None: """Add a new header tuple to the list. An alias for :meth:`add` for compatibility with the :mod:`wsgiref` :meth:`~wsgiref.headers.Headers.add_header` method. """ - self.add(_key, _value, **_kw) + self.add(key, value, **kwargs) - def clear(self): + def clear(self) -> None: """Clears all headers.""" - del self._list[:] + self._list.clear() - def set(self, _key, _value, **kw): + def set(self, key: str, value: t.Any, /, **kwargs: t.Any) -> None: """Remove all header tuples for `key` and add a new one. The newly added key either appears at the end of the list if there was no entry or replaces the first one. @@ -324,26 +387,32 @@ def set(self, _key, _value, **kw): :param key: The key to be inserted. :param value: The value to be inserted. """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _str_header_key(_key) - _value = _str_header_value(_value) + if kwargs: + value = _options_header_vkw(value, kwargs) + + value_str = _str_header_value(value) + if not self._list: - self._list.append((_key, _value)) + self._list.append((key, value_str)) return - listiter = iter(self._list) - ikey = _key.lower() - for idx, (old_key, _old_value) in enumerate(listiter): + + iter_list = iter(self._list) + ikey = key.lower() + + for idx, (old_key, _) in enumerate(iter_list): if old_key.lower() == ikey: # replace first occurrence - self._list[idx] = (_key, _value) + self._list[idx] = (key, value_str) break else: - self._list.append((_key, _value)) + # no existing occurrences + self._list.append((key, value_str)) return - self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] - def setlist(self, key, values): + # remove remaining occurrences + self._list[idx + 1 :] = [t for t in iter_list if t[0].lower() != ikey] + + def setlist(self, key: str, values: cabc.Iterable[t.Any]) -> None: """Remove any existing values for a header and add new ones. :param key: The header key to set. @@ -360,7 +429,7 @@ def setlist(self, key, values): else: self.remove(key) - def setdefault(self, key, default): + def setdefault(self, key: str, default: t.Any) -> str: """Return the first value for the key if it is in the headers, otherwise set the header to the value given by ``default`` and return that. @@ -369,13 +438,15 @@ def setdefault(self, key, default): :param default: The value to set for the key if it is not in the headers. """ - if key in self: - return self[key] + try: + return self._get_key(key) + except KeyError: + pass self.set(key, default) - return default + return self._get_key(key) - def setlistdefault(self, key, default): + def setlistdefault(self, key: str, default: cabc.Iterable[t.Any]) -> list[str]: """Return the list of values for the key if it is in the headers, otherwise set the header to the list of values given by ``default`` and return that. @@ -394,20 +465,41 @@ def setlistdefault(self, key, default): return self.getlist(key) - def __setitem__(self, key, value): + @t.overload + def __setitem__(self, key: str, value: t.Any) -> None: ... + @t.overload + def __setitem__(self, key: int, value: tuple[str, t.Any]) -> None: ... + @t.overload + def __setitem__( + self, key: slice, value: cabc.Iterable[tuple[str, t.Any]] + ) -> None: ... + def __setitem__( + self, + key: str | int | slice, + value: t.Any | tuple[str, t.Any] | cabc.Iterable[tuple[str, t.Any]], + ) -> None: """Like :meth:`set` but also supports index/slice based setting.""" - if isinstance(key, (slice, int)): - if isinstance(key, int): - value = [value] - value = [(_str_header_key(k), _str_header_value(v)) for (k, v) in value] - if isinstance(key, int): - self._list[key] = value[0] - else: - self._list[key] = value - else: + if isinstance(key, str): self.set(key, value) - - def update(self, *args, **kwargs): + elif isinstance(key, int): + self._list[key] = value[0], _str_header_value(value[1]) # type: ignore[index] + else: + self._list[key] = [(k, _str_header_value(v)) for k, v in value] # type: ignore[misc] + + def update( + self, + arg: ( + Headers + | MultiDict[str, t.Any] + | cabc.Mapping[ + str, t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any] + ] + | cabc.Iterable[tuple[str, t.Any]] + | None + ) = None, + /, + **kwargs: t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any], + ) -> None: """Replace headers in this object with items from another headers object and keyword arguments. @@ -420,45 +512,66 @@ def update(self, *args, **kwargs): .. versionadded:: 1.0 """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - mapping = args[0] - - if isinstance(mapping, (Headers, MultiDict)): - for key in mapping.keys(): - self.setlist(key, mapping.getlist(key)) - elif isinstance(mapping, dict): - for key, value in mapping.items(): - if isinstance(value, (list, tuple)): + if arg is not None: + if isinstance(arg, (Headers, MultiDict)): + for key in arg.keys(): + self.setlist(key, arg.getlist(key)) + elif isinstance(arg, cabc.Mapping): + for key, value in arg.items(): + if isinstance(value, (list, tuple, set)): self.setlist(key, value) else: self.set(key, value) else: - for key, value in mapping: + for key, value in arg: self.set(key, value) for key, value in kwargs.items(): - if isinstance(value, (list, tuple)): + if isinstance(value, (list, tuple, set)): self.setlist(key, value) else: self.set(key, value) - def to_wsgi_list(self): + def __or__( + self, + other: cabc.Mapping[ + str, t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any] + ], + ) -> te.Self: + if not isinstance(other, cabc.Mapping): + return NotImplemented + + rv = self.copy() + rv.update(other) + return rv + + def __ior__( + self, + other: ( + cabc.Mapping[str, t.Any | list[t.Any] | tuple[t.Any, ...] | cabc.Set[t.Any]] + | cabc.Iterable[tuple[str, t.Any]] + ), + ) -> te.Self: + if not isinstance(other, (cabc.Mapping, cabc.Iterable)): + return NotImplemented + + self.update(other) + return self + + def to_wsgi_list(self) -> list[tuple[str, str]]: """Convert the headers into a list suitable for WSGI. :return: list """ return list(self) - def copy(self): + def copy(self) -> te.Self: return self.__class__(self._list) - def __copy__(self): + def __copy__(self) -> te.Self: return self.copy() - def __str__(self): + def __str__(self) -> str: """Returns formatted headers suitable for HTTP transmission.""" strs = [] for key, value in self.to_wsgi_list(): @@ -466,56 +579,30 @@ def __str__(self): strs.append("\r\n") return "\r\n".join(strs) - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({list(self)!r})" -def _options_header_vkw(value: str, kw: dict[str, t.Any]): +def _options_header_vkw(value: str, kw: dict[str, t.Any]) -> str: return http.dump_options_header( value, {k.replace("_", "-"): v for k, v in kw.items()} ) -def _str_header_key(key: t.Any) -> str: - if not isinstance(key, str): - warnings.warn( - "Header keys must be strings. Passing other types is deprecated and will" - " not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - if isinstance(key, bytes): - key = key.decode("latin-1") - else: - key = str(key) - - return key - - _newline_re = re.compile(r"[\r\n]") def _str_header_value(value: t.Any) -> str: - if isinstance(value, bytes): - warnings.warn( - "Passing bytes as a header value is deprecated and will not be supported in" - " Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - value = value.decode("latin-1") - if not isinstance(value, str): value = str(value) if _newline_re.search(value) is not None: raise ValueError("Header values must not contain newline characters.") - return value + return value # type: ignore[no-any-return] -class EnvironHeaders(ImmutableHeadersMixin, Headers): +class EnvironHeaders(ImmutableHeadersMixin, Headers): # type: ignore[misc] """Read only version of the headers from a WSGI environment. This provides the same interface as `Headers` and is constructed from a WSGI environment. @@ -525,30 +612,36 @@ class EnvironHeaders(ImmutableHeadersMixin, Headers): HTTP exceptions. """ - def __init__(self, environ): + def __init__(self, environ: WSGIEnvironment) -> None: + super().__init__() self.environ = environ - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, EnvironHeaders): + return NotImplemented + return self.environ is other.environ - __hash__ = None + __hash__ = None # type: ignore[assignment] + + def __getitem__(self, key: str) -> str: # type: ignore[override] + return self._get_key(key) - def __getitem__(self, key, _get_mode=False): - # _get_mode is a no-op for this class as there is no index but - # used because get() calls it. + def _get_key(self, key: str) -> str: if not isinstance(key, str): - raise KeyError(key) + raise BadRequestKeyError(key) + key = key.upper().replace("-", "_") + if key in {"CONTENT_TYPE", "CONTENT_LENGTH"}: - return self.environ[key] - return self.environ[f"HTTP_{key}"] + return self.environ[key] # type: ignore[no-any-return] + + return self.environ[f"HTTP_{key}"] # type: ignore[no-any-return] - def __len__(self): - # the iter is necessary because otherwise list calls our - # len which would call list again and so forth. - return len(list(iter(self))) + def __len__(self) -> int: + return sum(1 for _ in self) - def __iter__(self): + def __iter__(self) -> cabc.Iterator[tuple[str, str]]: for key, value in self.environ.items(): if key.startswith("HTTP_") and key not in { "HTTP_CONTENT_TYPE", @@ -558,7 +651,10 @@ def __iter__(self): elif key in {"CONTENT_TYPE", "CONTENT_LENGTH"} and value: yield key.replace("_", "-").title(), value - def copy(self): + def copy(self) -> t.NoReturn: + raise TypeError(f"cannot create {type(self).__name__!r} copies") + + def __or__(self, other: t.Any) -> t.NoReturn: raise TypeError(f"cannot create {type(self).__name__!r} copies") diff --git a/src/werkzeug/datastructures/headers.pyi b/src/werkzeug/datastructures/headers.pyi deleted file mode 100644 index 86502221a..000000000 --- a/src/werkzeug/datastructures/headers.pyi +++ /dev/null @@ -1,109 +0,0 @@ -from collections.abc import Callable -from collections.abc import Iterable -from collections.abc import Iterator -from collections.abc import Mapping -from typing import Literal -from typing import NoReturn -from typing import overload -from typing import TypeVar - -from _typeshed import SupportsKeysAndGetItem -from _typeshed.wsgi import WSGIEnvironment - -from .mixins import ImmutableHeadersMixin - -D = TypeVar("D") -T = TypeVar("T") - -class Headers(dict[str, str]): - _list: list[tuple[str, str]] - def __init__( - self, - defaults: Mapping[str, str | Iterable[str]] - | Iterable[tuple[str, str]] - | None = None, - ) -> None: ... - @overload - def __getitem__(self, key: str) -> str: ... - @overload - def __getitem__(self, key: int) -> tuple[str, str]: ... - @overload - def __getitem__(self, key: slice) -> Headers: ... - @overload - def __getitem__(self, key: str, _get_mode: Literal[True] = ...) -> str: ... - def __eq__(self, other: object) -> bool: ... - @overload # type: ignore - def get(self, key: str, default: str) -> str: ... - @overload - def get(self, key: str, default: str | None = None) -> str | None: ... - @overload - def get( - self, key: str, default: T | None = None, type: Callable[[str], T] = ... - ) -> T | None: ... - @overload - def getlist(self, key: str) -> list[str]: ... - @overload - def getlist(self, key: str, type: Callable[[str], T]) -> list[T]: ... - def get_all(self, name: str) -> list[str]: ... - def items( # type: ignore - self, lower: bool = False - ) -> Iterator[tuple[str, str]]: ... - def keys(self, lower: bool = False) -> Iterator[str]: ... # type: ignore - def values(self) -> Iterator[str]: ... # type: ignore - def extend( - self, - *args: Mapping[str, str | Iterable[str]] | Iterable[tuple[str, str]], - **kwargs: str | Iterable[str], - ) -> None: ... - @overload - def __delitem__(self, key: str | int | slice) -> None: ... - @overload - def __delitem__(self, key: str, _index_operation: Literal[False]) -> None: ... - def remove(self, key: str) -> None: ... - @overload # type: ignore - def pop(self, key: str, default: str | None = None) -> str: ... - @overload - def pop( - self, key: int | None = None, default: tuple[str, str] | None = None - ) -> tuple[str, str]: ... - def popitem(self) -> tuple[str, str]: ... - def __contains__(self, key: str) -> bool: ... # type: ignore - def has_key(self, key: str) -> bool: ... - def __iter__(self) -> Iterator[tuple[str, str]]: ... # type: ignore - def add(self, _key: str, _value: str, **kw: str) -> None: ... - def _validate_value(self, value: str) -> None: ... - def add_header(self, _key: str, _value: str, **_kw: str) -> None: ... - def clear(self) -> None: ... - def set(self, _key: str, _value: str, **kw: str) -> None: ... - def setlist(self, key: str, values: Iterable[str]) -> None: ... - def setdefault(self, key: str, default: str) -> str: ... - def setlistdefault(self, key: str, default: Iterable[str]) -> None: ... - @overload - def __setitem__(self, key: str, value: str) -> None: ... - @overload - def __setitem__(self, key: int, value: tuple[str, str]) -> None: ... - @overload - def __setitem__(self, key: slice, value: Iterable[tuple[str, str]]) -> None: ... - @overload - def update( - self, __m: SupportsKeysAndGetItem[str, str], **kwargs: str | Iterable[str] - ) -> None: ... - @overload - def update( - self, __m: Iterable[tuple[str, str]], **kwargs: str | Iterable[str] - ) -> None: ... - @overload - def update(self, **kwargs: str | Iterable[str]) -> None: ... - def to_wsgi_list(self) -> list[tuple[str, str]]: ... - def copy(self) -> Headers: ... - def __copy__(self) -> Headers: ... - -class EnvironHeaders(ImmutableHeadersMixin, Headers): - environ: WSGIEnvironment - def __init__(self, environ: WSGIEnvironment) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __getitem__( # type: ignore - self, key: str, _get_mode: Literal[False] = False - ) -> str: ... - def __iter__(self) -> Iterator[tuple[str, str]]: ... # type: ignore - def copy(self) -> NoReturn: ... diff --git a/src/werkzeug/datastructures/mixins.py b/src/werkzeug/datastructures/mixins.py index 2c84ca8f2..03d461ad8 100644 --- a/src/werkzeug/datastructures/mixins.py +++ b/src/werkzeug/datastructures/mixins.py @@ -1,11 +1,22 @@ from __future__ import annotations +import collections.abc as cabc +import typing as t +from functools import update_wrapper from itertools import repeat from .._internal import _missing +if t.TYPE_CHECKING: + import typing_extensions as te -def is_immutable(self): +K = t.TypeVar("K") +V = t.TypeVar("V") +T = t.TypeVar("T") +F = t.TypeVar("F", bound=cabc.Callable[..., t.Any]) + + +def _immutable_error(self: t.Any) -> t.NoReturn: raise TypeError(f"{type(self).__name__!r} objects are immutable") @@ -17,102 +28,118 @@ class ImmutableListMixin: :private: """ - _hash_cache = None + _hash_cache: int | None = None - def __hash__(self): + def __hash__(self) -> int: if self._hash_cache is not None: return self._hash_cache - rv = self._hash_cache = hash(tuple(self)) + rv = self._hash_cache = hash(tuple(self)) # type: ignore[arg-type] return rv - def __reduce_ex__(self, protocol): - return type(self), (list(self),) + def __reduce_ex__(self, protocol: t.SupportsIndex) -> t.Any: + return type(self), (list(self),) # type: ignore[call-overload] - def __delitem__(self, key): - is_immutable(self) + def __delitem__(self, key: t.Any) -> t.NoReturn: + _immutable_error(self) - def __iadd__(self, other): - is_immutable(self) + def __iadd__(self, other: t.Any) -> t.NoReturn: + _immutable_error(self) - def __imul__(self, other): - is_immutable(self) + def __imul__(self, other: t.Any) -> t.NoReturn: + _immutable_error(self) - def __setitem__(self, key, value): - is_immutable(self) + def __setitem__(self, key: t.Any, value: t.Any) -> t.NoReturn: + _immutable_error(self) - def append(self, item): - is_immutable(self) + def append(self, item: t.Any) -> t.NoReturn: + _immutable_error(self) - def remove(self, item): - is_immutable(self) + def remove(self, item: t.Any) -> t.NoReturn: + _immutable_error(self) - def extend(self, iterable): - is_immutable(self) + def extend(self, iterable: t.Any) -> t.NoReturn: + _immutable_error(self) - def insert(self, pos, value): - is_immutable(self) + def insert(self, pos: t.Any, value: t.Any) -> t.NoReturn: + _immutable_error(self) - def pop(self, index=-1): - is_immutable(self) + def pop(self, index: t.Any = -1) -> t.NoReturn: + _immutable_error(self) - def reverse(self): - is_immutable(self) + def reverse(self: t.Any) -> t.NoReturn: + _immutable_error(self) - def sort(self, key=None, reverse=False): - is_immutable(self) + def sort(self, key: t.Any = None, reverse: t.Any = False) -> t.NoReturn: + _immutable_error(self) -class ImmutableDictMixin: +class ImmutableDictMixin(t.Generic[K, V]): """Makes a :class:`dict` immutable. + .. versionchanged:: 3.1 + Disallow ``|=`` operator. + .. versionadded:: 0.5 :private: """ - _hash_cache = None + _hash_cache: int | None = None @classmethod - def fromkeys(cls, keys, value=None): + @t.overload + def fromkeys( + cls, keys: cabc.Iterable[K], value: None + ) -> ImmutableDictMixin[K, t.Any | None]: ... + @classmethod + @t.overload + def fromkeys(cls, keys: cabc.Iterable[K], value: V) -> ImmutableDictMixin[K, V]: ... + @classmethod + def fromkeys( + cls, keys: cabc.Iterable[K], value: V | None = None + ) -> ImmutableDictMixin[K, t.Any | None] | ImmutableDictMixin[K, V]: instance = super().__new__(cls) - instance.__init__(zip(keys, repeat(value))) + instance.__init__(zip(keys, repeat(value))) # type: ignore[misc] return instance - def __reduce_ex__(self, protocol): - return type(self), (dict(self),) + def __reduce_ex__(self, protocol: t.SupportsIndex) -> t.Any: + return type(self), (dict(self),) # type: ignore[call-overload] - def _iter_hashitems(self): - return self.items() + def _iter_hashitems(self) -> t.Iterable[t.Any]: + return self.items() # type: ignore[attr-defined,no-any-return] - def __hash__(self): + def __hash__(self) -> int: if self._hash_cache is not None: return self._hash_cache rv = self._hash_cache = hash(frozenset(self._iter_hashitems())) return rv - def setdefault(self, key, default=None): - is_immutable(self) + def setdefault(self, key: t.Any, default: t.Any = None) -> t.NoReturn: + _immutable_error(self) + + def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn: + _immutable_error(self) - def update(self, *args, **kwargs): - is_immutable(self) + def __ior__(self, other: t.Any) -> t.NoReturn: + _immutable_error(self) - def pop(self, key, default=None): - is_immutable(self) + def pop(self, key: t.Any, default: t.Any = None) -> t.NoReturn: + _immutable_error(self) - def popitem(self): - is_immutable(self) + def popitem(self) -> t.NoReturn: + _immutable_error(self) - def __setitem__(self, key, value): - is_immutable(self) + def __setitem__(self, key: t.Any, value: t.Any) -> t.NoReturn: + _immutable_error(self) - def __delitem__(self, key): - is_immutable(self) + def __delitem__(self, key: t.Any) -> t.NoReturn: + _immutable_error(self) - def clear(self): - is_immutable(self) + def clear(self) -> t.NoReturn: + _immutable_error(self) -class ImmutableMultiDictMixin(ImmutableDictMixin): +class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): """Makes a :class:`MultiDict` immutable. .. versionadded:: 0.5 @@ -120,26 +147,26 @@ class ImmutableMultiDictMixin(ImmutableDictMixin): :private: """ - def __reduce_ex__(self, protocol): - return type(self), (list(self.items(multi=True)),) + def __reduce_ex__(self, protocol: t.SupportsIndex) -> t.Any: + return type(self), (list(self.items(multi=True)),) # type: ignore[attr-defined] - def _iter_hashitems(self): - return self.items(multi=True) + def _iter_hashitems(self) -> t.Iterable[t.Any]: + return self.items(multi=True) # type: ignore[attr-defined,no-any-return] - def add(self, key, value): - is_immutable(self) + def add(self, key: t.Any, value: t.Any) -> t.NoReturn: + _immutable_error(self) - def popitemlist(self): - is_immutable(self) + def popitemlist(self) -> t.NoReturn: + _immutable_error(self) - def poplist(self, key): - is_immutable(self) + def poplist(self, key: t.Any) -> t.NoReturn: + _immutable_error(self) - def setlist(self, key, new_list): - is_immutable(self) + def setlist(self, key: t.Any, new_list: t.Any) -> t.NoReturn: + _immutable_error(self) - def setlistdefault(self, key, default_list=None): - is_immutable(self) + def setlistdefault(self, key: t.Any, default_list: t.Any = None) -> t.NoReturn: + _immutable_error(self) class ImmutableHeadersMixin: @@ -147,96 +174,144 @@ class ImmutableHeadersMixin: hashable though since the only usecase for this datastructure in Werkzeug is a view on a mutable structure. + .. versionchanged:: 3.1 + Disallow ``|=`` operator. + .. versionadded:: 0.5 :private: """ - def __delitem__(self, key, **kwargs): - is_immutable(self) + def __delitem__(self, key: t.Any, **kwargs: t.Any) -> t.NoReturn: + _immutable_error(self) + + def __setitem__(self, key: t.Any, value: t.Any) -> t.NoReturn: + _immutable_error(self) - def __setitem__(self, key, value): - is_immutable(self) + def set(self, key: t.Any, value: t.Any, /, **kwargs: t.Any) -> t.NoReturn: + _immutable_error(self) - def set(self, _key, _value, **kwargs): - is_immutable(self) + def setlist(self, key: t.Any, values: t.Any) -> t.NoReturn: + _immutable_error(self) - def setlist(self, key, values): - is_immutable(self) + def add(self, key: t.Any, value: t.Any, /, **kwargs: t.Any) -> t.NoReturn: + _immutable_error(self) - def add(self, _key, _value, **kwargs): - is_immutable(self) + def add_header(self, key: t.Any, value: t.Any, /, **kwargs: t.Any) -> t.NoReturn: + _immutable_error(self) - def add_header(self, _key, _value, **_kwargs): - is_immutable(self) + def remove(self, key: t.Any) -> t.NoReturn: + _immutable_error(self) - def remove(self, key): - is_immutable(self) + def extend(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn: + _immutable_error(self) - def extend(self, *args, **kwargs): - is_immutable(self) + def update(self, arg: t.Any, /, **kwargs: t.Any) -> t.NoReturn: + _immutable_error(self) - def update(self, *args, **kwargs): - is_immutable(self) + def __ior__(self, other: t.Any) -> t.NoReturn: + _immutable_error(self) - def insert(self, pos, value): - is_immutable(self) + def insert(self, pos: t.Any, value: t.Any) -> t.NoReturn: + _immutable_error(self) - def pop(self, key=None, default=_missing): - is_immutable(self) + def pop(self, key: t.Any = None, default: t.Any = _missing) -> t.NoReturn: + _immutable_error(self) - def popitem(self): - is_immutable(self) + def popitem(self) -> t.NoReturn: + _immutable_error(self) - def setdefault(self, key, default): - is_immutable(self) + def setdefault(self, key: t.Any, default: t.Any) -> t.NoReturn: + _immutable_error(self) - def setlistdefault(self, key, default): - is_immutable(self) + def setlistdefault(self, key: t.Any, default: t.Any) -> t.NoReturn: + _immutable_error(self) -def _calls_update(name): - def oncall(self, *args, **kw): - rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) +def _always_update(f: F) -> F: + def wrapper( + self: UpdateDictMixin[t.Any, t.Any], /, *args: t.Any, **kwargs: t.Any + ) -> t.Any: + rv = f(self, *args, **kwargs) if self.on_update is not None: self.on_update(self) return rv - oncall.__name__ = name - return oncall + return update_wrapper(wrapper, f) # type: ignore[return-value] -class UpdateDictMixin(dict): +class UpdateDictMixin(dict[K, V]): """Makes dicts call `self.on_update` on modifications. + .. versionchanged:: 3.1 + Implement ``|=`` operator. + .. versionadded:: 0.5 :private: """ - on_update = None + on_update: cabc.Callable[[te.Self], None] | None = None - def setdefault(self, key, default=None): + def setdefault(self: te.Self, key: K, default: V | None = None) -> V: modified = key not in self - rv = super().setdefault(key, default) + rv = super().setdefault(key, default) # type: ignore[arg-type] if modified and self.on_update is not None: self.on_update(self) return rv - def pop(self, key, default=_missing): + @t.overload + def pop(self: te.Self, key: K) -> V: ... + @t.overload + def pop(self: te.Self, key: K, default: V) -> V: ... + @t.overload + def pop(self: te.Self, key: K, default: T) -> T: ... + def pop( + self: te.Self, + key: K, + default: V | T = _missing, # type: ignore[assignment] + ) -> V | T: modified = key in self if default is _missing: rv = super().pop(key) else: - rv = super().pop(key, default) + rv = super().pop(key, default) # type: ignore[arg-type] if modified and self.on_update is not None: self.on_update(self) return rv - __setitem__ = _calls_update("__setitem__") - __delitem__ = _calls_update("__delitem__") - clear = _calls_update("clear") - popitem = _calls_update("popitem") - update = _calls_update("update") + @_always_update + def __setitem__(self, key: K, value: V) -> None: + super().__setitem__(key, value) + + @_always_update + def __delitem__(self, key: K) -> None: + super().__delitem__(key) + + @_always_update + def clear(self) -> None: + super().clear() + + @_always_update + def popitem(self) -> tuple[K, V]: + return super().popitem() + + @_always_update + def update( # type: ignore[override] + self, + arg: cabc.Mapping[K, V] | cabc.Iterable[tuple[K, V]] | None = None, + /, + **kwargs: V, + ) -> None: + if arg is None: + super().update(**kwargs) + else: + super().update(arg, **kwargs) + + @_always_update + def __ior__( # type: ignore[override] + self, other: cabc.Mapping[K, V] | cabc.Iterable[tuple[K, V]] + ) -> te.Self: + return super().__ior__(other) diff --git a/src/werkzeug/datastructures/mixins.pyi b/src/werkzeug/datastructures/mixins.pyi deleted file mode 100644 index 74ed4b81e..000000000 --- a/src/werkzeug/datastructures/mixins.pyi +++ /dev/null @@ -1,97 +0,0 @@ -from collections.abc import Callable -from collections.abc import Hashable -from collections.abc import Iterable -from typing import Any -from typing import NoReturn -from typing import overload -from typing import SupportsIndex -from typing import TypeVar - -from _typeshed import SupportsKeysAndGetItem - -from .headers import Headers - -K = TypeVar("K") -T = TypeVar("T") -V = TypeVar("V") - -def is_immutable(self: object) -> NoReturn: ... - -class ImmutableListMixin(list[V]): - _hash_cache: int | None - def __hash__(self) -> int: ... # type: ignore - def __delitem__(self, key: SupportsIndex | slice) -> NoReturn: ... - def __iadd__(self, other: t.Any) -> NoReturn: ... # type: ignore - def __imul__(self, other: SupportsIndex) -> NoReturn: ... - def __setitem__(self, key: int | slice, value: V) -> NoReturn: ... # type: ignore - def append(self, value: V) -> NoReturn: ... - def remove(self, value: V) -> NoReturn: ... - def extend(self, values: Iterable[V]) -> NoReturn: ... - def insert(self, pos: SupportsIndex, value: V) -> NoReturn: ... - def pop(self, index: SupportsIndex = -1) -> NoReturn: ... - def reverse(self) -> NoReturn: ... - def sort( - self, key: Callable[[V], Any] | None = None, reverse: bool = False - ) -> NoReturn: ... - -class ImmutableDictMixin(dict[K, V]): - _hash_cache: int | None - @classmethod - def fromkeys( # type: ignore - cls, keys: Iterable[K], value: V | None = None - ) -> ImmutableDictMixin[K, V]: ... - def _iter_hashitems(self) -> Iterable[Hashable]: ... - def __hash__(self) -> int: ... # type: ignore - def setdefault(self, key: K, default: V | None = None) -> NoReturn: ... - def update(self, *args: Any, **kwargs: V) -> NoReturn: ... - def pop(self, key: K, default: V | None = None) -> NoReturn: ... # type: ignore - def popitem(self) -> NoReturn: ... - def __setitem__(self, key: K, value: V) -> NoReturn: ... - def __delitem__(self, key: K) -> NoReturn: ... - def clear(self) -> NoReturn: ... - -class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): - def _iter_hashitems(self) -> Iterable[Hashable]: ... - def add(self, key: K, value: V) -> NoReturn: ... - def popitemlist(self) -> NoReturn: ... - def poplist(self, key: K) -> NoReturn: ... - def setlist(self, key: K, new_list: Iterable[V]) -> NoReturn: ... - def setlistdefault( - self, key: K, default_list: Iterable[V] | None = None - ) -> NoReturn: ... - -class ImmutableHeadersMixin(Headers): - def __delitem__(self, key: Any, _index_operation: bool = True) -> NoReturn: ... - def __setitem__(self, key: Any, value: Any) -> NoReturn: ... - def set(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... - def setlist(self, key: Any, values: Any) -> NoReturn: ... - def add(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... - def add_header(self, _key: Any, _value: Any, **_kw: Any) -> NoReturn: ... - def remove(self, key: Any) -> NoReturn: ... - def extend(self, *args: Any, **kwargs: Any) -> NoReturn: ... - def update(self, *args: Any, **kwargs: Any) -> NoReturn: ... - def insert(self, pos: Any, value: Any) -> NoReturn: ... - def pop(self, key: Any = None, default: Any = ...) -> NoReturn: ... - def popitem(self) -> NoReturn: ... - def setdefault(self, key: Any, default: Any) -> NoReturn: ... - def setlistdefault(self, key: Any, default: Any) -> NoReturn: ... - -def _calls_update(name: str) -> Callable[[UpdateDictMixin[K, V]], Any]: ... - -class UpdateDictMixin(dict[K, V]): - on_update: Callable[[UpdateDictMixin[K, V] | None, None], None] - def setdefault(self, key: K, default: V | None = None) -> V: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: V | T = ...) -> V | T: ... - def __setitem__(self, key: K, value: V) -> None: ... - def __delitem__(self, key: K) -> None: ... - def clear(self) -> None: ... - def popitem(self) -> tuple[K, V]: ... - @overload - def update(self, __m: SupportsKeysAndGetItem[K, V], **kwargs: V) -> None: ... - @overload - def update(self, __m: Iterable[tuple[K, V]], **kwargs: V) -> None: ... - @overload - def update(self, **kwargs: V) -> None: ... diff --git a/src/werkzeug/datastructures/range.py b/src/werkzeug/datastructures/range.py index 7011ea4ae..4c9f67d44 100644 --- a/src/werkzeug/datastructures/range.py +++ b/src/werkzeug/datastructures/range.py @@ -1,5 +1,14 @@ from __future__ import annotations +import collections.abc as cabc +import typing as t +from datetime import datetime + +if t.TYPE_CHECKING: + import typing_extensions as te + +T = t.TypeVar("T") + class IfRange: """Very simple object that represents the `If-Range` header in parsed @@ -9,14 +18,14 @@ class IfRange: .. versionadded:: 0.7 """ - def __init__(self, etag=None, date=None): + def __init__(self, etag: str | None = None, date: datetime | None = None): #: The etag parsed and unquoted. Ranges always operate on strong #: etags so the weakness information is not necessary. self.etag = etag #: The date in parsed format or `None`. self.date = date - def to_header(self): + def to_header(self) -> str: """Converts the object back into an HTTP header.""" if self.date is not None: return http.http_date(self.date) @@ -24,10 +33,10 @@ def to_header(self): return http.quote_etag(self.etag) return "" - def __str__(self): + def __str__(self) -> str: return self.to_header() - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {str(self)!r}>" @@ -44,7 +53,9 @@ class Range: .. versionadded:: 0.7 """ - def __init__(self, units, ranges): + def __init__( + self, units: str, ranges: cabc.Sequence[tuple[int, int | None]] + ) -> None: #: The units of this range. Usually "bytes". self.units = units #: A list of ``(begin, end)`` tuples for the range header provided. @@ -55,7 +66,7 @@ def __init__(self, units, ranges): if start is None or (end is not None and (start < 0 or start >= end)): raise ValueError(f"{(start, end)} is not a valid range.") - def range_for_length(self, length): + def range_for_length(self, length: int | None) -> tuple[int, int] | None: """If the range is for bytes, the length is not None and there is exactly one range and it is satisfiable it returns a ``(start, stop)`` tuple, otherwise `None`. @@ -71,7 +82,7 @@ def range_for_length(self, length): return start, min(end, length) return None - def make_content_range(self, length): + def make_content_range(self, length: int | None) -> ContentRange | None: """Creates a :class:`~werkzeug.datastructures.ContentRange` object from the current range and given content length. """ @@ -80,7 +91,7 @@ def make_content_range(self, length): return ContentRange(self.units, rng[0], rng[1], length) return None - def to_header(self): + def to_header(self) -> str: """Converts the object back into an HTTP header.""" ranges = [] for begin, end in self.ranges: @@ -90,7 +101,7 @@ def to_header(self): ranges.append(f"{begin}-{end - 1}") return f"{self.units}={','.join(ranges)}" - def to_content_range_header(self, length): + def to_content_range_header(self, length: int | None) -> str | None: """Converts the object into `Content-Range` HTTP header, based on given length """ @@ -99,23 +110,34 @@ def to_content_range_header(self, length): return f"{self.units} {range[0]}-{range[1] - 1}/{length}" return None - def __str__(self): + def __str__(self) -> str: return self.to_header() - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {str(self)!r}>" -def _callback_property(name): - def fget(self): - return getattr(self, name) +class _CallbackProperty(t.Generic[T]): + def __set_name__(self, owner: type[ContentRange], name: str) -> None: + self.attr = f"_{name}" - def fset(self, value): - setattr(self, name, value) - if self.on_update is not None: - self.on_update(self) + @t.overload + def __get__(self, instance: None, owner: None) -> te.Self: ... + @t.overload + def __get__(self, instance: ContentRange, owner: type[ContentRange]) -> T: ... + def __get__( + self, instance: ContentRange | None, owner: type[ContentRange] | None + ) -> te.Self | T: + if instance is None: + return self - return property(fget, fset) + return instance.__dict__[self.attr] # type: ignore[no-any-return] + + def __set__(self, instance: ContentRange, value: T) -> None: + instance.__dict__[self.attr] = value + + if instance.on_update is not None: + instance.on_update(instance) class ContentRange: @@ -124,55 +146,67 @@ class ContentRange: .. versionadded:: 0.7 """ - def __init__(self, units, start, stop, length=None, on_update=None): - assert http.is_byte_range_valid(start, stop, length), "Bad range provided" + def __init__( + self, + units: str | None, + start: int | None, + stop: int | None, + length: int | None = None, + on_update: cabc.Callable[[ContentRange], None] | None = None, + ) -> None: self.on_update = on_update self.set(start, stop, length, units) #: The units to use, usually "bytes" - units = _callback_property("_units") + units: str | None = _CallbackProperty() # type: ignore[assignment] #: The start point of the range or `None`. - start = _callback_property("_start") + start: int | None = _CallbackProperty() # type: ignore[assignment] #: The stop point of the range (non-inclusive) or `None`. Can only be #: `None` if also start is `None`. - stop = _callback_property("_stop") + stop: int | None = _CallbackProperty() # type: ignore[assignment] #: The length of the range or `None`. - length = _callback_property("_length") - - def set(self, start, stop, length=None, units="bytes"): + length: int | None = _CallbackProperty() # type: ignore[assignment] + + def set( + self, + start: int | None, + stop: int | None, + length: int | None = None, + units: str | None = "bytes", + ) -> None: """Simple method to update the ranges.""" assert http.is_byte_range_valid(start, stop, length), "Bad range provided" - self._units = units - self._start = start - self._stop = stop - self._length = length + self._units: str | None = units + self._start: int | None = start + self._stop: int | None = stop + self._length: int | None = length if self.on_update is not None: self.on_update(self) - def unset(self): + def unset(self) -> None: """Sets the units to `None` which indicates that the header should no longer be used. """ self.set(None, None, units=None) - def to_header(self): - if self.units is None: + def to_header(self) -> str: + if self._units is None: return "" - if self.length is None: - length = "*" + if self._length is None: + length: str | int = "*" else: - length = self.length - if self.start is None: - return f"{self.units} */{length}" - return f"{self.units} {self.start}-{self.stop - 1}/{length}" + length = self._length + if self._start is None: + return f"{self._units} */{length}" + return f"{self._units} {self._start}-{self._stop - 1}/{length}" # type: ignore[operator] - def __bool__(self): - return self.units is not None + def __bool__(self) -> bool: + return self._units is not None - def __str__(self): + def __str__(self) -> str: return self.to_header() - def __repr__(self): + def __repr__(self) -> str: return f"<{type(self).__name__} {str(self)!r}>" diff --git a/src/werkzeug/datastructures/range.pyi b/src/werkzeug/datastructures/range.pyi deleted file mode 100644 index f38ad69ef..000000000 --- a/src/werkzeug/datastructures/range.pyi +++ /dev/null @@ -1,57 +0,0 @@ -from collections.abc import Callable -from datetime import datetime - -class IfRange: - etag: str | None - date: datetime | None - def __init__( - self, etag: str | None = None, date: datetime | None = None - ) -> None: ... - def to_header(self) -> str: ... - -class Range: - units: str - ranges: list[tuple[int, int | None]] - def __init__(self, units: str, ranges: list[tuple[int, int | None]]) -> None: ... - def range_for_length(self, length: int | None) -> tuple[int, int] | None: ... - def make_content_range(self, length: int | None) -> ContentRange | None: ... - def to_header(self) -> str: ... - def to_content_range_header(self, length: int | None) -> str | None: ... - -def _callback_property(name: str) -> property: ... - -class ContentRange: - on_update: Callable[[ContentRange], None] | None - def __init__( - self, - units: str | None, - start: int | None, - stop: int | None, - length: int | None = None, - on_update: Callable[[ContentRange], None] | None = None, - ) -> None: ... - @property - def units(self) -> str | None: ... - @units.setter - def units(self, value: str | None) -> None: ... - @property - def start(self) -> int | None: ... - @start.setter - def start(self, value: int | None) -> None: ... - @property - def stop(self) -> int | None: ... - @stop.setter - def stop(self, value: int | None) -> None: ... - @property - def length(self) -> int | None: ... - @length.setter - def length(self, value: int | None) -> None: ... - def set( - self, - start: int | None, - stop: int | None, - length: int | None = None, - units: str | None = "bytes", - ) -> None: ... - def unset(self) -> None: ... - def to_header(self) -> str: ... diff --git a/src/werkzeug/datastructures/structures.py b/src/werkzeug/datastructures/structures.py index 7ea7bee28..dbb7e8048 100644 --- a/src/werkzeug/datastructures/structures.py +++ b/src/werkzeug/datastructures/structures.py @@ -1,6 +1,7 @@ from __future__ import annotations -from collections.abc import MutableSet +import collections.abc as cabc +import typing as t from copy import deepcopy from .. import exceptions @@ -10,20 +11,29 @@ from .mixins import ImmutableMultiDictMixin from .mixins import UpdateDictMixin +if t.TYPE_CHECKING: + import typing_extensions as te -def is_immutable(self): - raise TypeError(f"{type(self).__name__!r} objects are immutable") +K = t.TypeVar("K") +V = t.TypeVar("V") +T = t.TypeVar("T") -def iter_multi_items(mapping): +def iter_multi_items( + mapping: ( + MultiDict[K, V] + | cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + | cabc.Iterable[tuple[K, V]] + ), +) -> cabc.Iterator[tuple[K, V]]: """Iterates over the items of a mapping yielding keys and values without dropping any from more complex structures. """ if isinstance(mapping, MultiDict): yield from mapping.items(multi=True) - elif isinstance(mapping, dict): + elif isinstance(mapping, cabc.Mapping): for key, value in mapping.items(): - if isinstance(value, (tuple, list)): + if isinstance(value, (list, tuple, set)): for v in value: yield key, v else: @@ -32,7 +42,7 @@ def iter_multi_items(mapping): yield from mapping -class ImmutableList(ImmutableListMixin, list): +class ImmutableList(ImmutableListMixin, list[V]): # type: ignore[misc] """An immutable :class:`list`. .. versionadded:: 0.5 @@ -40,11 +50,11 @@ class ImmutableList(ImmutableListMixin, list): :private: """ - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({list.__repr__(self)})" -class TypeConversionDict(dict): +class TypeConversionDict(dict[K, V]): """Works like a regular dict but the :meth:`get` method can perform type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` are subclasses of this class and provide the same feature. @@ -52,7 +62,22 @@ class TypeConversionDict(dict): .. versionadded:: 0.5 """ - def get(self, key, default=None, type=None): + @t.overload # type: ignore[override] + def get(self, key: K) -> V | None: ... + @t.overload + def get(self, key: K, default: V) -> V: ... + @t.overload + def get(self, key: K, default: T) -> V | T: ... + @t.overload + def get(self, key: str, type: cabc.Callable[[V], T]) -> T | None: ... + @t.overload + def get(self, key: str, default: T, type: cabc.Callable[[V], T]) -> T: ... + def get( # type: ignore[misc] + self, + key: K, + default: V | T | None = None, + type: cabc.Callable[[V], T] | None = None, + ) -> V | T | None: """Return the default value if the requested data doesn't exist. If `type` is provided and is a callable it should convert the value, return it or raise a :exc:`ValueError` if that is not possible. In @@ -70,40 +95,46 @@ def get(self, key, default=None, type=None): be looked up. If not further specified `None` is returned. :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the default value is returned. + :class:`MultiDict`. If a :exc:`ValueError` or a + :exc:`TypeError` is raised by this callable the default + value is returned. + + .. versionchanged:: 3.0.2 + Returns the default value on :exc:`TypeError`, too. """ try: rv = self[key] except KeyError: return default - if type is not None: - try: - rv = type(rv) - except ValueError: - rv = default - return rv + if type is None: + return rv -class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): + try: + return type(rv) + except (ValueError, TypeError): + return default + + +class ImmutableTypeConversionDict(ImmutableDictMixin[K, V], TypeConversionDict[K, V]): # type: ignore[misc] """Works like a :class:`TypeConversionDict` but does not support modifications. .. versionadded:: 0.5 """ - def copy(self): + def copy(self) -> TypeConversionDict[K, V]: """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ return TypeConversionDict(self) - def __copy__(self): + def __copy__(self) -> te.Self: return self -class MultiDict(TypeConversionDict): +class MultiDict(TypeConversionDict[K, V]): """A :class:`MultiDict` is a dictionary subclass customized to deal with multiple values for the same key which is for example used by the parsing functions in the wrappers. This is necessary because some HTML form @@ -142,42 +173,57 @@ class MultiDict(TypeConversionDict): :param mapping: the initial value for the :class:`MultiDict`. Either a regular dict, an iterable of ``(key, value)`` tuples or `None`. + + .. versionchanged:: 3.1 + Implement ``|`` and ``|=`` operators. """ - def __init__(self, mapping=None): - if isinstance(mapping, MultiDict): - dict.__init__(self, ((k, l[:]) for k, l in mapping.lists())) - elif isinstance(mapping, dict): + def __init__( + self, + mapping: ( + MultiDict[K, V] + | cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + | cabc.Iterable[tuple[K, V]] + | None + ) = None, + ) -> None: + if mapping is None: + super().__init__() + elif isinstance(mapping, MultiDict): + super().__init__((k, vs[:]) for k, vs in mapping.lists()) + elif isinstance(mapping, cabc.Mapping): tmp = {} for key, value in mapping.items(): - if isinstance(value, (tuple, list)): - if len(value) == 0: - continue + if isinstance(value, (list, tuple, set)): value = list(value) + + if not value: + continue else: value = [value] tmp[key] = value - dict.__init__(self, tmp) + super().__init__(tmp) # type: ignore[arg-type] else: tmp = {} - for key, value in mapping or (): + for key, value in mapping: tmp.setdefault(key, []).append(value) - dict.__init__(self, tmp) + super().__init__(tmp) # type: ignore[arg-type] - def __getstate__(self): + def __getstate__(self) -> t.Any: return dict(self.lists()) - def __setstate__(self, value): - dict.clear(self) - dict.update(self, value) + def __setstate__(self, value: t.Any) -> None: + super().clear() + super().update(value) - def __iter__(self): - # Work around https://bugs.python.org/issue43246. - # (`return super().__iter__()` also works here, which makes this look - # even more like it should be a no-op, yet it isn't.) - return dict.__iter__(self) + def __iter__(self) -> cabc.Iterator[K]: + # https://github.com/python/cpython/issues/87412 + # If __iter__ is not overridden, Python uses a fast path for dict(md), + # taking the data directly and getting lists of values, rather than + # calling __getitem__ and getting only the first value. + return super().__iter__() - def __getitem__(self, key): + def __getitem__(self, key: K) -> V: """Return the first data value for this key; raises KeyError if not found. @@ -186,20 +232,20 @@ def __getitem__(self, key): """ if key in self: - lst = dict.__getitem__(self, key) - if len(lst) > 0: - return lst[0] + lst = super().__getitem__(key) + if len(lst) > 0: # type: ignore[arg-type] + return lst[0] # type: ignore[index,no-any-return] raise exceptions.BadRequestKeyError(key) - def __setitem__(self, key, value): + def __setitem__(self, key: K, value: V) -> None: """Like :meth:`add` but removes an existing key first. :param key: the key for the value. :param value: the value to set. """ - dict.__setitem__(self, key, [value]) + super().__setitem__(key, [value]) # type: ignore[assignment] - def add(self, key, value): + def add(self, key: K, value: V) -> None: """Adds a new value for the key. .. versionadded:: 0.6 @@ -207,22 +253,30 @@ def add(self, key, value): :param key: the key for the value. :param value: the value to add. """ - dict.setdefault(self, key, []).append(value) - - def getlist(self, key, type=None): + super().setdefault(key, []).append(value) # type: ignore[arg-type,attr-defined] + + @t.overload + def getlist(self, key: K) -> list[V]: ... + @t.overload + def getlist(self, key: K, type: cabc.Callable[[V], T]) -> list[T]: ... + def getlist( + self, key: K, type: cabc.Callable[[V], T] | None = None + ) -> list[V] | list[T]: """Return the list of items for a given key. If that key is not in the `MultiDict`, the return value will be an empty list. Just like `get`, `getlist` accepts a `type` parameter. All items will be converted with the callable defined there. :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. + :param type: Callable to convert each value. If a ``ValueError`` or + ``TypeError`` is raised, the value is omitted. :return: a :class:`list` of all the values for the key. + + .. versionchanged:: 3.1 + Catches ``TypeError`` in addition to ``ValueError``. """ try: - rv = dict.__getitem__(self, key) + rv: list[V] = super().__getitem__(key) # type: ignore[assignment] except KeyError: return [] if type is None: @@ -231,11 +285,11 @@ def getlist(self, key, type=None): for item in rv: try: result.append(type(item)) - except ValueError: + except (ValueError, TypeError): pass return result - def setlist(self, key, new_list): + def setlist(self, key: K, new_list: cabc.Iterable[V]) -> None: """Remove the old values for a key and add new ones. Note that the list you pass the values in will be shallow-copied before it is inserted in the dictionary. @@ -251,9 +305,13 @@ def setlist(self, key, new_list): :param new_list: An iterable with the new values for the key. Old values are removed first. """ - dict.__setitem__(self, key, list(new_list)) + super().__setitem__(key, list(new_list)) # type: ignore[assignment] - def setdefault(self, key, default=None): + @t.overload + def setdefault(self, key: K) -> None: ... + @t.overload + def setdefault(self, key: K, default: V) -> V: ... + def setdefault(self, key: K, default: V | None = None) -> V | None: """Returns the value for the key if it is in the dict, otherwise it returns `default` and sets that value for `key`. @@ -262,12 +320,13 @@ def setdefault(self, key, default=None): in the dict. If not further specified it's `None`. """ if key not in self: - self[key] = default - else: - default = self[key] - return default + self[key] = default # type: ignore[assignment] - def setlistdefault(self, key, default_list=None): + return self[key] + + def setlistdefault( + self, key: K, default_list: cabc.Iterable[V] | None = None + ) -> list[V]: """Like `setdefault` but sets multiple values. The list returned is not a copy, but the list that is actually used internally. This means that you can put new values into the dict by appending items @@ -285,38 +344,42 @@ def setlistdefault(self, key, default_list=None): :return: a :class:`list` """ if key not in self: - default_list = list(default_list or ()) - dict.__setitem__(self, key, default_list) - else: - default_list = dict.__getitem__(self, key) - return default_list + super().__setitem__(key, list(default_list or ())) # type: ignore[assignment] - def items(self, multi=False): + return super().__getitem__(key) # type: ignore[return-value] + + def items(self, multi: bool = False) -> cabc.Iterable[tuple[K, V]]: # type: ignore[override] """Return an iterator of ``(key, value)`` pairs. :param multi: If set to `True` the iterator returned will have a pair for each value of each key. Otherwise it will only contain pairs for the first value of each key. """ - for key, values in dict.items(self): + values: list[V] + + for key, values in super().items(): # type: ignore[assignment] if multi: for value in values: yield key, value else: yield key, values[0] - def lists(self): + def lists(self) -> cabc.Iterable[tuple[K, list[V]]]: """Return a iterator of ``(key, values)`` pairs, where values is the list of all values associated with the key.""" - for key, values in dict.items(self): + values: list[V] + + for key, values in super().items(): # type: ignore[assignment] yield key, list(values) - def values(self): + def values(self) -> cabc.Iterable[V]: # type: ignore[override] """Returns an iterator of the first value on every key's value list.""" - for values in dict.values(self): + values: list[V] + + for values in super().values(): # type: ignore[assignment] yield values[0] - def listvalues(self): + def listvalues(self) -> cabc.Iterable[list[V]]: """Return an iterator of all values associated with a key. Zipping :meth:`keys` and this is the same as calling :meth:`lists`: @@ -324,17 +387,21 @@ def listvalues(self): >>> zip(d.keys(), d.listvalues()) == d.lists() True """ - return dict.values(self) + return super().values() # type: ignore[return-value] - def copy(self): + def copy(self) -> te.Self: """Return a shallow copy of this object.""" return self.__class__(self) - def deepcopy(self, memo=None): + def deepcopy(self, memo: t.Any = None) -> te.Self: """Return a deep copy of this object.""" return self.__class__(deepcopy(self.to_dict(flat=False), memo)) - def to_dict(self, flat=True): + @t.overload + def to_dict(self) -> dict[K, V]: ... + @t.overload + def to_dict(self, flat: t.Literal[False]) -> dict[K, list[V]]: ... + def to_dict(self, flat: bool = True) -> dict[K, V] | dict[K, list[V]]: """Return the contents as regular dict. If `flat` is `True` the returned dict will only have the first item present, if `flat` is `False` all values will be returned as lists. @@ -348,7 +415,14 @@ def to_dict(self, flat=True): return dict(self.items()) return dict(self.lists()) - def update(self, mapping): + def update( # type: ignore[override] + self, + mapping: ( + MultiDict[K, V] + | cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + | cabc.Iterable[tuple[K, V]] + ), + ) -> None: """update() extends rather than replaces existing key lists: >>> a = MultiDict({'x': 1}) @@ -367,9 +441,42 @@ def update(self, mapping): MultiDict([]) """ for key, value in iter_multi_items(mapping): - MultiDict.add(self, key, value) + self.add(key, value) + + def __or__( # type: ignore[override] + self, other: cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + ) -> MultiDict[K, V]: + if not isinstance(other, cabc.Mapping): + return NotImplemented + + rv = self.copy() + rv.update(other) + return rv + + def __ior__( # type: ignore[override] + self, + other: ( + cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + | cabc.Iterable[tuple[K, V]] + ), + ) -> te.Self: + if not isinstance(other, (cabc.Mapping, cabc.Iterable)): + return NotImplemented + + self.update(other) + return self - def pop(self, key, default=_missing): + @t.overload + def pop(self, key: K) -> V: ... + @t.overload + def pop(self, key: K, default: V) -> V: ... + @t.overload + def pop(self, key: K, default: T) -> V | T: ... + def pop( + self, + key: K, + default: V | T = _missing, # type: ignore[assignment] + ) -> V | T: """Pop the first item for a list on the dict. Afterwards the key is removed from the dict, so additional values are discarded: @@ -383,8 +490,10 @@ def pop(self, key, default=_missing): :param default: if provided the value to return if the key was not in the dictionary. """ + lst: list[V] + try: - lst = dict.pop(self, key) + lst = super().pop(key) # type: ignore[assignment] if len(lst) == 0: raise exceptions.BadRequestKeyError(key) @@ -396,19 +505,21 @@ def pop(self, key, default=_missing): raise exceptions.BadRequestKeyError(key) from None - def popitem(self): + def popitem(self) -> tuple[K, V]: """Pop an item from the dict.""" + item: tuple[K, list[V]] + try: - item = dict.popitem(self) + item = super().popitem() # type: ignore[assignment] if len(item[1]) == 0: raise exceptions.BadRequestKeyError(item[0]) - return (item[0], item[1][0]) + return item[0], item[1][0] except KeyError as e: raise exceptions.BadRequestKeyError(e.args[0]) from None - def poplist(self, key): + def poplist(self, key: K) -> list[V]: """Pop the list for a key from the dict. If the key is not in the dict an empty list is returned. @@ -416,26 +527,26 @@ def poplist(self, key): If the key does no longer exist a list is returned instead of raising an error. """ - return dict.pop(self, key, []) + return super().pop(key, []) # type: ignore[return-value] - def popitemlist(self): + def popitemlist(self) -> tuple[K, list[V]]: """Pop a ``(key, list)`` tuple from the dict.""" try: - return dict.popitem(self) + return super().popitem() # type: ignore[return-value] except KeyError as e: raise exceptions.BadRequestKeyError(e.args[0]) from None - def __copy__(self): + def __copy__(self) -> te.Self: return self.copy() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: t.Any) -> te.Self: return self.deepcopy(memo=memo) - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({list(self.items(multi=True))!r})" -class _omd_bucket: +class _omd_bucket(t.Generic[K, V]): """Wraps values in the :class:`OrderedMultiDict`. This makes it possible to keep an order over multiple different keys. It requires a lot of extra memory and slows down access a lot, but makes it @@ -444,11 +555,11 @@ class _omd_bucket: __slots__ = ("prev", "key", "value", "next") - def __init__(self, omd, key, value): - self.prev = omd._last_bucket - self.key = key - self.value = value - self.next = None + def __init__(self, omd: _OrderedMultiDict[K, V], key: K, value: V) -> None: + self.prev: _omd_bucket[K, V] | None = omd._last_bucket + self.key: K = key + self.value: V = value + self.next: _omd_bucket[K, V] | None = None if omd._first_bucket is None: omd._first_bucket = self @@ -456,7 +567,7 @@ def __init__(self, omd, key, value): omd._last_bucket.next = self omd._last_bucket = self - def unlink(self, omd): + def unlink(self, omd: _OrderedMultiDict[K, V]) -> None: if self.prev: self.prev.next = self.next if self.next: @@ -467,7 +578,7 @@ def unlink(self, omd): omd._last_bucket = self.prev -class OrderedMultiDict(MultiDict): +class _OrderedMultiDict(MultiDict[K, V]): """Works like a regular :class:`MultiDict` but preserves the order of the fields. To convert the ordered multi dict into a list you can use the :meth:`items` method and pass it ``multi=True``. @@ -481,18 +592,38 @@ class OrderedMultiDict(MultiDict): multi dict into a regular dict by using ``dict(multidict)``. Instead you have to use the :meth:`to_dict` method, otherwise the internal bucket objects are exposed. + + .. deprecated:: 3.1 + Will be removed in Werkzeug 3.2. Use ``MultiDict`` instead. """ - def __init__(self, mapping=None): - dict.__init__(self) - self._first_bucket = self._last_bucket = None + def __init__( + self, + mapping: ( + MultiDict[K, V] + | cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + | cabc.Iterable[tuple[K, V]] + | None + ) = None, + ) -> None: + import warnings + + warnings.warn( + "'OrderedMultiDict' is deprecated and will be removed in Werkzeug" + " 3.2. Use 'MultiDict' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__() + self._first_bucket: _omd_bucket[K, V] | None = None + self._last_bucket: _omd_bucket[K, V] | None = None if mapping is not None: - OrderedMultiDict.update(self, mapping) + self.update(mapping) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, MultiDict): return NotImplemented - if isinstance(other, OrderedMultiDict): + if isinstance(other, _OrderedMultiDict): iter1 = iter(self.items(multi=True)) iter2 = iter(other.items(multi=True)) try: @@ -514,41 +645,42 @@ def __eq__(self, other): return False return True - __hash__ = None + __hash__ = None # type: ignore[assignment] - def __reduce_ex__(self, protocol): + def __reduce_ex__(self, protocol: t.SupportsIndex) -> t.Any: return type(self), (list(self.items(multi=True)),) - def __getstate__(self): + def __getstate__(self) -> t.Any: return list(self.items(multi=True)) - def __setstate__(self, values): - dict.clear(self) + def __setstate__(self, values: t.Any) -> None: + self.clear() + for key, value in values: self.add(key, value) - def __getitem__(self, key): + def __getitem__(self, key: K) -> V: if key in self: - return dict.__getitem__(self, key)[0].value + return dict.__getitem__(self, key)[0].value # type: ignore[index,no-any-return] raise exceptions.BadRequestKeyError(key) - def __setitem__(self, key, value): + def __setitem__(self, key: K, value: V) -> None: self.poplist(key) self.add(key, value) - def __delitem__(self, key): + def __delitem__(self, key: K) -> None: self.pop(key) - def keys(self): - return (key for key, value in self.items()) + def keys(self) -> cabc.Iterable[K]: # type: ignore[override] + return (key for key, _ in self.items()) - def __iter__(self): + def __iter__(self) -> cabc.Iterator[K]: return iter(self.keys()) - def values(self): + def values(self) -> cabc.Iterable[V]: # type: ignore[override] return (value for key, value in self.items()) - def items(self, multi=False): + def items(self, multi: bool = False) -> cabc.Iterable[tuple[K, V]]: # type: ignore[override] ptr = self._first_bucket if multi: while ptr is not None: @@ -562,7 +694,7 @@ def items(self, multi=False): yield ptr.key, ptr.value ptr = ptr.next - def lists(self): + def lists(self) -> cabc.Iterable[tuple[K, list[V]]]: returned_keys = set() ptr = self._first_bucket while ptr is not None: @@ -571,16 +703,24 @@ def lists(self): returned_keys.add(ptr.key) ptr = ptr.next - def listvalues(self): + def listvalues(self) -> cabc.Iterable[list[V]]: for _key, values in self.lists(): yield values - def add(self, key, value): - dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) + def add(self, key: K, value: V) -> None: + dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) # type: ignore[arg-type,attr-defined] + + @t.overload + def getlist(self, key: K) -> list[V]: ... + @t.overload + def getlist(self, key: K, type: cabc.Callable[[V], T]) -> list[T]: ... + def getlist( + self, key: K, type: cabc.Callable[[V], T] | None = None + ) -> list[V] | list[T]: + rv: list[_omd_bucket[K, V]] - def getlist(self, key, type=None): try: - rv = dict.__getitem__(self, key) + rv = dict.__getitem__(self, key) # type: ignore[index] except KeyError: return [] if type is None: @@ -589,31 +729,50 @@ def getlist(self, key, type=None): for item in rv: try: result.append(type(item.value)) - except ValueError: + except (ValueError, TypeError): pass return result - def setlist(self, key, new_list): + def setlist(self, key: K, new_list: cabc.Iterable[V]) -> None: self.poplist(key) for value in new_list: self.add(key, value) - def setlistdefault(self, key, default_list=None): + def setlistdefault(self, key: t.Any, default_list: t.Any = None) -> t.NoReturn: raise TypeError("setlistdefault is unsupported for ordered multi dicts") - def update(self, mapping): + def update( # type: ignore[override] + self, + mapping: ( + MultiDict[K, V] + | cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + | cabc.Iterable[tuple[K, V]] + ), + ) -> None: for key, value in iter_multi_items(mapping): - OrderedMultiDict.add(self, key, value) + self.add(key, value) - def poplist(self, key): - buckets = dict.pop(self, key, ()) + def poplist(self, key: K) -> list[V]: + buckets: cabc.Iterable[_omd_bucket[K, V]] = dict.pop(self, key, ()) # type: ignore[arg-type] for bucket in buckets: bucket.unlink(self) return [x.value for x in buckets] - def pop(self, key, default=_missing): + @t.overload + def pop(self, key: K) -> V: ... + @t.overload + def pop(self, key: K, default: V) -> V: ... + @t.overload + def pop(self, key: K, default: T) -> V | T: ... + def pop( + self, + key: K, + default: V | T = _missing, # type: ignore[assignment] + ) -> V | T: + buckets: list[_omd_bucket[K, V]] + try: - buckets = dict.pop(self, key) + buckets = dict.pop(self, key) # type: ignore[arg-type] except KeyError: if default is not _missing: return default @@ -625,9 +784,12 @@ def pop(self, key, default=_missing): return buckets[0].value - def popitem(self): + def popitem(self) -> tuple[K, V]: + key: K + buckets: list[_omd_bucket[K, V]] + try: - key, buckets = dict.popitem(self) + key, buckets = dict.popitem(self) # type: ignore[arg-type,assignment] except KeyError as e: raise exceptions.BadRequestKeyError(e.args[0]) from None @@ -636,9 +798,12 @@ def popitem(self): return key, buckets[0].value - def popitemlist(self): + def popitemlist(self) -> tuple[K, list[V]]: + key: K + buckets: list[_omd_bucket[K, V]] + try: - key, buckets = dict.popitem(self) + key, buckets = dict.popitem(self) # type: ignore[arg-type,assignment] except KeyError as e: raise exceptions.BadRequestKeyError(e.args[0]) from None @@ -648,7 +813,7 @@ def popitemlist(self): return key, [x.value for x in buckets] -class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): +class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore[misc] """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` instances as sequence and it will combine the return values of all wrapped dicts: @@ -671,54 +836,80 @@ class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): exceptions. """ - def __reduce_ex__(self, protocol): + def __reduce_ex__(self, protocol: t.SupportsIndex) -> t.Any: return type(self), (self.dicts,) - def __init__(self, dicts=None): - self.dicts = list(dicts) or [] + def __init__(self, dicts: cabc.Iterable[MultiDict[K, V]] | None = None) -> None: + super().__init__() + self.dicts: list[MultiDict[K, V]] = list(dicts or ()) @classmethod - def fromkeys(cls, keys, value=None): + def fromkeys(cls, keys: t.Any, value: t.Any = None) -> t.NoReturn: raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys") - def __getitem__(self, key): + def __getitem__(self, key: K) -> V: for d in self.dicts: if key in d: return d[key] raise exceptions.BadRequestKeyError(key) - def get(self, key, default=None, type=None): + @t.overload # type: ignore[override] + def get(self, key: K) -> V | None: ... + @t.overload + def get(self, key: K, default: V) -> V: ... + @t.overload + def get(self, key: K, default: T) -> V | T: ... + @t.overload + def get(self, key: str, type: cabc.Callable[[V], T]) -> T | None: ... + @t.overload + def get(self, key: str, default: T, type: cabc.Callable[[V], T]) -> T: ... + def get( # type: ignore[misc] + self, + key: K, + default: V | T | None = None, + type: cabc.Callable[[V], T] | None = None, + ) -> V | T | None: for d in self.dicts: if key in d: if type is not None: try: return type(d[key]) - except ValueError: + except (ValueError, TypeError): continue return d[key] return default - def getlist(self, key, type=None): + @t.overload + def getlist(self, key: K) -> list[V]: ... + @t.overload + def getlist(self, key: K, type: cabc.Callable[[V], T]) -> list[T]: ... + def getlist( + self, key: K, type: cabc.Callable[[V], T] | None = None + ) -> list[V] | list[T]: rv = [] for d in self.dicts: - rv.extend(d.getlist(key, type)) + rv.extend(d.getlist(key, type)) # type: ignore[arg-type] return rv - def _keys_impl(self): + def _keys_impl(self) -> set[K]: """This function exists so __len__ can be implemented more efficiently, saving one list creation from an iterator. """ - rv = set() - rv.update(*self.dicts) - return rv + return set(k for d in self.dicts for k in d) - def keys(self): + def keys(self) -> cabc.Iterable[K]: # type: ignore[override] return self._keys_impl() - def __iter__(self): - return iter(self.keys()) + def __iter__(self) -> cabc.Iterator[K]: + return iter(self._keys_impl()) - def items(self, multi=False): + @t.overload # type: ignore[override] + def items(self) -> cabc.Iterable[tuple[K, V]]: ... + @t.overload + def items(self, multi: t.Literal[True]) -> cabc.Iterable[tuple[K, list[V]]]: ... + def items( + self, multi: bool = False + ) -> cabc.Iterable[tuple[K, V]] | cabc.Iterable[tuple[K, list[V]]]: found = set() for d in self.dicts: for key, value in d.items(multi): @@ -728,21 +919,21 @@ def items(self, multi=False): found.add(key) yield key, value - def values(self): - for _key, value in self.items(): + def values(self) -> cabc.Iterable[V]: # type: ignore[override] + for _, value in self.items(): yield value - def lists(self): - rv = {} + def lists(self) -> cabc.Iterable[tuple[K, list[V]]]: + rv: dict[K, list[V]] = {} for d in self.dicts: for key, values in d.lists(): rv.setdefault(key, []).extend(values) - return list(rv.items()) + return rv.items() - def listvalues(self): + def listvalues(self) -> cabc.Iterable[list[V]]: return (x[1] for x in self.lists()) - def copy(self): + def copy(self) -> MultiDict[K, V]: # type: ignore[override] """Return a shallow mutable copy of this object. This returns a :class:`MultiDict` representing the data at the @@ -754,105 +945,118 @@ def copy(self): """ return MultiDict(self) - def to_dict(self, flat=True): - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first item for each key. - :return: a :class:`dict` - """ - if flat: - return dict(self.items()) - - return dict(self.lists()) - - def __len__(self): + def __len__(self) -> int: return len(self._keys_impl()) - def __contains__(self, key): + def __contains__(self, key: K) -> bool: # type: ignore[override] for d in self.dicts: if key in d: return True return False - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({self.dicts!r})" -class ImmutableDict(ImmutableDictMixin, dict): +class ImmutableDict(ImmutableDictMixin[K, V], dict[K, V]): # type: ignore[misc] """An immutable :class:`dict`. .. versionadded:: 0.5 """ - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({dict.__repr__(self)})" - def copy(self): + def copy(self) -> dict[K, V]: """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ return dict(self) - def __copy__(self): + def __copy__(self) -> te.Self: return self -class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): +class ImmutableMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore[misc] """An immutable :class:`MultiDict`. .. versionadded:: 0.5 """ - def copy(self): + def copy(self) -> MultiDict[K, V]: # type: ignore[override] """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ return MultiDict(self) - def __copy__(self): + def __copy__(self) -> te.Self: return self -class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): +class _ImmutableOrderedMultiDict( # type: ignore[misc] + ImmutableMultiDictMixin[K, V], _OrderedMultiDict[K, V] +): """An immutable :class:`OrderedMultiDict`. + .. deprecated:: 3.1 + Will be removed in Werkzeug 3.2. Use ``ImmutableMultiDict`` instead. + .. versionadded:: 0.6 """ - def _iter_hashitems(self): + def __init__( + self, + mapping: ( + MultiDict[K, V] + | cabc.Mapping[K, V | list[V] | tuple[V, ...] | set[V]] + | cabc.Iterable[tuple[K, V]] + | None + ) = None, + ) -> None: + super().__init__() + + if mapping is not None: + for k, v in iter_multi_items(mapping): + _OrderedMultiDict.add(self, k, v) + + def _iter_hashitems(self) -> cabc.Iterable[t.Any]: return enumerate(self.items(multi=True)) - def copy(self): + def copy(self) -> _OrderedMultiDict[K, V]: # type: ignore[override] """Return a shallow mutable copy of this object. Keep in mind that the standard library's :func:`copy` function is a no-op for this class like for any other python immutable type (eg: :class:`tuple`). """ - return OrderedMultiDict(self) + return _OrderedMultiDict(self) - def __copy__(self): + def __copy__(self) -> te.Self: return self -class CallbackDict(UpdateDictMixin, dict): +class CallbackDict(UpdateDictMixin[K, V], dict[K, V]): """A dict that calls a function passed every time something is changed. The function is passed the dict instance. """ - def __init__(self, initial=None, on_update=None): - dict.__init__(self, initial or ()) + def __init__( + self, + initial: cabc.Mapping[K, V] | cabc.Iterable[tuple[K, V]] | None = None, + on_update: cabc.Callable[[te.Self], None] | None = None, + ) -> None: + if initial is None: + super().__init__() + else: + super().__init__(initial) + self.on_update = on_update - def __repr__(self): - return f"<{type(self).__name__} {dict.__repr__(self)}>" + def __repr__(self) -> str: + return f"<{type(self).__name__} {super().__repr__()}>" -class HeaderSet(MutableSet): +class HeaderSet(cabc.MutableSet[str]): """Similar to the :class:`ETags` class this implements a set-like structure. Unlike :class:`ETags` this is case insensitive and used for vary, allow, and content-language headers. @@ -865,16 +1069,20 @@ class HeaderSet(MutableSet): HeaderSet(['foo', 'bar', 'baz']) """ - def __init__(self, headers=None, on_update=None): + def __init__( + self, + headers: cabc.Iterable[str] | None = None, + on_update: cabc.Callable[[te.Self], None] | None = None, + ) -> None: self._headers = list(headers or ()) self._set = {x.lower() for x in self._headers} self.on_update = on_update - def add(self, header): + def add(self, header: str) -> None: """Add a new header to the set.""" self.update((header,)) - def remove(self, header): + def remove(self: te.Self, header: str) -> None: """Remove a header from the set. This raises an :exc:`KeyError` if the header is not in the set. @@ -895,7 +1103,7 @@ def remove(self, header): if self.on_update is not None: self.on_update(self) - def update(self, iterable): + def update(self: te.Self, iterable: cabc.Iterable[str]) -> None: """Add all the headers from the iterable to the set. :param iterable: updates the set with the items from the iterable. @@ -910,7 +1118,7 @@ def update(self, iterable): if inserted_any and self.on_update is not None: self.on_update(self) - def discard(self, header): + def discard(self, header: str) -> None: """Like :meth:`remove` but ignores errors. :param header: the header to be discarded. @@ -920,7 +1128,7 @@ def discard(self, header): except KeyError: pass - def find(self, header): + def find(self, header: str) -> int: """Return the index of the header in the set or return -1 if not found. :param header: the header to be looked up. @@ -931,7 +1139,7 @@ def find(self, header): return idx return -1 - def index(self, header): + def index(self, header: str) -> int: """Return the index of the header in the set or raise an :exc:`IndexError`. @@ -942,14 +1150,15 @@ def index(self, header): raise IndexError(header) return rv - def clear(self): + def clear(self: te.Self) -> None: """Clear the set.""" self._set.clear() - del self._headers[:] + self._headers.clear() + if self.on_update is not None: self.on_update(self) - def as_set(self, preserve_casing=False): + def as_set(self, preserve_casing: bool = False) -> set[str]: """Return the set as real python set type. When calling this, all the items are converted to lowercase and the ordering is lost. @@ -962,20 +1171,20 @@ def as_set(self, preserve_casing=False): return set(self._headers) return set(self._set) - def to_header(self): + def to_header(self) -> str: """Convert the header set into an HTTP header string.""" return ", ".join(map(http.quote_header_value, self._headers)) - def __getitem__(self, idx): + def __getitem__(self, idx: t.SupportsIndex) -> str: return self._headers[idx] - def __delitem__(self, idx): + def __delitem__(self: te.Self, idx: t.SupportsIndex) -> None: rv = self._headers.pop(idx) self._set.remove(rv.lower()) if self.on_update is not None: self.on_update(self) - def __setitem__(self, idx, value): + def __setitem__(self: te.Self, idx: t.SupportsIndex, value: str) -> None: old = self._headers[idx] self._set.remove(old.lower()) self._headers[idx] = value @@ -983,24 +1192,48 @@ def __setitem__(self, idx, value): if self.on_update is not None: self.on_update(self) - def __contains__(self, header): + def __contains__(self, header: str) -> bool: # type: ignore[override] return header.lower() in self._set - def __len__(self): + def __len__(self) -> int: return len(self._set) - def __iter__(self): + def __iter__(self) -> cabc.Iterator[str]: return iter(self._headers) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._set) - def __str__(self): + def __str__(self) -> str: return self.to_header() - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({self._headers!r})" # circular dependencies from .. import http + + +def __getattr__(name: str) -> t.Any: + import warnings + + if name == "OrderedMultiDict": + warnings.warn( + "'OrderedMultiDict' is deprecated and will be removed in Werkzeug" + " 3.2. Use 'MultiDict' instead.", + DeprecationWarning, + stacklevel=2, + ) + return _OrderedMultiDict + + if name == "ImmutableOrderedMultiDict": + warnings.warn( + "'ImmutableOrderedMultiDict' is deprecated and will be removed in" + " Werkzeug 3.2. Use 'ImmutableMultiDict' instead.", + DeprecationWarning, + stacklevel=2, + ) + return _ImmutableOrderedMultiDict + + raise AttributeError(name) diff --git a/src/werkzeug/datastructures/structures.pyi b/src/werkzeug/datastructures/structures.pyi deleted file mode 100644 index 2e7af35be..000000000 --- a/src/werkzeug/datastructures/structures.pyi +++ /dev/null @@ -1,208 +0,0 @@ -from collections.abc import Callable -from collections.abc import Iterable -from collections.abc import Iterator -from collections.abc import Mapping -from typing import Any -from typing import Generic -from typing import Literal -from typing import NoReturn -from typing import overload -from typing import TypeVar - -from .mixins import ( - ImmutableDictMixin, - ImmutableListMixin, - ImmutableMultiDictMixin, - UpdateDictMixin, -) - -D = TypeVar("D") -K = TypeVar("K") -T = TypeVar("T") -V = TypeVar("V") -_CD = TypeVar("_CD", bound="CallbackDict") - -def is_immutable(self: object) -> NoReturn: ... -def iter_multi_items( - mapping: Mapping[K, V | Iterable[V]] | Iterable[tuple[K, V]] -) -> Iterator[tuple[K, V]]: ... - -class ImmutableList(ImmutableListMixin[V]): ... - -class TypeConversionDict(dict[K, V]): - @overload - def get(self, key: K, default: None = ..., type: None = ...) -> V | None: ... - @overload - def get(self, key: K, default: D, type: None = ...) -> D | V: ... - @overload - def get(self, key: K, default: D, type: Callable[[V], T]) -> D | T: ... - @overload - def get(self, key: K, type: Callable[[V], T]) -> T | None: ... - -class ImmutableTypeConversionDict(ImmutableDictMixin[K, V], TypeConversionDict[K, V]): - def copy(self) -> TypeConversionDict[K, V]: ... - def __copy__(self) -> ImmutableTypeConversionDict: ... - -class MultiDict(TypeConversionDict[K, V]): - def __init__( - self, - mapping: Mapping[K, Iterable[V] | V] | Iterable[tuple[K, V]] | None = None, - ) -> None: ... - def __getitem__(self, item: K) -> V: ... - def __setitem__(self, key: K, value: V) -> None: ... - def add(self, key: K, value: V) -> None: ... - @overload - def getlist(self, key: K) -> list[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... - def setlist(self, key: K, new_list: Iterable[V]) -> None: ... - def setdefault(self, key: K, default: V | None = None) -> V: ... - def setlistdefault( - self, key: K, default_list: Iterable[V] | None = None - ) -> list[V]: ... - def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore - def lists(self) -> Iterator[tuple[K, list[V]]]: ... - def values(self) -> Iterator[V]: ... # type: ignore - def listvalues(self) -> Iterator[list[V]]: ... - def copy(self) -> MultiDict[K, V]: ... - def deepcopy(self, memo: Any = None) -> MultiDict[K, V]: ... - @overload - def to_dict(self) -> dict[K, V]: ... - @overload - def to_dict(self, flat: Literal[False]) -> dict[K, list[V]]: ... - def update( # type: ignore - self, mapping: Mapping[K, Iterable[V] | V] | Iterable[tuple[K, V]] - ) -> None: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: V | T = ...) -> V | T: ... - def popitem(self) -> tuple[K, V]: ... - def poplist(self, key: K) -> list[V]: ... - def popitemlist(self) -> tuple[K, list[V]]: ... - def __copy__(self) -> MultiDict[K, V]: ... - def __deepcopy__(self, memo: Any) -> MultiDict[K, V]: ... - -class _omd_bucket(Generic[K, V]): - prev: _omd_bucket | None - next: _omd_bucket | None - key: K - value: V - def __init__(self, omd: OrderedMultiDict, key: K, value: V) -> None: ... - def unlink(self, omd: OrderedMultiDict) -> None: ... - -class OrderedMultiDict(MultiDict[K, V]): - _first_bucket: _omd_bucket | None - _last_bucket: _omd_bucket | None - def __init__(self, mapping: Mapping[K, V] | None = None) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __getitem__(self, key: K) -> V: ... - def __setitem__(self, key: K, value: V) -> None: ... - def __delitem__(self, key: K) -> None: ... - def keys(self) -> Iterator[K]: ... # type: ignore - def __iter__(self) -> Iterator[K]: ... - def values(self) -> Iterator[V]: ... # type: ignore - def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore - def lists(self) -> Iterator[tuple[K, list[V]]]: ... - def listvalues(self) -> Iterator[list[V]]: ... - def add(self, key: K, value: V) -> None: ... - @overload - def getlist(self, key: K) -> list[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... - def setlist(self, key: K, new_list: Iterable[V]) -> None: ... - def setlistdefault( - self, key: K, default_list: Iterable[V] | None = None - ) -> list[V]: ... - def update( # type: ignore - self, mapping: Mapping[K, V] | Iterable[tuple[K, V]] - ) -> None: ... - def poplist(self, key: K) -> list[V]: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: V | T = ...) -> V | T: ... - def popitem(self) -> tuple[K, V]: ... - def popitemlist(self) -> tuple[K, list[V]]: ... - -class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore - dicts: list[MultiDict[K, V]] - def __init__(self, dicts: Iterable[MultiDict[K, V]] | None) -> None: ... - @classmethod - def fromkeys(cls, keys: Any, value: Any = None) -> NoReturn: ... - def __getitem__(self, key: K) -> V: ... - @overload # type: ignore - def get(self, key: K) -> V | None: ... - @overload - def get(self, key: K, default: V | T = ...) -> V | T: ... - @overload - def get( - self, key: K, default: T | None = None, type: Callable[[V], T] = ... - ) -> T | None: ... - @overload - def getlist(self, key: K) -> list[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... - def _keys_impl(self) -> set[K]: ... - def keys(self) -> set[K]: ... # type: ignore - def __iter__(self) -> set[K]: ... # type: ignore - def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore - def values(self) -> Iterator[V]: ... # type: ignore - def lists(self) -> Iterator[tuple[K, list[V]]]: ... - def listvalues(self) -> Iterator[list[V]]: ... - def copy(self) -> MultiDict[K, V]: ... - @overload - def to_dict(self) -> dict[K, V]: ... - @overload - def to_dict(self, flat: Literal[False]) -> dict[K, list[V]]: ... - def __contains__(self, key: K) -> bool: ... # type: ignore - def has_key(self, key: K) -> bool: ... - -class ImmutableDict(ImmutableDictMixin[K, V], dict[K, V]): - def copy(self) -> dict[K, V]: ... - def __copy__(self) -> ImmutableDict[K, V]: ... - -class ImmutableMultiDict( # type: ignore - ImmutableMultiDictMixin[K, V], MultiDict[K, V] -): - def copy(self) -> MultiDict[K, V]: ... - def __copy__(self) -> ImmutableMultiDict[K, V]: ... - -class ImmutableOrderedMultiDict( # type: ignore - ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V] -): - def _iter_hashitems(self) -> Iterator[tuple[int, tuple[K, V]]]: ... - def copy(self) -> OrderedMultiDict[K, V]: ... - def __copy__(self) -> ImmutableOrderedMultiDict[K, V]: ... - -class CallbackDict(UpdateDictMixin[K, V], dict[K, V]): - def __init__( - self, - initial: Mapping[K, V] | Iterable[tuple[K, V]] | None = None, - on_update: Callable[[_CD], None] | None = None, - ) -> None: ... - -class HeaderSet(set[str]): - _headers: list[str] - _set: set[str] - on_update: Callable[[HeaderSet], None] | None - def __init__( - self, - headers: Iterable[str] | None = None, - on_update: Callable[[HeaderSet], None] | None = None, - ) -> None: ... - def add(self, header: str) -> None: ... - def remove(self, header: str) -> None: ... - def update(self, iterable: Iterable[str]) -> None: ... # type: ignore - def discard(self, header: str) -> None: ... - def find(self, header: str) -> int: ... - def index(self, header: str) -> int: ... - def clear(self) -> None: ... - def as_set(self, preserve_casing: bool = False) -> set[str]: ... - def to_header(self) -> str: ... - def __getitem__(self, idx: int) -> str: ... - def __delitem__(self, idx: int) -> None: ... - def __setitem__(self, idx: int, value: str) -> None: ... - def __contains__(self, header: str) -> bool: ... # type: ignore - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[str]: ... diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py index 3b04b534e..0c4cabd89 100644 --- a/src/werkzeug/debug/__init__.py +++ b/src/werkzeug/debug/__init__.py @@ -13,13 +13,16 @@ from contextlib import ExitStack from io import BytesIO from itertools import chain +from multiprocessing import Value from os.path import basename from os.path import join from zlib import adler32 from .._internal import _log from ..exceptions import NotFound +from ..exceptions import SecurityError from ..http import parse_cookie +from ..sansio.utils import host_is_trusted from ..security import gen_salt from ..utils import send_file from ..wrappers.request import Request @@ -82,7 +85,8 @@ def _generate() -> str | bytes | None: try: # subprocess may not be available, e.g. Google App Engine # https://github.com/pallets/werkzeug/issues/925 - from subprocess import Popen, PIPE + from subprocess import PIPE + from subprocess import Popen dump = Popen( ["ioreg", "-c", "IOPlatformExpertDevice", "-d", "2"], stdout=PIPE @@ -110,7 +114,7 @@ def _generate() -> str | bytes | None: guid, guid_type = winreg.QueryValueEx(rk, "MachineGuid") if guid_type == winreg.REG_SZ: - return guid.encode("utf-8") + return guid.encode() return guid except OSError: @@ -169,7 +173,8 @@ def get_pin_and_cookie_name( # App Engine. It may also raise a KeyError if the UID does not # have a username, such as in Docker. username = getpass.getuser() - except (ImportError, KeyError): + # Python >= 3.13 only raises OSError + except (ImportError, KeyError, OSError): username = None mod = sys.modules.get(modname) @@ -193,7 +198,7 @@ def get_pin_and_cookie_name( if not bit: continue if isinstance(bit, str): - bit = bit.encode("utf-8") + bit = bit.encode() h.update(bit) h.update(b"cookiesalt") @@ -283,7 +288,7 @@ def __init__( self.console_init_func = console_init_func self.show_hidden_frames = show_hidden_frames self.secret = gen_salt(20) - self._failed_pin_auth = 0 + self._failed_pin_auth = Value("B") self.pin_logging = pin_logging if pin_security: @@ -297,6 +302,14 @@ def __init__( else: self.pin = None + self.trusted_hosts: list[str] = [".localhost", "127.0.0.1"] + """List of domains to allow requests to the debugger from. A leading dot + allows all subdomains. This only allows ``".localhost"`` domains by + default. + + .. versionadded:: 3.0.3 + """ + @property def pin(self) -> str | None: if not hasattr(self, "_pin"): @@ -343,7 +356,7 @@ def debug_application( is_trusted = bool(self.check_pin_trust(environ)) html = tb.render_debugger_html( - evalex=self.evalex, + evalex=self.evalex and self.check_host_trust(environ), secret=self.secret, evalex_trusted=is_trusted, ) @@ -364,13 +377,16 @@ def debug_application( environ["wsgi.errors"].write("".join(tb.render_traceback_text())) - def execute_command( # type: ignore[return] + def execute_command( self, request: Request, command: str, frame: DebugFrameSummary | _ConsoleFrame, ) -> Response: """Execute a command in a console.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + contexts = self.frame_contexts.get(id(frame), []) with ExitStack() as exit_stack: @@ -381,6 +397,9 @@ def execute_command( # type: ignore[return] def display_console(self, request: Request) -> Response: """Display a standalone shell.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + if 0 not in self.frames: if self.console_init_func is None: ns = {} @@ -433,12 +452,21 @@ def check_pin_trust(self, environ: WSGIEnvironment) -> bool | None: return None return (time.time() - PIN_TIME) < ts + def check_host_trust(self, environ: WSGIEnvironment) -> bool: + return host_is_trusted(environ.get("HTTP_HOST"), self.trusted_hosts) + def _fail_pin_auth(self) -> None: - time.sleep(5.0 if self._failed_pin_auth > 5 else 0.5) - self._failed_pin_auth += 1 + with self._failed_pin_auth.get_lock(): + count = self._failed_pin_auth.value + self._failed_pin_auth.value = count + 1 + + time.sleep(5.0 if count > 5 else 0.5) def pin_auth(self, request: Request) -> Response: """Authenticates with the pin.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + exhausted = False auth = False trust = self.check_pin_trust(request.environ) @@ -459,7 +487,7 @@ def pin_auth(self, request: Request) -> Response: auth = True # If we failed too many times, then we're locked out. - elif self._failed_pin_auth > 10: + elif self._failed_pin_auth.value > 10: exhausted = True # Otherwise go through pin based authentication @@ -467,7 +495,7 @@ def pin_auth(self, request: Request) -> Response: entered_pin = request.args["pin"] if entered_pin.strip().replace("-", "") == pin.replace("-", ""): - self._failed_pin_auth = 0 + self._failed_pin_auth.value = 0 auth = True else: self._fail_pin_auth() @@ -488,8 +516,11 @@ def pin_auth(self, request: Request) -> Response: rv.delete_cookie(self.pin_cookie_name) return rv - def log_pin_request(self) -> Response: + def log_pin_request(self, request: Request) -> Response: """Log the pin if needed.""" + if not self.check_host_trust(request.environ): + return SecurityError() # type: ignore[return-value] + if self.pin_logging and self.pin is not None: _log( "info", " * To enable the debugger you need to enter the security pin:" @@ -516,7 +547,7 @@ def __call__( elif cmd == "pinauth" and secret == self.secret: response = self.pin_auth(request) # type: ignore elif cmd == "printpin" and secret == self.secret: - response = self.log_pin_request() # type: ignore + response = self.log_pin_request(request) # type: ignore elif ( self.evalex and cmd is not None diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py index 03ddc07f2..4e40475a5 100644 --- a/src/werkzeug/debug/console.py +++ b/src/werkzeug/debug/console.py @@ -13,7 +13,7 @@ from .repr import helper _stream: ContextVar[HTMLStringO] = ContextVar("werkzeug.debug.console.stream") -_ipy: ContextVar = ContextVar("werkzeug.debug.console.ipy") +_ipy: ContextVar[_InteractiveConsole] = ContextVar("werkzeug.debug.console.ipy") class HTMLStringO: diff --git a/src/werkzeug/debug/repr.py b/src/werkzeug/debug/repr.py index 3bf15a77a..2bbd9d546 100644 --- a/src/werkzeug/debug/repr.py +++ b/src/werkzeug/debug/repr.py @@ -4,6 +4,7 @@ Together with the CSS and JavaScript of the debugger this gives a colorful and more compact output. """ + from __future__ import annotations import codecs @@ -80,9 +81,7 @@ def __call__(self, topic: t.Any | None = None) -> None: helper = _Helper() -def _add_subclass_info( - inner: str, obj: object, base: t.Type | tuple[t.Type, ...] -) -> str: +def _add_subclass_info(inner: str, obj: object, base: type | tuple[type, ...]) -> str: if isinstance(base, tuple): for cls in base: if type(obj) is cls: @@ -96,9 +95,9 @@ def _add_subclass_info( def _sequence_repr_maker( - left: str, right: str, base: t.Type, limit: int = 8 -) -> t.Callable[[DebugReprGenerator, t.Iterable, bool], str]: - def proxy(self: DebugReprGenerator, obj: t.Iterable, recursive: bool) -> str: + left: str, right: str, base: type, limit: int = 8 +) -> t.Callable[[DebugReprGenerator, t.Iterable[t.Any], bool], str]: + def proxy(self: DebugReprGenerator, obj: t.Iterable[t.Any], recursive: bool) -> str: if recursive: return _add_subclass_info(f"{left}...{right}", obj, base) buf = [left] @@ -130,7 +129,7 @@ def __init__(self) -> None: 'collections.deque([', "])", deque ) - def regex_repr(self, obj: t.Pattern) -> str: + def regex_repr(self, obj: t.Pattern[t.AnyStr]) -> str: pattern = repr(obj.pattern) pattern = codecs.decode(pattern, "unicode-escape", "ignore") pattern = f"r{pattern}" @@ -188,7 +187,7 @@ def dict_repr( buf.append("}") return _add_subclass_info("".join(buf), d, dict) - def object_repr(self, obj: type[dict] | t.Callable | type[list] | None) -> str: + def object_repr(self, obj: t.Any) -> str: r = repr(obj) return f'{escape(r)}' diff --git a/src/werkzeug/debug/shared/debugger.js b/src/werkzeug/debug/shared/debugger.js index f463e9c77..809b14a6e 100644 --- a/src/werkzeug/debug/shared/debugger.js +++ b/src/werkzeug/debug/shared/debugger.js @@ -37,18 +37,22 @@ function wrapPlainTraceback() { plainTraceback.replaceWith(wrapper); } +function makeDebugURL(args) { + const params = new URLSearchParams(args) + params.set("s", SECRET) + return `?__debugger__=yes&${params}` +} + function initPinBox() { document.querySelector(".pin-prompt form").addEventListener( "submit", function (event) { event.preventDefault(); - const pin = encodeURIComponent(this.pin.value); - const encodedSecret = encodeURIComponent(SECRET); const btn = this.btn; btn.disabled = true; fetch( - `${document.location.pathname}?__debugger__=yes&cmd=pinauth&pin=${pin}&s=${encodedSecret}` + makeDebugURL({cmd: "pinauth", pin: this.pin.value}) ) .then((res) => res.json()) .then(({auth, exhausted}) => { @@ -77,10 +81,7 @@ function initPinBox() { function promptForPin() { if (!EVALEX_TRUSTED) { - const encodedSecret = encodeURIComponent(SECRET); - fetch( - `${document.location.pathname}?__debugger__=yes&cmd=printpin&s=${encodedSecret}` - ); + fetch(makeDebugURL({cmd: "printpin"})); const pinPrompt = document.getElementsByClassName("pin-prompt")[0]; fadeIn(pinPrompt); document.querySelector('.pin-prompt input[name="pin"]').focus(); @@ -237,7 +238,7 @@ function createConsoleInput() { function createIconForConsole() { const img = document.createElement("img"); - img.setAttribute("src", "?__debugger__=yes&cmd=resource&f=console.png"); + img.setAttribute("src", makeDebugURL({cmd: "resource", f: "console.png"})); img.setAttribute("title", "Open an interactive python shell in this frame"); return img; } @@ -263,24 +264,7 @@ function handleConsoleSubmit(e, command, frameID) { e.preventDefault(); return new Promise((resolve) => { - // Get input command. - const cmd = command.value; - - // Setup GET request. - const urlPath = ""; - const params = { - __debugger__: "yes", - cmd: cmd, - frm: frameID, - s: SECRET, - }; - const paramString = Object.keys(params) - .map((key) => { - return "&" + encodeURIComponent(key) + "=" + encodeURIComponent(params[key]); - }) - .join(""); - - fetch(urlPath + "?" + paramString) + fetch(makeDebugURL({cmd: command.value, frm: frameID})) .then((res) => { return res.text(); }) diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py index c45f56ef0..d922893ea 100644 --- a/src/werkzeug/debug/tbtools.py +++ b/src/werkzeug/debug/tbtools.py @@ -185,9 +185,9 @@ def _process_traceback( "globals": f.f_globals, } - if hasattr(fs, "colno"): + if sys.version_info >= (3, 11): frame_args["colno"] = fs.colno - frame_args["end_colno"] = fs.end_colno # type: ignore[attr-defined] + frame_args["end_colno"] = fs.end_colno new_stack.append(DebugFrameSummary(**frame_args)) @@ -265,7 +265,9 @@ def all_tracebacks( @cached_property def all_frames(self) -> list[DebugFrameSummary]: return [ - f for _, te in self.all_tracebacks for f in te.stack # type: ignore[misc] + f # type: ignore[misc] + for _, te in self.all_tracebacks + for f in te.stack ] def render_traceback_text(self) -> str: @@ -294,7 +296,12 @@ def render_traceback_html(self, include_title: bool = True) -> str: rows.append("\n".join(row_parts)) - is_syntax_error = issubclass(self._te.exc_type, SyntaxError) + if sys.version_info < (3, 13): + exc_type_str = self._te.exc_type.__name__ + else: + exc_type_str = self._te.exc_type_str + + is_syntax_error = exc_type_str == "SyntaxError" if include_title: if is_syntax_error: @@ -323,13 +330,19 @@ def render_debugger_html( ) -> str: exc_lines = list(self._te.format_exception_only()) plaintext = "".join(self._te.format()) + + if sys.version_info < (3, 13): + exc_type_str = self._te.exc_type.__name__ + else: + exc_type_str = self._te.exc_type_str + return PAGE_HTML % { "evalex": "true" if evalex else "false", "evalex_trusted": "true" if evalex_trusted else "false", "console": "false", "title": escape(exc_lines[0]), "exception": escape("".join(exc_lines)), - "exception_type": escape(self._te.exc_type.__name__), + "exception_type": escape(exc_type_str), "summary": self.render_traceback_html(include_title=False), "plaintext": escape(plaintext), "plaintext_cs": re.sub("-{2,}", "-", plaintext), diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py index 253612918..1cd999773 100644 --- a/src/werkzeug/exceptions.py +++ b/src/werkzeug/exceptions.py @@ -43,6 +43,7 @@ def application(request): return e """ + from __future__ import annotations import typing as t @@ -56,6 +57,7 @@ def application(request): if t.TYPE_CHECKING: from _typeshed.wsgi import StartResponse from _typeshed.wsgi import WSGIEnvironment + from .datastructures import WWWAuthenticate from .sansio.response import Response from .wrappers.request import Request as WSGIRequest @@ -94,7 +96,7 @@ def name(self) -> str: def get_description( self, environ: WSGIEnvironment | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> str: """Get the description.""" if self.description is None: @@ -108,7 +110,7 @@ def get_description( def get_body( self, environ: WSGIEnvironment | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> str: """Get the HTML body.""" return ( @@ -122,7 +124,7 @@ def get_body( def get_headers( self, environ: WSGIEnvironment | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> list[tuple[str, str]]: """Get a list of headers.""" return [("Content-Type", "text/html; charset=utf-8")] @@ -130,7 +132,7 @@ def get_headers( def get_response( self, environ: WSGIEnvironment | WSGIRequest | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> Response: """Get a response object. If one was passed to the exception it's returned directly. @@ -195,7 +197,7 @@ class BadRequestKeyError(BadRequest, KeyError): #: useful in a debug mode. show_exception = False - def __init__(self, arg: str | None = None, *args: t.Any, **kwargs: t.Any): + def __init__(self, arg: object | None = None, *args: t.Any, **kwargs: t.Any): super().__init__(*args, **kwargs) if arg is None: @@ -312,7 +314,7 @@ def __init__( def get_headers( self, environ: WSGIEnvironment | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.www_authenticate: @@ -376,7 +378,7 @@ def __init__( def get_headers( self, environ: WSGIEnvironment | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.valid_methods: @@ -536,7 +538,7 @@ def __init__( def get_headers( self, environ: WSGIEnvironment | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.length is not None: @@ -569,6 +571,19 @@ class ImATeapot(HTTPException): description = "This server is a teapot, not a coffee machine" +class MisdirectedRequest(HTTPException): + """421 Misdirected Request + + Indicates that the request was directed to a server that is not able to + produce a response. + + .. versionadded:: 3.1 + """ + + code = 421 + description = "The server is not able to produce a response." + + class UnprocessableEntity(HTTPException): """*422* `Unprocessable Entity` @@ -645,7 +660,7 @@ def __init__( def get_headers( self, environ: WSGIEnvironment | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py index 25ef0d61b..010341497 100644 --- a/src/werkzeug/formparser.py +++ b/src/werkzeug/formparser.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing as t -import warnings from io import BytesIO from urllib.parse import parse_qsl @@ -31,9 +30,12 @@ if t.TYPE_CHECKING: import typing as te + from _typeshed.wsgi import WSGIEnvironment - t_parse_result = t.Tuple[t.IO[bytes], MultiDict, MultiDict] + t_parse_result = tuple[ + t.IO[bytes], MultiDict[str, str], MultiDict[str, FileStorage] + ] class TStreamFactory(te.Protocol): def __call__( @@ -42,8 +44,7 @@ def __call__( content_type: str | None, filename: str | None, content_length: int | None = None, - ) -> t.IO[bytes]: - ... + ) -> t.IO[bytes]: ... F = t.TypeVar("F", bound=t.Callable[..., t.Any]) @@ -68,11 +69,9 @@ def default_stream_factory( def parse_form_data( environ: WSGIEnvironment, stream_factory: TStreamFactory | None = None, - charset: str | None = None, - errors: str | None = None, max_form_memory_size: int | None = None, max_content_length: int | None = None, - cls: type[MultiDict] | None = None, + cls: type[MultiDict[str, t.Any]] | None = None, silent: bool = True, *, max_form_parts: int | None = None, @@ -108,12 +107,11 @@ def parse_form_data( is exceeded, a :exc:`~exceptions.RequestEntityTooLarge` exception is raised. :return: A tuple in the form ``(stream, form, files)``. - .. versionchanged:: 2.3 - Added the ``max_form_parts`` parameter. + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. .. versionchanged:: 2.3 - The ``charset`` and ``errors`` parameters are deprecated and will be removed in - Werkzeug 3.0. + Added the ``max_form_parts`` parameter. .. versionadded:: 0.5.1 Added the ``silent`` parameter. @@ -124,8 +122,6 @@ def parse_form_data( """ return FormDataParser( stream_factory=stream_factory, - charset=charset, - errors=errors, max_form_memory_size=max_form_memory_size, max_content_length=max_content_length, max_form_parts=max_form_parts, @@ -159,13 +155,11 @@ class FormDataParser: :param max_form_parts: The maximum number of multipart parts to be parsed. If this is exceeded, a :exc:`~exceptions.RequestEntityTooLarge` exception is raised. - .. versionchanged:: 2.3 - The ``charset`` and ``errors`` parameters are deprecated and will be removed in - Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. - .. versionchanged:: 2.3 - The ``parse_functions`` attribute and ``get_parse_func`` methods are deprecated - and will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``parse_functions`` attribute and ``get_parse_func`` methods were removed. .. versionchanged:: 2.2.3 Added the ``max_form_parts`` parameter. @@ -176,11 +170,9 @@ class FormDataParser: def __init__( self, stream_factory: TStreamFactory | None = None, - charset: str | None = None, - errors: str | None = None, max_form_memory_size: int | None = None, max_content_length: int | None = None, - cls: type[MultiDict] | None = None, + cls: type[MultiDict[str, t.Any]] | None = None, silent: bool = True, *, max_form_parts: int | None = None, @@ -189,78 +181,16 @@ def __init__( stream_factory = default_stream_factory self.stream_factory = stream_factory - - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - self.charset = charset - - if errors is not None: - warnings.warn( - "The 'errors' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - errors = "replace" - - self.errors = errors self.max_form_memory_size = max_form_memory_size self.max_content_length = max_content_length self.max_form_parts = max_form_parts if cls is None: - cls = MultiDict + cls = t.cast("type[MultiDict[str, t.Any]]", MultiDict) self.cls = cls self.silent = silent - def get_parse_func( - self, mimetype: str, options: dict[str, str] - ) -> None | ( - t.Callable[ - [FormDataParser, t.IO[bytes], str, int | None, dict[str, str]], - t_parse_result, - ] - ): - warnings.warn( - "The 'get_parse_func' method is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - if mimetype == "multipart/form-data": - return type(self)._parse_multipart - elif mimetype == "application/x-www-form-urlencoded": - return type(self)._parse_urlencoded - elif mimetype == "application/x-url-encoded": - warnings.warn( - "The 'application/x-url-encoded' mimetype is invalid, and will not be" - " treated as 'application/x-www-form-urlencoded' in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - return type(self)._parse_urlencoded - elif mimetype in self.parse_functions: - warnings.warn( - "The 'parse_functions' attribute is deprecated and will be removed in" - " Werkzeug 3.0. Override 'parse' instead.", - DeprecationWarning, - stacklevel=2, - ) - return self.parse_functions[mimetype] - - return None - def parse_from_environ(self, environ: WSGIEnvironment) -> t_parse_result: """Parses the information from the environment as form data. @@ -294,30 +224,14 @@ def parse( the multipart boundary for instance) :return: A tuple in the form ``(stream, form, files)``. - .. versionchanged:: 2.3 - The ``application/x-url-encoded`` content type is deprecated and will not be - treated as ``application/x-www-form-urlencoded`` in Werkzeug 3.0. + .. versionchanged:: 3.0 + The invalid ``application/x-url-encoded`` content type is not + treated as ``application/x-www-form-urlencoded``. """ if mimetype == "multipart/form-data": parse_func = self._parse_multipart elif mimetype == "application/x-www-form-urlencoded": parse_func = self._parse_urlencoded - elif mimetype == "application/x-url-encoded": - warnings.warn( - "The 'application/x-url-encoded' mimetype is invalid, and will not be" - " treated as 'application/x-www-form-urlencoded' in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - parse_func = self._parse_urlencoded - elif mimetype in self.parse_functions: - warnings.warn( - "The 'parse_functions' attribute is deprecated and will be removed in" - " Werkzeug 3.0. Override 'parse' instead.", - DeprecationWarning, - stacklevel=2, - ) - parse_func = self.parse_functions[mimetype].__get__(self, type(self)) else: return stream, self.cls(), self.cls() @@ -339,12 +253,8 @@ def _parse_multipart( content_length: int | None, options: dict[str, str], ) -> t_parse_result: - charset = self.charset if self.charset != "utf-8" else None - errors = self.errors if self.errors != "replace" else None parser = MultiPartParser( stream_factory=self.stream_factory, - charset=charset, - errors=errors, max_form_memory_size=self.max_form_memory_size, max_form_parts=self.max_form_parts, cls=self.cls, @@ -371,61 +281,23 @@ def _parse_urlencoded( ): raise RequestEntityTooLarge() - try: - items = parse_qsl( - stream.read().decode(), - keep_blank_values=True, - encoding=self.charset, - errors="werkzeug.url_quote", - ) - except ValueError as e: - raise RequestEntityTooLarge() from e - + items = parse_qsl( + stream.read().decode(), + keep_blank_values=True, + errors="werkzeug.url_quote", + ) return stream, self.cls(items), self.cls() - parse_functions: dict[ - str, - t.Callable[ - [FormDataParser, t.IO[bytes], str, int | None, dict[str, str]], - t_parse_result, - ], - ] = {} - class MultiPartParser: def __init__( self, stream_factory: TStreamFactory | None = None, - charset: str | None = None, - errors: str | None = None, max_form_memory_size: int | None = None, - cls: type[MultiDict] | None = None, + cls: type[MultiDict[str, t.Any]] | None = None, buffer_size: int = 64 * 1024, max_form_parts: int | None = None, ) -> None: - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - self.charset = charset - - if errors is not None: - warnings.warn( - "The 'errors' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - errors = "replace" - - self.errors = errors self.max_form_memory_size = max_form_memory_size self.max_form_parts = max_form_parts @@ -435,7 +307,7 @@ def __init__( self.stream_factory = stream_factory if cls is None: - cls = MultiDict + cls = t.cast("type[MultiDict[str, t.Any]]", MultiDict) self.cls = cls self.buffer_size = buffer_size @@ -456,7 +328,7 @@ def get_part_charset(self, headers: Headers) -> str: if ct_charset in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: return ct_charset - return self.charset + return "utf-8" def start_file_streaming( self, event: File, total_content_length: int | None @@ -478,8 +350,9 @@ def start_file_streaming( def parse( self, stream: t.IO[bytes], boundary: bytes, content_length: int | None - ) -> tuple[MultiDict, MultiDict]: + ) -> tuple[MultiDict[str, str], MultiDict[str, FileStorage]]: current_part: Field | File + field_size: int | None = None container: t.IO[bytes] | list[bytes] _write: t.Callable[[bytes], t.Any] @@ -498,18 +371,28 @@ def parse( while not isinstance(event, (Epilogue, NeedData)): if isinstance(event, Field): current_part = event + field_size = 0 container = [] _write = container.append elif isinstance(event, File): current_part = event + field_size = None container = self.start_file_streaming(event, content_length) _write = container.write elif isinstance(event, Data): + if self.max_form_memory_size is not None and field_size is not None: + # Ensure that accumulated data events do not exceed limit. + # Also checked within single event in MultipartDecoder. + field_size += len(event.data) + + if field_size > self.max_form_memory_size: + raise RequestEntityTooLarge() + _write(event.data) if not event.more_data: if isinstance(current_part, Field): value = b"".join(container).decode( - self.get_part_charset(current_part.headers), self.errors + self.get_part_charset(current_part.headers), "replace" ) fields.append((current_part.name, value)) else: diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py index 07d1fd496..f1dbb850a 100644 --- a/src/werkzeug/http.py +++ b/src/werkzeug/http.py @@ -136,11 +136,7 @@ class COOP(Enum): SAME_ORIGIN = "same-origin" -def quote_header_value( - value: t.Any, - extra_chars: str | None = None, - allow_token: bool = True, -) -> str: +def quote_header_value(value: t.Any, allow_token: bool = True) -> str: """Add double quotes around a header value. If the header contains only ASCII token characters, it will be returned unchanged. If the header contains ``"`` or ``\\`` characters, they will be escaped with an additional ``\\`` character. @@ -150,52 +146,33 @@ def quote_header_value( :param value: The value to quote. Will be converted to a string. :param allow_token: Disable to quote the value even if it only has token characters. - .. versionchanged:: 2.3 - The value is quoted if it is the empty string. + .. versionchanged:: 3.0 + Passing bytes is not supported. - .. versionchanged:: 2.3 - Passing bytes is deprecated and will not be supported in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``extra_chars`` parameter is removed. .. versionchanged:: 2.3 - The ``extra_chars`` parameter is deprecated and will be removed in Werkzeug 3.0. + The value is quoted if it is the empty string. .. versionadded:: 0.5 """ - if isinstance(value, bytes): - warnings.warn( - "Passing bytes is deprecated and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - value = value.decode("latin1") + value_str = str(value) - if extra_chars is not None: - warnings.warn( - "The 'extra_chars' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - value = str(value) - - if not value: + if not value_str: return '""' if allow_token: token_chars = _token_chars - if extra_chars: - token_chars |= set(extra_chars) + if token_chars.issuperset(value_str): + return value_str - if token_chars.issuperset(value): - return value + value_str = value_str.replace("\\", "\\\\").replace('"', '\\"') + return f'"{value_str}"' - value = value.replace("\\", "\\\\").replace('"', '\\"') - return f'"{value}"' - -def unquote_header_value(value: str, is_filename: bool | None = None) -> str: +def unquote_header_value(value: str) -> str: """Remove double quotes and decode slash-escaped ``"`` and ``\\`` characters in a header value. @@ -203,22 +180,12 @@ def unquote_header_value(value: str, is_filename: bool | None = None) -> str: :param value: The header value to unquote. - .. versionchanged:: 2.3 - The ``is_filename`` parameter is deprecated and will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``is_filename`` parameter is removed. """ - if is_filename is not None: - warnings.warn( - "The 'is_filename' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - if len(value) >= 2 and value[0] == value[-1] == '"': value = value[1:-1] - - if not is_filename: - return value.replace("\\\\", "\\").replace('\\"', '"') + return value.replace("\\\\", "\\").replace('\\"', '"') return value @@ -269,10 +236,7 @@ def dump_options_header(header: str | None, options: t.Mapping[str, t.Any]) -> s return "; ".join(segments) -def dump_header( - iterable: dict[str, t.Any] | t.Iterable[t.Any], - allow_token: bool | None = None, -) -> str: +def dump_header(iterable: dict[str, t.Any] | t.Iterable[t.Any]) -> str: """Produce a header value from a list of items or ``key=value`` pairs, separated by commas ``,``. @@ -298,22 +262,12 @@ def dump_header( :param iterable: The items to create a header from. - .. versionchanged:: 2.3 - The ``allow_token`` parameter is deprecated and will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``allow_token`` parameter is removed. .. versionchanged:: 2.2.3 If a key ends with ``*``, its value will not be quoted. """ - if allow_token is not None: - warnings.warn( - "'The 'allow_token' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - allow_token = True - if isinstance(iterable, dict): items = [] @@ -323,11 +277,9 @@ def dump_header( elif key[-1] == "*": items.append(f"{key}={value}") else: - items.append( - f"{key}={quote_header_value(value, allow_token=allow_token)}" - ) + items.append(f"{key}={quote_header_value(value)}") else: - items = [quote_header_value(x, allow_token=allow_token) for x in iterable] + items = [quote_header_value(x) for x in iterable] return ", ".join(items) @@ -372,7 +324,7 @@ def parse_list_header(value: str) -> list[str]: return result -def parse_dict_header(value: str, cls: type[dict] | None = None) -> dict[str, str]: +def parse_dict_header(value: str) -> dict[str, str | None]: """Parse a list header using :func:`parse_list_header`, then parse each item as a ``key=value`` pair. @@ -391,41 +343,28 @@ def parse_dict_header(value: str, cls: type[dict] | None = None) -> dict[str, st :param value: The header value to parse. - .. versionchanged:: 2.3 - Added support for ``key*=charset''value`` encoded items. + .. versionchanged:: 3.0 + Passing bytes is not supported. - .. versionchanged:: 2.3 - Passing bytes is deprecated, support will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``cls`` argument is removed. .. versionchanged:: 2.3 - The ``cls`` argument is deprecated and will be removed in Werkzeug 3.0. + Added support for ``key*=charset''value`` encoded items. .. versionchanged:: 0.9 The ``cls`` argument was added. """ - if cls is None: - cls = dict - else: - warnings.warn( - "The 'cls' parameter is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - result = cls() - - if isinstance(value, bytes): - warnings.warn( - "Passing bytes is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - value = value.decode("latin1") + result: dict[str, str | None] = {} for item in parse_list_header(value): key, has_value, value = item.partition("=") key = key.strip() + if not key: + # =value is not valid + continue + if not has_value: result[key] = None continue @@ -460,22 +399,8 @@ def parse_dict_header(value: str, cls: type[dict] | None = None) -> dict[str, st # https://httpwg.org/specs/rfc9110.html#parameter -_parameter_re = re.compile( - r""" - # don't match multiple empty parts, that causes backtracking - \s*;\s* # find the part delimiter - (?: - ([\w!#$%&'*+\-.^`|~]+) # key, one or more token chars - = # equals, with no space on either side - ( # value, token or quoted string - [\w!#$%&'*+\-.^`|~]+ # one or more token chars - | - "(?:\\\\|\\"|.)*?" # quoted string, consuming slash escapes - ) - )? # optionally match key=value, to account for empty parts - """, - re.ASCII | re.VERBOSE, -) +_parameter_key_re = re.compile(r"([\w!#$%&'*+\-.^`|~]+)=", flags=re.ASCII) +_parameter_token_value_re = re.compile(r"[\w!#$%&'*+\-.^`|~]+", flags=re.ASCII) # https://www.rfc-editor.org/rfc/rfc2231#section-4 _charset_value_re = re.compile( r""" @@ -557,18 +482,49 @@ def parse_options_header(value: str | None) -> tuple[str, dict[str, str]]: # empty (invalid) value, or value without options return value, {} - rest = f";{rest}" + # Collect all valid key=value parts without processing the value. + parts: list[tuple[str, str]] = [] + + while True: + if (m := _parameter_key_re.match(rest)) is not None: + pk = m.group(1).lower() + rest = rest[m.end() :] + + # Value may be a token. + if (m := _parameter_token_value_re.match(rest)) is not None: + parts.append((pk, m.group())) + + # Value may be a quoted string, find the closing quote. + elif rest[:1] == '"': + pos = 1 + length = len(rest) + + while pos < length: + if rest[pos : pos + 2] in {"\\\\", '\\"'}: + # Consume escaped slashes and quotes. + pos += 2 + elif rest[pos] == '"': + # Stop at an unescaped quote. + parts.append((pk, rest[: pos + 1])) + rest = rest[pos + 1 :] + break + else: + # Consume any other character. + pos += 1 + + # Find the next section delimited by `;`, if any. + if (end := rest.find(";")) == -1: + break + + rest = rest[end + 1 :].lstrip() + options: dict[str, str] = {} encoding: str | None = None continued_encoding: str | None = None - for pk, pv in _parameter_re.findall(rest): - if not pk: - # empty or invalid part - continue - - pk = pk.lower() - + # For each collected part, process optional charset and continuation, + # unquote quoted values. + for pk, pv in parts: if pk[-1] == "*": # key*=charset''value becomes key=value, where value is percent encoded pk = pk[:-1] @@ -618,13 +574,11 @@ def parse_options_header(value: str | None) -> tuple[str, dict[str, str]]: @t.overload -def parse_accept_header(value: str | None) -> ds.Accept: - ... +def parse_accept_header(value: str | None) -> ds.Accept: ... @t.overload -def parse_accept_header(value: str | None, cls: type[_TAnyAccept]) -> _TAnyAccept: - ... +def parse_accept_header(value: str | None, cls: type[_TAnyAccept]) -> _TAnyAccept: ... def parse_accept_header( @@ -645,7 +599,7 @@ def parse_accept_header( Parse according to RFC 9110. Items with invalid ``q`` values are skipped. """ if cls is None: - cls = t.cast(t.Type[_TAnyAccept], ds.Accept) + cls = t.cast(type[_TAnyAccept], ds.Accept) if not value: return cls(None) @@ -681,26 +635,26 @@ def parse_accept_header( _TAnyCC = t.TypeVar("_TAnyCC", bound="ds.cache_control._CacheControl") -_t_cc_update = t.Optional[t.Callable[[_TAnyCC], None]] @t.overload def parse_cache_control_header( - value: str | None, on_update: _t_cc_update, cls: None = None -) -> ds.RequestCacheControl: - ... + value: str | None, + on_update: t.Callable[[ds.cache_control._CacheControl], None] | None = None, +) -> ds.RequestCacheControl: ... @t.overload def parse_cache_control_header( - value: str | None, on_update: _t_cc_update, cls: type[_TAnyCC] -) -> _TAnyCC: - ... + value: str | None, + on_update: t.Callable[[ds.cache_control._CacheControl], None] | None = None, + cls: type[_TAnyCC] = ..., +) -> _TAnyCC: ... def parse_cache_control_header( value: str | None, - on_update: _t_cc_update = None, + on_update: t.Callable[[ds.cache_control._CacheControl], None] | None = None, cls: type[_TAnyCC] | None = None, ) -> _TAnyCC: """Parse a cache control header. The RFC differs between response and @@ -720,7 +674,7 @@ def parse_cache_control_header( :return: a `cls` object. """ if cls is None: - cls = t.cast(t.Type[_TAnyCC], ds.RequestCacheControl) + cls = t.cast("type[_TAnyCC]", ds.RequestCacheControl) if not value: return cls((), on_update) @@ -729,26 +683,26 @@ def parse_cache_control_header( _TAnyCSP = t.TypeVar("_TAnyCSP", bound="ds.ContentSecurityPolicy") -_t_csp_update = t.Optional[t.Callable[[_TAnyCSP], None]] @t.overload def parse_csp_header( - value: str | None, on_update: _t_csp_update, cls: None = None -) -> ds.ContentSecurityPolicy: - ... + value: str | None, + on_update: t.Callable[[ds.ContentSecurityPolicy], None] | None = None, +) -> ds.ContentSecurityPolicy: ... @t.overload def parse_csp_header( - value: str | None, on_update: _t_csp_update, cls: type[_TAnyCSP] -) -> _TAnyCSP: - ... + value: str | None, + on_update: t.Callable[[ds.ContentSecurityPolicy], None] | None = None, + cls: type[_TAnyCSP] = ..., +) -> _TAnyCSP: ... def parse_csp_header( value: str | None, - on_update: _t_csp_update = None, + on_update: t.Callable[[ds.ContentSecurityPolicy], None] | None = None, cls: type[_TAnyCSP] | None = None, ) -> _TAnyCSP: """Parse a Content Security Policy header. @@ -764,7 +718,7 @@ def parse_csp_header( :return: a `cls` object. """ if cls is None: - cls = t.cast(t.Type[_TAnyCSP], ds.ContentSecurityPolicy) + cls = t.cast("type[_TAnyCSP]", ds.ContentSecurityPolicy) if value is None: return cls((), on_update) @@ -815,65 +769,6 @@ def parse_set_header( return ds.HeaderSet(parse_list_header(value), on_update) -def parse_authorization_header( - value: str | None, -) -> ds.Authorization | None: - """Parse an HTTP basic/digest authorization header transmitted by the web - browser. The return value is either `None` if the header was invalid or - not given, otherwise an :class:`~werkzeug.datastructures.Authorization` - object. - - :param value: the authorization header to parse. - :return: a :class:`~werkzeug.datastructures.Authorization` object or `None`. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use :meth:`.Authorization.from_header` instead. - """ - from .datastructures import Authorization - - warnings.warn( - "'parse_authorization_header' is deprecated and will be removed in Werkzeug" - " 2.4. Use 'Authorization.from_header' instead.", - DeprecationWarning, - stacklevel=2, - ) - return Authorization.from_header(value) - - -def parse_www_authenticate_header( - value: str | None, - on_update: t.Callable[[ds.WWWAuthenticate], None] | None = None, -) -> ds.WWWAuthenticate: - """Parse an HTTP WWW-Authenticate header into a - :class:`~werkzeug.datastructures.WWWAuthenticate` object. - - :param value: a WWW-Authenticate header to parse. - :param on_update: an optional callable that is called every time a value - on the :class:`~werkzeug.datastructures.WWWAuthenticate` - object is changed. - :return: a :class:`~werkzeug.datastructures.WWWAuthenticate` object. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use :meth:`.WWWAuthenticate.from_header` - instead. - """ - from .datastructures.auth import WWWAuthenticate - - warnings.warn( - "'parse_www_authenticate_header' is deprecated and will be removed in Werkzeug" - " 2.4. Use 'WWWAuthenticate.from_header' instead.", - DeprecationWarning, - stacklevel=2, - ) - rv = WWWAuthenticate.from_header(value) - - if rv is None: - rv = WWWAuthenticate("basic") - - rv._on_update = on_update - return rv - - def parse_if_range_header(value: str | None) -> ds.IfRange: """Parses an if-range header which can be an etag or a date. Returns a :class:`~werkzeug.datastructures.IfRange` object. @@ -1019,6 +914,10 @@ def quote_etag(etag: str, weak: bool = False) -> str: return etag +@t.overload +def unquote_etag(etag: str) -> tuple[str, bool]: ... +@t.overload +def unquote_etag(etag: None) -> tuple[None, None]: ... def unquote_etag( etag: str | None, ) -> tuple[str, bool] | tuple[None, None]: @@ -1284,9 +1183,7 @@ def is_hop_by_hop_header(header: str) -> bool: def parse_cookie( header: WSGIEnvironment | str | None, - charset: str | None = None, - errors: str | None = None, - cls: type[ds.MultiDict] | None = None, + cls: type[ds.MultiDict[str, str]] | None = None, ) -> ds.MultiDict[str, str]: """Parse a cookie from a string or WSGI environ. @@ -1300,9 +1197,8 @@ def parse_cookie( :param cls: A dict-like class to store the parsed cookies in. Defaults to :class:`MultiDict`. - .. versionchanged:: 2.3 - Passing bytes, and the ``charset`` and ``errors`` parameters, are deprecated and - will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` and ``errors`` parameters, were removed. .. versionchanged:: 1.0 Returns a :class:`MultiDict` instead of a ``TypeConversionDict``. @@ -1313,22 +1209,13 @@ def parse_cookie( """ if isinstance(header, dict): cookie = header.get("HTTP_COOKIE") - elif isinstance(header, bytes): - warnings.warn( - "Passing bytes is deprecated and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - cookie = header.decode() else: cookie = header if cookie: cookie = cookie.encode("latin1").decode() - return _sansio_http.parse_cookie( - cookie=cookie, charset=charset, errors=errors, cls=cls - ) + return _sansio_http.parse_cookie(cookie=cookie, cls=cls) _cookie_no_quote_re = re.compile(r"[\w!#$%&'()*+\-./:<=>?@\[\]^`{|}~]*", re.A) @@ -1349,10 +1236,10 @@ def dump_cookie( domain: str | None = None, secure: bool = False, httponly: bool = False, - charset: str | None = None, sync_expires: bool = True, max_size: int = 4093, samesite: str | None = None, + partitioned: bool = False, ) -> str: """Create a Set-Cookie header without the ``Set-Cookie`` prefix. @@ -1389,9 +1276,17 @@ def dump_cookie( `_. Set to 0 to disable this check. :param samesite: Limits the scope of the cookie such that it will only be attached to requests if those requests are same-site. + :param partitioned: Opts the cookie into partitioned storage. This + will also set secure to True .. _`cookie`: http://browsercookielimits.squawky.net/ + .. versionchanged:: 3.1 + The ``partitioned`` parameter was added. + + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` parameter, were removed. + .. versionchanged:: 2.3.3 The ``path`` parameter is ``/`` by default. @@ -1405,46 +1300,14 @@ def dump_cookie( .. versionchanged:: 2.3 The ``path`` parameter is ``None`` by default. - .. versionchanged:: 2.3 - Passing bytes, and the ``charset`` parameter, are deprecated and will be removed - in Werkzeug 3.0. - .. versionchanged:: 1.0.0 The string ``'None'`` is accepted for ``samesite``. """ - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be removed" - " in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - if isinstance(key, bytes): - warnings.warn( - "The 'key' parameter must be a string. Bytes are deprecated" - " and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - key = key.decode() - - if isinstance(value, bytes): - warnings.warn( - "The 'value' parameter must be a string. Bytes are" - " deprecated and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - value = value.decode() - if path is not None: # safe = https://url.spec.whatwg.org/#url-path-segment-string # as well as percent for things that are already quoted # excluding semicolon since it's part of the header syntax - path = quote(path, safe="%!$&'()*+,/:=@", encoding=charset) + path = quote(path, safe="%!$&'()*+,/:=@") if domain: domain = domain.partition(":")[0].lstrip(".").encode("idna").decode("ascii") @@ -1464,12 +1327,15 @@ def dump_cookie( if samesite not in {"Strict", "Lax", "None"}: raise ValueError("SameSite must be 'Strict', 'Lax', or 'None'.") + if partitioned: + secure = True + # Quote value if it contains characters not allowed by RFC 6265. Slash-escape with # three octal digits, which matches http.cookies, although the RFC suggests base64. if not _cookie_no_quote_re.fullmatch(value): # Work with bytes here, since a UTF-8 character could be multiple bytes. value = _cookie_slash_re.sub( - lambda m: _cookie_slash_map[m.group()], value.encode(charset) + lambda m: _cookie_slash_map[m.group()], value.encode() ).decode("ascii") value = f'"{value}"' @@ -1485,6 +1351,7 @@ def dump_cookie( ("HttpOnly", httponly), ("Path", path), ("SameSite", samesite), + ("Partitioned", partitioned), ): if v is None or v is False: continue diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py index fba80e974..302589bba 100644 --- a/src/werkzeug/local.py +++ b/src/werkzeug/local.py @@ -20,7 +20,7 @@ F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -def release_local(local: Local | LocalStack) -> None: +def release_local(local: Local | LocalStack[t.Any]) -> None: """Release the data for the current context in a :class:`Local` or :class:`LocalStack` without using a :class:`LocalManager`. @@ -64,7 +64,9 @@ def __init__(self, context_var: ContextVar[dict[str, t.Any]] | None = None) -> N def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: return iter(self.__storage.get({}).items()) - def __call__(self, name: str, *, unbound_message: str | None = None) -> LocalProxy: + def __call__( + self, name: str, *, unbound_message: str | None = None + ) -> LocalProxy[t.Any]: """Create a :class:`LocalProxy` that access an attribute on this local namespace. @@ -169,7 +171,7 @@ def top(self) -> T | None: def __call__( self, name: str | None = None, *, unbound_message: str | None = None - ) -> LocalProxy: + ) -> LocalProxy[t.Any]: """Create a :class:`LocalProxy` that accesses the top of this local stack. @@ -205,7 +207,8 @@ class LocalManager: def __init__( self, - locals: None | (Local | LocalStack | t.Iterable[Local | LocalStack]) = None, + locals: None + | (Local | LocalStack[t.Any] | t.Iterable[Local | LocalStack[t.Any]]) = None, ) -> None: if locals is None: self.locals = [] @@ -269,23 +272,27 @@ class _ProxyLookup: def __init__( self, - f: t.Callable | None = None, - fallback: t.Callable | None = None, + f: t.Callable[..., t.Any] | None = None, + fallback: t.Callable[[LocalProxy[t.Any]], t.Any] | None = None, class_value: t.Any | None = None, is_attr: bool = False, ) -> None: - bind_f: t.Callable[[LocalProxy, t.Any], t.Callable] | None + bind_f: t.Callable[[LocalProxy[t.Any], t.Any], t.Callable[..., t.Any]] | None if hasattr(f, "__get__"): # A Python function, can be turned into a bound method. - def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable: + def bind_f( + instance: LocalProxy[t.Any], obj: t.Any + ) -> t.Callable[..., t.Any]: return f.__get__(obj, type(obj)) # type: ignore elif f is not None: # A C function, use partial to bind the first argument. - def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable: + def bind_f( + instance: LocalProxy[t.Any], obj: t.Any + ) -> t.Callable[..., t.Any]: return partial(f, obj) else: @@ -297,10 +304,10 @@ def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable: self.class_value = class_value self.is_attr = is_attr - def __set_name__(self, owner: LocalProxy, name: str) -> None: + def __set_name__(self, owner: LocalProxy[t.Any], name: str) -> None: self.name = name - def __get__(self, instance: LocalProxy, owner: type | None = None) -> t.Any: + def __get__(self, instance: LocalProxy[t.Any], owner: type | None = None) -> t.Any: if instance is None: if self.class_value is not None: return self.class_value @@ -330,7 +337,9 @@ def __get__(self, instance: LocalProxy, owner: type | None = None) -> t.Any: def __repr__(self) -> str: return f"proxy {self.name}" - def __call__(self, instance: LocalProxy, *args: t.Any, **kwargs: t.Any) -> t.Any: + def __call__( + self, instance: LocalProxy[t.Any], *args: t.Any, **kwargs: t.Any + ) -> t.Any: """Support calling unbound methods from the class. For example, this happens with ``copy.copy``, which does ``type(x).__copy__(x)``. ``type(x)`` can't be proxied, so it @@ -347,12 +356,14 @@ class _ProxyIOp(_ProxyLookup): __slots__ = () def __init__( - self, f: t.Callable | None = None, fallback: t.Callable | None = None + self, + f: t.Callable[..., t.Any] | None = None, + fallback: t.Callable[[LocalProxy[t.Any]], t.Any] | None = None, ) -> None: super().__init__(f, fallback) - def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable: - def i_op(self: t.Any, other: t.Any) -> LocalProxy: + def bind_f(instance: LocalProxy[t.Any], obj: t.Any) -> t.Callable[..., t.Any]: + def i_op(self: t.Any, other: t.Any) -> LocalProxy[t.Any]: f(self, other) # type: ignore return instance @@ -520,32 +531,33 @@ def _get_current_object() -> T: object.__setattr__(self, "_LocalProxy__wrapped", local) object.__setattr__(self, "_get_current_object", _get_current_object) - __doc__ = _ProxyLookup( # type: ignore + __doc__ = _ProxyLookup( # type: ignore[assignment] class_value=__doc__, fallback=lambda self: type(self).__doc__, is_attr=True ) __wrapped__ = _ProxyLookup( - fallback=lambda self: self._LocalProxy__wrapped, is_attr=True + fallback=lambda self: self._LocalProxy__wrapped, # type: ignore[attr-defined] + is_attr=True, ) # __del__ should only delete the proxy - __repr__ = _ProxyLookup( # type: ignore + __repr__ = _ProxyLookup( # type: ignore[assignment] repr, fallback=lambda self: f"<{type(self).__name__} unbound>" ) - __str__ = _ProxyLookup(str) # type: ignore + __str__ = _ProxyLookup(str) # type: ignore[assignment] __bytes__ = _ProxyLookup(bytes) - __format__ = _ProxyLookup() # type: ignore + __format__ = _ProxyLookup() # type: ignore[assignment] __lt__ = _ProxyLookup(operator.lt) __le__ = _ProxyLookup(operator.le) - __eq__ = _ProxyLookup(operator.eq) # type: ignore - __ne__ = _ProxyLookup(operator.ne) # type: ignore + __eq__ = _ProxyLookup(operator.eq) # type: ignore[assignment] + __ne__ = _ProxyLookup(operator.ne) # type: ignore[assignment] __gt__ = _ProxyLookup(operator.gt) __ge__ = _ProxyLookup(operator.ge) - __hash__ = _ProxyLookup(hash) # type: ignore + __hash__ = _ProxyLookup(hash) # type: ignore[assignment] __bool__ = _ProxyLookup(bool, fallback=lambda self: False) __getattr__ = _ProxyLookup(getattr) # __getattribute__ triggered through __getattr__ - __setattr__ = _ProxyLookup(setattr) # type: ignore - __delattr__ = _ProxyLookup(delattr) # type: ignore - __dir__ = _ProxyLookup(dir, fallback=lambda self: []) # type: ignore + __setattr__ = _ProxyLookup(setattr) # type: ignore[assignment] + __delattr__ = _ProxyLookup(delattr) # type: ignore[assignment] + __dir__ = _ProxyLookup(dir, fallback=lambda self: []) # type: ignore[assignment] # __get__ (proxying descriptor not supported) # __set__ (descriptor) # __delete__ (descriptor) @@ -556,9 +568,7 @@ def _get_current_object() -> T: # __weakref__ (__getattr__) # __init_subclass__ (proxying metaclass not supported) # __prepare__ (metaclass) - __class__ = _ProxyLookup( - fallback=lambda self: type(self), is_attr=True - ) # type: ignore + __class__ = _ProxyLookup(fallback=lambda self: type(self), is_attr=True) # type: ignore[assignment] __instancecheck__ = _ProxyLookup(lambda self, other: isinstance(other, self)) __subclasscheck__ = _ProxyLookup(lambda self, other: issubclass(other, self)) # __class_getitem__ triggered through __getitem__ diff --git a/src/werkzeug/middleware/dispatcher.py b/src/werkzeug/middleware/dispatcher.py index 559fea585..e11bacc52 100644 --- a/src/werkzeug/middleware/dispatcher.py +++ b/src/werkzeug/middleware/dispatcher.py @@ -30,6 +30,7 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + from __future__ import annotations import typing as t diff --git a/src/werkzeug/middleware/http_proxy.py b/src/werkzeug/middleware/http_proxy.py index 59ba9b324..5e239156a 100644 --- a/src/werkzeug/middleware/http_proxy.py +++ b/src/werkzeug/middleware/http_proxy.py @@ -7,6 +7,7 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + from __future__ import annotations import typing as t diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py index 462959943..3714271b1 100644 --- a/src/werkzeug/middleware/lint.py +++ b/src/werkzeug/middleware/lint.py @@ -12,6 +12,7 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + from __future__ import annotations import typing as t @@ -37,7 +38,7 @@ class HTTPWarning(Warning): """Warning class for HTTP warnings.""" -def check_type(context: str, obj: object, need: t.Type = str) -> None: +def check_type(context: str, obj: object, need: type = str) -> None: if type(obj) is not need: warn( f"{context!r} requires {need.__name__!r}, got {type(obj).__name__!r}.", @@ -180,30 +181,44 @@ def close(self) -> None: key ): warn( - f"Entity header {key!r} found in 304 response.", HTTPWarning + f"Entity header {key!r} found in 304 response.", + HTTPWarning, + stacklevel=2, ) if bytes_sent: - warn("304 responses must not have a body.", HTTPWarning) + warn( + "304 responses must not have a body.", + HTTPWarning, + stacklevel=2, + ) elif 100 <= status_code < 200 or status_code == 204: if content_length != 0: warn( f"{status_code} responses must have an empty content length.", HTTPWarning, + stacklevel=2, ) if bytes_sent: - warn(f"{status_code} responses must not have a body.", HTTPWarning) + warn( + f"{status_code} responses must not have a body.", + HTTPWarning, + stacklevel=2, + ) elif content_length is not None and content_length != bytes_sent: warn( "Content-Length and the number of bytes sent to the" " client do not match.", WSGIWarning, + stacklevel=2, ) def __del__(self) -> None: if not self.closed: try: warn( - "Iterator was garbage collected before it was closed.", WSGIWarning + "Iterator was garbage collected before it was closed.", + WSGIWarning, + stacklevel=2, ) except Exception: pass @@ -236,7 +251,7 @@ def __init__(self, app: WSGIApplication) -> None: self.app = app def check_environ(self, environ: WSGIEnvironment) -> None: - if type(environ) is not dict: + if type(environ) is not dict: # noqa: E721 warn( "WSGI environment is not a standard Python dict.", WSGIWarning, @@ -304,14 +319,14 @@ def check_start_response( if status_code < 100: warn("Status code < 100 detected.", WSGIWarning, stacklevel=3) - if type(headers) is not list: + if type(headers) is not list: # noqa: E721 warn("Header list is not a list.", WSGIWarning, stacklevel=3) for item in headers: if type(item) is not tuple or len(item) != 2: warn("Header items must be 2-item tuples.", WSGIWarning, stacklevel=3) name, value = item - if type(name) is not str or type(value) is not str: + if type(name) is not str or type(value) is not str: # noqa: E721 warn( "Header keys and values must be strings.", WSGIWarning, stacklevel=3 ) @@ -326,10 +341,10 @@ def check_start_response( if exc_info is not None and not isinstance(exc_info, tuple): warn("Invalid value for exc_info.", WSGIWarning, stacklevel=3) - headers = Headers(headers) - self.check_headers(headers) + headers_obj = Headers(headers) + self.check_headers(headers_obj) - return status_code, headers + return status_code, headers_obj def check_headers(self, headers: Headers) -> None: etag = headers.get("etag") @@ -402,13 +417,17 @@ def checking_start_response( ) if kwargs: - warn("'start_response' does not take keyword arguments.", WSGIWarning) + warn( + "'start_response' does not take keyword arguments.", + WSGIWarning, + stacklevel=2, + ) status: str = args[0] headers: list[tuple[str, str]] = args[1] - exc_info: None | ( - tuple[type[BaseException], BaseException, TracebackType] - ) = (args[2] if len(args) == 3 else None) + exc_info: ( + None | (tuple[type[BaseException], BaseException, TracebackType]) + ) = args[2] if len(args) == 3 else None headers_set[:] = self.check_start_response(status, headers, exc_info) return GuardedWrite(start_response(status, headers, exc_info), chunks) @@ -416,5 +435,5 @@ def checking_start_response( app_iter = self.app(environ, t.cast("StartResponse", checking_start_response)) self.check_iterator(app_iter) return GuardedIterator( - app_iter, t.cast(t.Tuple[int, Headers], headers_set), chunks + app_iter, t.cast(tuple[int, Headers], headers_set), chunks ) diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py index 2d806154c..112b87776 100644 --- a/src/werkzeug/middleware/profiler.py +++ b/src/werkzeug/middleware/profiler.py @@ -11,6 +11,7 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + from __future__ import annotations import os.path @@ -44,11 +45,16 @@ class ProfilerMiddleware: - ``{method}`` - The request method; GET, POST, etc. - ``{path}`` - The request path or 'root' should one not exist. - - ``{elapsed}`` - The elapsed time of the request. + - ``{elapsed}`` - The elapsed time of the request in milliseconds. - ``{time}`` - The time of the request. - If it is a callable, it will be called with the WSGI ``environ`` - dict and should return a filename. + If it is a callable, it will be called with the WSGI ``environ`` and + be expected to return a filename string. The ``environ`` dictionary + will also have the ``"werkzeug.profiler"`` key populated with a + dictionary containing the following fields (more may be added in the + future): + - ``{elapsed}`` - The elapsed time of the request in milliseconds. + - ``{time}`` - The time of the request. :param app: The WSGI application to wrap. :param stream: Write stats to this stream. Disable with ``None``. @@ -65,6 +71,10 @@ class ProfilerMiddleware: from werkzeug.middleware.profiler import ProfilerMiddleware app = ProfilerMiddleware(app) + .. versionchanged:: 3.0 + Added the ``"werkzeug.profiler"`` key to the ``filename_format(environ)`` + parameter with the ``elapsed`` and ``time`` fields. + .. versionchanged:: 0.15 Stats are written even if ``profile_dir`` is given, and can be disable by passing ``stream=None``. @@ -118,6 +128,10 @@ def runapp() -> None: if self._profile_dir is not None: if callable(self._filename_format): + environ["werkzeug.profiler"] = { + "elapsed": elapsed * 1000.0, + "time": time.time(), + } filename = self._filename_format(environ) else: filename = self._filename_format.format( diff --git a/src/werkzeug/middleware/proxy_fix.py b/src/werkzeug/middleware/proxy_fix.py index 8dfbb36c0..cbf4e0bae 100644 --- a/src/werkzeug/middleware/proxy_fix.py +++ b/src/werkzeug/middleware/proxy_fix.py @@ -21,6 +21,7 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + from __future__ import annotations import typing as t diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py index e3ec7cab8..c7c06df5a 100644 --- a/src/werkzeug/middleware/shared_data.py +++ b/src/werkzeug/middleware/shared_data.py @@ -8,8 +8,10 @@ :copyright: 2007 Pallets :license: BSD-3-Clause """ + from __future__ import annotations +import collections.abc as cabc import importlib.util import mimetypes import os @@ -28,8 +30,8 @@ from ..wsgi import get_path_info from ..wsgi import wrap_file -_TOpener = t.Callable[[], t.Tuple[t.IO[bytes], datetime, int]] -_TLoader = t.Callable[[t.Optional[str]], t.Tuple[t.Optional[str], t.Optional[_TOpener]]] +_TOpener = t.Callable[[], tuple[t.IO[bytes], datetime, int]] +_TLoader = t.Callable[[t.Optional[str]], tuple[t.Optional[str], t.Optional[_TOpener]]] if t.TYPE_CHECKING: from _typeshed.wsgi import StartResponse @@ -38,7 +40,6 @@ class SharedDataMiddleware: - """A WSGI middleware which provides static content for development environments or simple server setups. Its usage is quite simple:: @@ -103,7 +104,7 @@ def __init__( self, app: WSGIApplication, exports: ( - dict[str, str | tuple[str, str]] + cabc.Mapping[str, str | tuple[str, str]] | t.Iterable[tuple[str, str | tuple[str, str]]] ), disallow: None = None, @@ -116,7 +117,7 @@ def __init__( self.cache = cache self.cache_timeout = cache_timeout - if isinstance(exports, dict): + if isinstance(exports, cabc.Mapping): exports = exports.items() for key, value in exports: @@ -218,9 +219,9 @@ def loader( return loader def generate_etag(self, mtime: datetime, file_size: int, real_filename: str) -> str: - real_filename = os.fsencode(real_filename) + fn_str = os.fsencode(real_filename) timestamp = mtime.timestamp() - checksum = adler32(real_filename) & 0xFFFFFFFF + checksum = adler32(fn_str) & 0xFFFFFFFF return f"wzsdm-{timestamp}-{file_size}-{checksum}" def __call__( diff --git a/src/werkzeug/routing/__init__.py b/src/werkzeug/routing/__init__.py index 84b043fdf..62adc48fb 100644 --- a/src/werkzeug/routing/__init__.py +++ b/src/werkzeug/routing/__init__.py @@ -105,6 +105,7 @@ routing tried to match a ``POST`` request) a ``MethodNotAllowed`` exception is raised. """ + from .converters import AnyConverter as AnyConverter from .converters import BaseConverter as BaseConverter from .converters import FloatConverter as FloatConverter diff --git a/src/werkzeug/routing/converters.py b/src/werkzeug/routing/converters.py index c59e2abcb..6016a975e 100644 --- a/src/werkzeug/routing/converters.py +++ b/src/werkzeug/routing/converters.py @@ -3,7 +3,6 @@ import re import typing as t import uuid -import warnings from urllib.parse import quote if t.TYPE_CHECKING: @@ -42,17 +41,8 @@ def to_python(self, value: str) -> t.Any: return value def to_url(self, value: t.Any) -> str: - if isinstance(value, (bytes, bytearray)): - warnings.warn( - "Passing bytes as a URL value is deprecated and will not be supported" - " in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=7, - ) - return quote(value, safe="!$&'()*+,/:;=@") - # safe = https://url.spec.whatwg.org/#url-path-segment-string - return quote(str(value), encoding=self.map.charset, safe="!$&'()*+,/:;=@") + return quote(str(value), safe="!$&'()*+,/:;=@") class UnicodeConverter(BaseConverter): @@ -129,6 +119,7 @@ class PathConverter(BaseConverter): :param map: the :class:`Map`. """ + part_isolating = False regex = "[^/].*?" weight = 200 @@ -140,7 +131,7 @@ class NumberConverter(BaseConverter): """ weight = 50 - num_convert: t.Callable = int + num_convert: t.Callable[[t.Any], t.Any] = int def __init__( self, @@ -161,18 +152,18 @@ def __init__( def to_python(self, value: str) -> t.Any: if self.fixed_digits and len(value) != self.fixed_digits: raise ValidationError() - value = self.num_convert(value) - if (self.min is not None and value < self.min) or ( - self.max is not None and value > self.max + value_num = self.num_convert(value) + if (self.min is not None and value_num < self.min) or ( + self.max is not None and value_num > self.max ): raise ValidationError() - return value + return value_num def to_url(self, value: t.Any) -> str: - value = str(self.num_convert(value)) + value_str = str(self.num_convert(value)) if self.fixed_digits: - value = value.zfill(self.fixed_digits) - return value + value_str = value_str.zfill(self.fixed_digits) + return value_str @property def signed_regex(self) -> str: diff --git a/src/werkzeug/routing/exceptions.py b/src/werkzeug/routing/exceptions.py index 9d0a5281b..eeabd4ed1 100644 --- a/src/werkzeug/routing/exceptions.py +++ b/src/werkzeug/routing/exceptions.py @@ -10,10 +10,11 @@ if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment - from .map import MapAdapter - from .rules import Rule + from ..wrappers.request import Request from ..wrappers.response import Response + from .map import MapAdapter + from .rules import Rule class RoutingException(Exception): @@ -40,7 +41,7 @@ def __init__(self, new_url: str) -> None: def get_response( self, environ: WSGIEnvironment | Request | None = None, - scope: dict | None = None, + scope: dict[str, t.Any] | None = None, ) -> Response: return redirect(self.new_url, self.code) @@ -58,7 +59,7 @@ def __init__(self, path_info: str) -> None: class RequestAliasRedirect(RoutingException): # noqa: B903 """This rule is an alias and wants to redirect to the canonical URL.""" - def __init__(self, matched_values: t.Mapping[str, t.Any], endpoint: str) -> None: + def __init__(self, matched_values: t.Mapping[str, t.Any], endpoint: t.Any) -> None: super().__init__() self.matched_values = matched_values self.endpoint = endpoint @@ -71,7 +72,7 @@ class BuildError(RoutingException, LookupError): def __init__( self, - endpoint: str, + endpoint: t.Any, values: t.Mapping[str, t.Any], method: str | None, adapter: MapAdapter | None = None, @@ -92,7 +93,10 @@ def _score_rule(rule: Rule) -> float: [ 0.98 * difflib.SequenceMatcher( - None, rule.endpoint, self.endpoint + # endpoints can be any type, compare as strings + None, + str(rule.endpoint), + str(self.endpoint), ).ratio(), 0.01 * bool(set(self.values or ()).issubset(rule.arguments)), 0.01 * bool(rule.methods and self.method in rule.methods), diff --git a/src/werkzeug/routing/map.py b/src/werkzeug/routing/map.py index 0d02bb8b7..4d15e8824 100644 --- a/src/werkzeug/routing/map.py +++ b/src/werkzeug/routing/map.py @@ -32,9 +32,10 @@ if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment + + from ..wrappers.request import Request from .converters import BaseConverter from .rules import RuleFactory - from ..wrappers.request import Request class Map: @@ -47,7 +48,6 @@ class Map: :param rules: sequence of url rules for this map. :param default_subdomain: The default subdomain for rules without a subdomain defined. - :param charset: charset of the url. defaults to ``"utf-8"`` :param strict_slashes: If a rule ends with a slash but the matched URL does not, redirect to the URL with a trailing slash. :param merge_slashes: Merge consecutive slashes when matching or @@ -62,15 +62,13 @@ class Map: :param sort_parameters: If set to `True` the url parameters are sorted. See `url_encode` for more details. :param sort_key: The sort key function for `url_encode`. - :param encoding_errors: the error method to use for decoding :param host_matching: if set to `True` it enables the host matching feature and disables the subdomain one. If enabled the `host` parameter to rules is used instead of the `subdomain` one. - .. versionchanged:: 2.3 - The ``charset`` and ``encoding_errors`` parameters are deprecated and will be - removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``charset`` and ``encoding_errors`` parameters were removed. .. versionchanged:: 1.0 If ``url_scheme`` is ``ws`` or ``wss``, only WebSocket rules will match. @@ -97,48 +95,21 @@ def __init__( self, rules: t.Iterable[RuleFactory] | None = None, default_subdomain: str = "", - charset: str | None = None, strict_slashes: bool = True, merge_slashes: bool = True, redirect_defaults: bool = True, converters: t.Mapping[str, type[BaseConverter]] | None = None, sort_parameters: bool = False, sort_key: t.Callable[[t.Any], t.Any] | None = None, - encoding_errors: str | None = None, host_matching: bool = False, ) -> None: self._matcher = StateMachineMatcher(merge_slashes) - self._rules_by_endpoint: dict[str, list[Rule]] = {} + self._rules_by_endpoint: dict[t.Any, list[Rule]] = {} self._remap = True self._remap_lock = self.lock_class() self.default_subdomain = default_subdomain - - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - self.charset = charset - - if encoding_errors is not None: - warnings.warn( - "The 'encoding_errors' parameter is deprecated and will be" - " removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - encoding_errors = "replace" - - self.encoding_errors = encoding_errors self.strict_slashes = strict_slashes - self.merge_slashes = merge_slashes self.redirect_defaults = redirect_defaults self.host_matching = host_matching @@ -152,7 +123,15 @@ def __init__( for rulefactory in rules or (): self.add(rulefactory) - def is_endpoint_expecting(self, endpoint: str, *arguments: str) -> bool: + @property + def merge_slashes(self) -> bool: + return self._matcher.merge_slashes + + @merge_slashes.setter + def merge_slashes(self, value: bool) -> None: + self._matcher.merge_slashes = value + + def is_endpoint_expecting(self, endpoint: t.Any, *arguments: str) -> bool: """Iterate over all rules and check if the endpoint expects the arguments provided. This is for example useful if you have some URLs that expect a language code and others that do not and @@ -166,9 +145,9 @@ def is_endpoint_expecting(self, endpoint: str, *arguments: str) -> bool: checked. """ self.update() - arguments = set(arguments) + arguments_set = set(arguments) for rule in self._rules_by_endpoint[endpoint]: - if arguments.issubset(rule.arguments): + if arguments_set.issubset(rule.arguments): return True return False @@ -176,7 +155,7 @@ def is_endpoint_expecting(self, endpoint: str, *arguments: str) -> bool: def _rules(self) -> list[Rule]: return [rule for rules in self._rules_by_endpoint.values() for rule in rules] - def iter_rules(self, endpoint: str | None = None) -> t.Iterator[Rule]: + def iter_rules(self, endpoint: t.Any | None = None) -> t.Iterator[Rule]: """Iterate over all rules or the rules of an endpoint. :param endpoint: if provided only the rules for that endpoint @@ -362,7 +341,7 @@ def bind_to_environ( def _get_wsgi_string(name: str) -> str | None: val = env.get(name) if val is not None: - return _wsgi_decoding_dance(val, self.charset) + return _wsgi_decoding_dance(val) return None script_name = _get_wsgi_string("SCRIPT_NAME") @@ -401,7 +380,6 @@ def __repr__(self) -> str: class MapAdapter: - """Returned by :meth:`Map.bind` or :meth:`Map.bind_to_environ` and does the URL matching and building based on runtime information. """ @@ -492,15 +470,14 @@ def application(environ, start_response): raise @t.overload - def match( # type: ignore + def match( self, path_info: str | None = None, method: str | None = None, return_rule: t.Literal[False] = False, query_args: t.Mapping[str, t.Any] | str | None = None, websocket: bool | None = None, - ) -> tuple[str, t.Mapping[str, t.Any]]: - ... + ) -> tuple[t.Any, t.Mapping[str, t.Any]]: ... @t.overload def match( @@ -510,8 +487,7 @@ def match( return_rule: t.Literal[True] = True, query_args: t.Mapping[str, t.Any] | str | None = None, websocket: bool | None = None, - ) -> tuple[Rule, t.Mapping[str, t.Any]]: - ... + ) -> tuple[Rule, t.Mapping[str, t.Any]]: ... def match( self, @@ -520,7 +496,7 @@ def match( return_rule: bool = False, query_args: t.Mapping[str, t.Any] | str | None = None, websocket: bool | None = None, - ) -> tuple[str | Rule, t.Mapping[str, t.Any]]: + ) -> tuple[t.Any | Rule, t.Mapping[str, t.Any]]: """The usage is simple: you just pass the match method the current path info as well as the method (which defaults to `GET`). The following things can then happen: @@ -629,9 +605,7 @@ def match( result = self.map._matcher.match(domain_part, path_part, method, websocket) except RequestPath as e: # safe = https://url.spec.whatwg.org/#url-path-segment-string - new_path = quote( - e.path_info, safe="!$&'()*+,/:;=@", encoding=self.map.charset - ) + new_path = quote(e.path_info, safe="!$&'()*+,/:;=@") raise RequestRedirect( self.make_redirect_url(new_path, query_args) ) from None @@ -767,7 +741,7 @@ def get_default_redirect( def encode_query_args(self, query_args: t.Mapping[str, t.Any] | str) -> str: if not isinstance(query_args, str): - return _urlencode(query_args, encoding=self.map.charset) + return _urlencode(query_args) return query_args def make_redirect_url( @@ -796,7 +770,7 @@ def make_redirect_url( def make_alias_redirect_url( self, path: str, - endpoint: str, + endpoint: t.Any, values: t.Mapping[str, t.Any], method: str, query_args: t.Mapping[str, t.Any] | str, @@ -812,7 +786,7 @@ def make_alias_redirect_url( def _partial_build( self, - endpoint: str, + endpoint: t.Any, values: t.Mapping[str, t.Any], method: str | None, append_unknown: bool, @@ -853,7 +827,7 @@ def _partial_build( def build( self, - endpoint: str, + endpoint: t.Any, values: t.Mapping[str, t.Any] | None = None, method: str | None = None, force_external: bool = False, diff --git a/src/werkzeug/routing/matcher.py b/src/werkzeug/routing/matcher.py index 0d1210a67..1fd00efca 100644 --- a/src/werkzeug/routing/matcher.py +++ b/src/werkzeug/routing/matcher.py @@ -177,7 +177,7 @@ def _match( rv = _match(self._root, [domain, *path.split("/")], []) except SlashRequired: raise RequestPath(f"{path}/") from None - if rv is None: + if rv is None or rv[0].merge_slashes is False: raise NoMatch(have_match_for, websocket_mismatch) else: raise RequestPath(f"{path}") diff --git a/src/werkzeug/routing/rules.py b/src/werkzeug/routing/rules.py index 904a02258..2dad31dd3 100644 --- a/src/werkzeug/routing/rules.py +++ b/src/werkzeug/routing/rules.py @@ -67,6 +67,7 @@ class RulePart: _simple_rule_re = re.compile(r"<([^>]+)>") _converter_args_re = re.compile( r""" + \s* ((?P\w+)\s*=\s*)? (?P True|False| @@ -100,7 +101,7 @@ def _pythonize(value: str) -> None | bool | int | float | str: return _PYTHON_CONSTANTS[value] for convert in int, float: try: - return convert(value) # type: ignore + return convert(value) except ValueError: pass if value[:1] == value[-1:] and value[0] in "\"'": @@ -108,12 +109,18 @@ def _pythonize(value: str) -> None | bool | int | float | str: return str(value) -def parse_converter_args(argstr: str) -> tuple[t.Tuple, dict[str, t.Any]]: +def parse_converter_args(argstr: str) -> tuple[tuple[t.Any, ...], dict[str, t.Any]]: argstr += "," args = [] kwargs = {} + position = 0 for item in _converter_args_re.finditer(argstr): + if item.start() != position: + raise ValueError( + f"Cannot parse converter argument '{argstr[position:item.start()]}'" + ) + value = item.group("stringval") if value is None: value = item.group("value") @@ -123,6 +130,7 @@ def parse_converter_args(argstr: str) -> tuple[t.Tuple, dict[str, t.Any]]: else: name = item.group("name") kwargs[name] = value + position = item.end() return tuple(args), kwargs @@ -286,11 +294,18 @@ def get_rules(self, map: Map) -> t.Iterator[Rule]: ) -def _prefix_names(src: str) -> ast.stmt: +_ASTT = t.TypeVar("_ASTT", bound=ast.AST) + + +def _prefix_names(src: str, expected_type: type[_ASTT]) -> _ASTT: """ast parse and prefix names with `.` to avoid collision with user vars""" - tree = ast.parse(src).body[0] + tree: ast.AST = ast.parse(src).body[0] if isinstance(tree, ast.Expr): - tree = tree.value # type: ignore + tree = tree.value + if not isinstance(tree, expected_type): + raise TypeError( + f"AST node is of type {type(tree).__name__}, not {expected_type.__name__}" + ) for node in ast.walk(tree): if isinstance(node, ast.Name): node.id = f".{node.id}" @@ -305,8 +320,11 @@ def _prefix_names(src: str) -> ast.stmt: else: q = params = "" """ -_IF_KWARGS_URL_ENCODE_AST = _prefix_names(_IF_KWARGS_URL_ENCODE_CODE) -_URL_ENCODE_AST_NAMES = (_prefix_names("q"), _prefix_names("params")) +_IF_KWARGS_URL_ENCODE_AST = _prefix_names(_IF_KWARGS_URL_ENCODE_CODE, ast.If) +_URL_ENCODE_AST_NAMES = ( + _prefix_names("q", ast.Name), + _prefix_names("params", ast.Name), +) class Rule(RuleFactory): @@ -445,7 +463,7 @@ def __init__( subdomain: str | None = None, methods: t.Iterable[str] | None = None, build_only: bool = False, - endpoint: str | None = None, + endpoint: t.Any | None = None, strict_slashes: bool | None = None, merge_slashes: bool | None = None, redirect_to: str | t.Callable[..., str] | None = None, @@ -485,7 +503,7 @@ def __init__( ) self.methods = methods - self.endpoint: str = endpoint # type: ignore + self.endpoint: t.Any = endpoint self.redirect_to = redirect_to if defaults: @@ -566,7 +584,7 @@ def get_converter( self, variable_name: str, converter_name: str, - args: t.Tuple, + args: tuple[t.Any, ...], kwargs: t.Mapping[str, t.Any], ) -> BaseConverter: """Looks up the converter for the given parameter. @@ -583,7 +601,7 @@ def _encode_query_vars(self, query_vars: t.Mapping[str, t.Any]) -> str: if self.map.sort_parameters: items = sorted(items, key=self.map.sort_key) - return _urlencode(items, encoding=self.map.charset) + return _urlencode(items) def _parse_rule(self, rule: str) -> t.Iterable[RulePart]: content = "" @@ -739,22 +757,17 @@ def _compile_builder( opl.append((False, data)) elif not is_dynamic: # safe = https://url.spec.whatwg.org/#url-path-segment-string - opl.append( - ( - False, - quote(data, safe="!$&'()*+,/:;=@", encoding=self.map.charset), - ) - ) + opl.append((False, quote(data, safe="!$&'()*+,/:;=@"))) else: opl.append((True, data)) - def _convert(elem: str) -> ast.stmt: - ret = _prefix_names(_CALL_CONVERTER_CODE_FMT.format(elem=elem)) - ret.args = [ast.Name(str(elem), ast.Load())] # type: ignore # str for py2 + def _convert(elem: str) -> ast.Call: + ret = _prefix_names(_CALL_CONVERTER_CODE_FMT.format(elem=elem), ast.Call) + ret.args = [ast.Name(elem, ast.Load())] return ret - def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]: - parts = [ + def _parts(ops: list[tuple[bool, str]]) -> list[ast.expr]: + parts: list[ast.expr] = [ _convert(elem) if is_dynamic else ast.Constant(elem) for is_dynamic, elem in ops ] @@ -770,13 +783,14 @@ def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]: dom_parts = _parts(dom_ops) url_parts = _parts(url_ops) + body: list[ast.stmt] if not append_unknown: body = [] else: body = [_IF_KWARGS_URL_ENCODE_AST] url_parts.extend(_URL_ENCODE_AST_NAMES) - def _join(parts: list[ast.AST]) -> ast.AST: + def _join(parts: list[ast.expr]) -> ast.expr: if len(parts) == 1: # shortcut return parts[0] return ast.JoinedStr(parts) @@ -792,7 +806,7 @@ def _join(parts: list[ast.AST]) -> ast.AST: ] kargs = [str(k) for k in defaults] - func_ast: ast.FunctionDef = _prefix_names("def _(): pass") # type: ignore + func_ast = _prefix_names("def _(): pass", ast.FunctionDef) func_ast.name = f"" func_ast.args.args.append(ast.arg(".self", None)) for arg in pargs + kargs: @@ -812,13 +826,13 @@ def _join(parts: list[ast.AST]) -> ast.AST: # bad line numbers cause an assert to fail in debug builds for node in ast.walk(module): if "lineno" in node._attributes: - node.lineno = 1 + node.lineno = 1 # type: ignore[attr-defined] if "end_lineno" in node._attributes: - node.end_lineno = node.lineno + node.end_lineno = node.lineno # type: ignore[attr-defined] if "col_offset" in node._attributes: - node.col_offset = 0 + node.col_offset = 0 # type: ignore[attr-defined] if "end_col_offset" in node._attributes: - node.end_col_offset = node.col_offset + node.end_col_offset = node.col_offset # type: ignore[attr-defined] code = compile(module, "", "exec") return self._get_func_code(code, func_ast.name) @@ -909,6 +923,6 @@ def __repr__(self) -> str: parts.append(f"<{data}>") else: parts.append(data) - parts = "".join(parts).lstrip("|") + parts_str = "".join(parts).lstrip("|") methods = f" ({', '.join(self.methods)})" if self.methods is not None else "" - return f"<{type(self).__name__} {parts!r}{methods} -> {self.endpoint}>" + return f"<{type(self).__name__} {parts_str!r}{methods} -> {self.endpoint}>" diff --git a/src/werkzeug/sansio/http.py b/src/werkzeug/sansio/http.py index 21a619720..f02d7fd54 100644 --- a/src/werkzeug/sansio/http.py +++ b/src/werkzeug/sansio/http.py @@ -2,7 +2,6 @@ import re import typing as t -import warnings from datetime import datetime from .._internal import _dt_as_utc @@ -73,7 +72,6 @@ def is_resource_modified( if etag: etag, _ = unquote_etag(etag) - etag = t.cast(str, etag) if if_range is not None and if_range.etag is not None: unmodified = parse_etags(if_range.etag).contains(etag) @@ -123,9 +121,7 @@ def _cookie_unslash_replace(m: t.Match[bytes]) -> bytes: def parse_cookie( cookie: str | None = None, - charset: str | None = None, - errors: str | None = None, - cls: type[ds.MultiDict] | None = None, + cls: type[ds.MultiDict[str, str]] | None = None, ) -> ds.MultiDict[str, str]: """Parse a cookie from a string. @@ -138,41 +134,13 @@ def parse_cookie( :param cls: A dict-like class to store the parsed cookies in. Defaults to :class:`MultiDict`. - .. versionchanged:: 2.3 - Passing bytes, and the ``charset`` and ``errors`` parameters, are deprecated and - will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` and ``errors`` parameters, were removed. .. versionadded:: 2.2 """ if cls is None: - cls = ds.MultiDict - - if isinstance(cookie, bytes): - warnings.warn( - "The 'cookie' parameter must be a string. Passing bytes is deprecated and" - " will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - cookie = cookie.decode() - - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be removed in Werkzeug 3.0", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - if errors is not None: - warnings.warn( - "The 'errors' parameter is deprecated and will be removed in Werkzeug 3.0", - DeprecationWarning, - stacklevel=2, - ) - else: - errors = "replace" + cls = t.cast("type[ds.MultiDict[str, str]]", ds.MultiDict) if not cookie: return cls() @@ -191,7 +159,7 @@ def parse_cookie( # Work with bytes here, since a UTF-8 character could be multiple bytes. cv = _cookie_unslash_re.sub( _cookie_unslash_replace, cv[1:-1].encode() - ).decode(charset, errors) + ).decode(errors="replace") out.append((ck, cv)) diff --git a/src/werkzeug/sansio/multipart.py b/src/werkzeug/sansio/multipart.py index fc8735378..731be0336 100644 --- a/src/werkzeug/sansio/multipart.py +++ b/src/werkzeug/sansio/multipart.py @@ -140,6 +140,8 @@ def receive_data(self, data: bytes | None) -> None: self.max_form_memory_size is not None and len(self.buffer) + len(data) > self.max_form_memory_size ): + # Ensure that data within single event does not exceed limit. + # Also checked across accumulated events in MultiPartParser. raise RequestEntityTooLarge() else: self.buffer.extend(data) diff --git a/src/werkzeug/sansio/request.py b/src/werkzeug/sansio/request.py index 0bcda90b2..8d5fbd8f8 100644 --- a/src/werkzeug/sansio/request.py +++ b/src/werkzeug/sansio/request.py @@ -1,7 +1,6 @@ from __future__ import annotations import typing as t -import warnings from datetime import datetime from urllib.parse import parse_qsl @@ -59,105 +58,21 @@ class Request: :param headers: The headers received with the request. :param remote_addr: The address of the client sending the request. + .. versionchanged:: 3.0 + The ``charset``, ``url_charset``, and ``encoding_errors`` attributes + were removed. + .. versionadded:: 2.0 """ - _charset: str - - @property - def charset(self) -> str: - """The charset used to decode body, form, and cookie data. Defaults to UTF-8. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Request data must always be UTF-8. - """ - warnings.warn( - "The 'charset' attribute is deprecated and will not be used in Werkzeug" - " 2.4. Interpreting bytes as text in body, form, and cookie data will" - " always use UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - return self._charset - - @charset.setter - def charset(self, value: str) -> None: - warnings.warn( - "The 'charset' attribute is deprecated and will not be used in Werkzeug" - " 2.4. Interpreting bytes as text in body, form, and cookie data will" - " always use UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - self._charset = value - - _encoding_errors: str - - @property - def encoding_errors(self) -> str: - """How errors when decoding bytes are handled. Defaults to "replace". - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. - """ - warnings.warn( - "The 'encoding_errors' attribute is deprecated and will not be used in" - " Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - return self._encoding_errors - - @encoding_errors.setter - def encoding_errors(self, value: str) -> None: - warnings.warn( - "The 'encoding_errors' attribute is deprecated and will not be used in" - " Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - self._encoding_errors = value - - _url_charset: str - - @property - def url_charset(self) -> str: - """The charset to use when decoding percent-encoded bytes in :attr:`args`. - Defaults to the value of :attr:`charset`, which defaults to UTF-8. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Percent-encoded bytes must always be UTF-8. - - .. versionadded:: 0.6 - """ - warnings.warn( - "The 'url_charset' attribute is deprecated and will not be used in" - " Werkzeug 3.0. Percent-encoded bytes must always be UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - return self._url_charset - - @url_charset.setter - def url_charset(self, value: str) -> None: - warnings.warn( - "The 'url_charset' attribute is deprecated and will not be used in" - " Werkzeug 3.0. Percent-encoded bytes must always be UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - self._url_charset = value - #: the class to use for `args` and `form`. The default is an #: :class:`~werkzeug.datastructures.ImmutableMultiDict` which supports - #: multiple values per key. alternatively it makes sense to use an - #: :class:`~werkzeug.datastructures.ImmutableOrderedMultiDict` which - #: preserves order or a :class:`~werkzeug.datastructures.ImmutableDict` - #: which is the fastest but only remembers the last key. It is also + #: multiple values per key. A :class:`~werkzeug.datastructures.ImmutableDict` + #: is faster but only remembers the last key. It is also #: possible to use mutable structures, but this is not recommended. #: #: .. versionadded:: 0.6 - parameter_storage_class: type[MultiDict] = ImmutableMultiDict + parameter_storage_class: type[MultiDict[str, t.Any]] = ImmutableMultiDict #: The type to be used for dict values from the incoming WSGI #: environment. (For example for :attr:`cookies`.) By default an @@ -167,14 +82,14 @@ def url_charset(self, value: str) -> None: #: Changed to ``ImmutableMultiDict`` to support multiple values. #: #: .. versionadded:: 0.6 - dict_storage_class: type[MultiDict] = ImmutableMultiDict + dict_storage_class: type[MultiDict[str, t.Any]] = ImmutableMultiDict #: the type to be used for list values from the incoming WSGI environment. #: By default an :class:`~werkzeug.datastructures.ImmutableList` is used #: (for example for :attr:`access_list`). #: #: .. versionadded:: 0.6 - list_storage_class: type[t.List] = ImmutableList + list_storage_class: type[list[t.Any]] = ImmutableList user_agent_class: type[UserAgent] = UserAgent """The class used and returned by the :attr:`user_agent` property to @@ -209,40 +124,6 @@ def __init__( headers: Headers, remote_addr: str | None, ) -> None: - if not isinstance(type(self).charset, property): - warnings.warn( - "The 'charset' attribute is deprecated and will not be used in Werkzeug" - " 2.4. Interpreting bytes as text in body, form, and cookie data will" - " always use UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - self._charset = self.charset - else: - self._charset = "utf-8" - - if not isinstance(type(self).encoding_errors, property): - warnings.warn( - "The 'encoding_errors' attribute is deprecated and will not be used in" - " Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - self._encoding_errors = self.encoding_errors - else: - self._encoding_errors = "replace" - - if not isinstance(type(self).url_charset, property): - warnings.warn( - "The 'url_charset' attribute is deprecated and will not be used in" - " Werkzeug 3.0. Percent-encoded bytes must always be UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - self._url_charset = self.url_charset - else: - self._url_charset = self._charset - #: The method the request was made with, such as ``GET``. self.method = method.upper() #: The URL scheme of the protocol the request used, such as @@ -291,7 +172,6 @@ def args(self) -> MultiDict[str, str]: parse_qsl( self.query_string.decode(), keep_blank_values=True, - encoding=self._url_charset, errors="werkzeug.url_quote", ) ) @@ -360,13 +240,8 @@ def cookies(self) -> ImmutableMultiDict[str, str]: """A :class:`dict` with the contents of all cookies transmitted with the request.""" wsgi_combined_cookie = ";".join(self.headers.getlist("Cookie")) - charset = self._charset if self._charset != "utf-8" else None - errors = self._encoding_errors if self._encoding_errors != "replace" else None return parse_cookie( # type: ignore - wsgi_combined_cookie, - charset=charset, - errors=errors, - cls=self.dict_storage_class, + wsgi_combined_cookie, cls=self.dict_storage_class ) # Common Descriptors diff --git a/src/werkzeug/sansio/response.py b/src/werkzeug/sansio/response.py index e5c1df743..9fed08625 100644 --- a/src/werkzeug/sansio/response.py +++ b/src/werkzeug/sansio/response.py @@ -1,38 +1,40 @@ from __future__ import annotations import typing as t -import warnings from datetime import datetime from datetime import timedelta from datetime import timezone from http import HTTPStatus +from ..datastructures import CallbackDict +from ..datastructures import ContentRange +from ..datastructures import ContentSecurityPolicy from ..datastructures import Headers from ..datastructures import HeaderSet +from ..datastructures import ResponseCacheControl +from ..datastructures import WWWAuthenticate +from ..http import COEP +from ..http import COOP +from ..http import dump_age from ..http import dump_cookie +from ..http import dump_header +from ..http import dump_options_header +from ..http import http_date from ..http import HTTP_STATUS_CODES +from ..http import parse_age +from ..http import parse_cache_control_header +from ..http import parse_content_range_header +from ..http import parse_csp_header +from ..http import parse_date +from ..http import parse_options_header +from ..http import parse_set_header +from ..http import quote_etag +from ..http import unquote_etag from ..utils import get_content_type -from werkzeug.datastructures import CallbackDict -from werkzeug.datastructures import ContentRange -from werkzeug.datastructures import ContentSecurityPolicy -from werkzeug.datastructures import ResponseCacheControl -from werkzeug.datastructures import WWWAuthenticate -from werkzeug.http import COEP -from werkzeug.http import COOP -from werkzeug.http import dump_age -from werkzeug.http import dump_header -from werkzeug.http import dump_options_header -from werkzeug.http import http_date -from werkzeug.http import parse_age -from werkzeug.http import parse_cache_control_header -from werkzeug.http import parse_content_range_header -from werkzeug.http import parse_csp_header -from werkzeug.http import parse_date -from werkzeug.http import parse_options_header -from werkzeug.http import parse_set_header -from werkzeug.http import quote_etag -from werkzeug.http import unquote_etag -from werkzeug.utils import header_property +from ..utils import header_property + +if t.TYPE_CHECKING: + from ..datastructures.cache_control import _CacheControl def _set_property(name: str, doc: str | None = None) -> property: @@ -81,36 +83,12 @@ class Response: :param content_type: The full content type of the response. Overrides building the value from ``mimetype``. + .. versionchanged:: 3.0 + The ``charset`` attribute was removed. + .. versionadded:: 2.0 """ - _charset: str - - @property - def charset(self) -> str: - """The charset used to encode body and cookie data. Defaults to UTF-8. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Response data must always be UTF-8. - """ - warnings.warn( - "The 'charset' attribute is deprecated and will not be used in Werkzeug" - " 2.4. Text in body and cookie data will always use UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - return self._charset - - @charset.setter - def charset(self, value: str) -> None: - warnings.warn( - "The 'charset' attribute is deprecated and will not be used in Werkzeug" - " 2.4. Text in body and cookie data will always use UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - self._charset = value - #: the default status if none is provided. default_status = 200 @@ -139,17 +117,6 @@ def __init__( mimetype: str | None = None, content_type: str | None = None, ) -> None: - if not isinstance(type(self).charset, property): - warnings.warn( - "The 'charset' attribute is deprecated and will not be used in Werkzeug" - " 2.4. Text in body and cookie data will always use UTF-8.", - DeprecationWarning, - stacklevel=2, - ) - self._charset = self.charset - else: - self._charset = "utf-8" - if isinstance(headers, Headers): self.headers = headers elif not headers: @@ -161,7 +128,7 @@ def __init__( if mimetype is None and "content-type" not in self.headers: mimetype = self.default_mimetype if mimetype is not None: - mimetype = get_content_type(mimetype, self._charset) + mimetype = get_content_type(mimetype, "utf-8") content_type = mimetype if content_type is not None: self.headers["Content-Type"] = content_type @@ -230,6 +197,7 @@ def set_cookie( secure: bool = False, httponly: bool = False, samesite: str | None = None, + partitioned: bool = False, ) -> None: """Sets a cookie. @@ -254,8 +222,11 @@ def set_cookie( :param httponly: Disallow JavaScript access to the cookie. :param samesite: Limit the scope of the cookie to only be attached to requests that are "same-site". + :param partitioned: If ``True``, the cookie will be partitioned. + + .. versionchanged:: 3.1 + The ``partitioned`` parameter was added. """ - charset = self._charset if self._charset != "utf-8" else None self.headers.add( "Set-Cookie", dump_cookie( @@ -267,9 +238,9 @@ def set_cookie( domain=domain, secure=secure, httponly=httponly, - charset=charset, max_size=self.max_cookie_size, samesite=samesite, + partitioned=partitioned, ), ) @@ -281,6 +252,7 @@ def delete_cookie( secure: bool = False, httponly: bool = False, samesite: str | None = None, + partitioned: bool = False, ) -> None: """Delete a cookie. Fails silently if key doesn't exist. @@ -294,6 +266,7 @@ def delete_cookie( :param httponly: Disallow JavaScript access to the cookie. :param samesite: Limit the scope of the cookie to only be attached to requests that are "same-site". + :param partitioned: If ``True``, the cookie will be partitioned. """ self.set_cookie( key, @@ -304,6 +277,7 @@ def delete_cookie( secure=secure, httponly=httponly, samesite=samesite, + partitioned=partitioned, ) @property @@ -332,7 +306,7 @@ def mimetype(self) -> str | None: @mimetype.setter def mimetype(self, value: str) -> None: - self.headers["Content-Type"] = get_content_type(value, self._charset) + self.headers["Content-Type"] = get_content_type(value, "utf-8") @property def mimetype_params(self) -> dict[str, str]: @@ -343,7 +317,7 @@ def mimetype_params(self) -> dict[str, str]: .. versionadded:: 0.5 """ - def on_update(d: CallbackDict) -> None: + def on_update(d: CallbackDict[str, str]) -> None: self.headers["Content-Type"] = dump_options_header(self.mimetype, d) d = parse_options_header(self.headers.get("content-type", ""))[1] @@ -518,7 +492,7 @@ def cache_control(self) -> ResponseCacheControl: request/response chain. """ - def on_update(cache_control: ResponseCacheControl) -> None: + def on_update(cache_control: _CacheControl) -> None: if not cache_control and "cache-control" in self.headers: del self.headers["cache-control"] elif cache_control: diff --git a/src/werkzeug/sansio/utils.py b/src/werkzeug/sansio/utils.py index 48ec1bfa0..ff7ceda34 100644 --- a/src/werkzeug/sansio/utils.py +++ b/src/werkzeug/sansio/utils.py @@ -8,7 +8,7 @@ from ..urls import uri_to_iri -def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: +def host_is_trusted(hostname: str | None, trusted_list: t.Iterable[str]) -> bool: """Check if a host matches a list of trusted names. :param hostname: The name to check. @@ -71,6 +71,9 @@ def get_host( :return: Host, with port if necessary. :raise ~werkzeug.exceptions.SecurityError: If the host is not trusted. + + .. versionchanged:: 3.1.3 + If ``SERVER_NAME`` is IPv6, it is wrapped in ``[]``. """ host = "" @@ -79,6 +82,11 @@ def get_host( elif server is not None: host = server[0] + # If SERVER_NAME is IPv6, wrap it in [] to match Host header. + # Check for : because domain or IPv4 can't have that. + if ":" in host and host[0] != "[": + host = f"[{host}]" + if server[1] is not None: host = f"{host}:{server[1]}" diff --git a/src/werkzeug/security.py b/src/werkzeug/security.py index 282c4fd8c..3f49ad1b4 100644 --- a/src/werkzeug/security.py +++ b/src/werkzeug/security.py @@ -5,10 +5,9 @@ import os import posixpath import secrets -import warnings SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -DEFAULT_PBKDF2_ITERATIONS = 600000 +DEFAULT_PBKDF2_ITERATIONS = 1_000_000 _os_alt_seps: list[str] = list( sep for sep in [os.sep, os.path.altsep] if sep is not None and sep != "/" @@ -24,17 +23,9 @@ def gen_salt(length: int) -> str: def _hash_internal(method: str, salt: str, password: str) -> tuple[str, str]: - if method == "plain": - warnings.warn( - "The 'plain' password method is deprecated and will be removed in" - " Werkzeug 3.0. Migrate to the 'scrypt' method.", - stacklevel=3, - ) - return password, method - method, *args = method.split(":") - salt = salt.encode("utf-8") - password = password.encode("utf-8") + salt_bytes = salt.encode() + password_bytes = password.encode() if method == "scrypt": if not args: @@ -49,7 +40,9 @@ def _hash_internal(method: str, salt: str, password: str) -> tuple[str, str]: maxmem = 132 * n * r * p # ideally 128, but some extra seems needed return ( - hashlib.scrypt(password, salt=salt, n=n, r=r, p=p, maxmem=maxmem).hex(), + hashlib.scrypt( + password_bytes, salt=salt_bytes, n=n, r=r, p=p, maxmem=maxmem + ).hex(), f"scrypt:{n}:{r}:{p}", ) elif method == "pbkdf2": @@ -68,30 +61,26 @@ def _hash_internal(method: str, salt: str, password: str) -> tuple[str, str]: raise ValueError("'pbkdf2' takes 2 arguments.") return ( - hashlib.pbkdf2_hmac(hash_name, password, salt, iterations).hex(), + hashlib.pbkdf2_hmac( + hash_name, password_bytes, salt_bytes, iterations + ).hex(), f"pbkdf2:{hash_name}:{iterations}", ) else: - warnings.warn( - f"The '{method}' password method is deprecated and will be removed in" - " Werkzeug 3.0. Migrate to the 'scrypt' method.", - stacklevel=3, - ) - return hmac.new(salt, password, method).hexdigest(), method + raise ValueError(f"Invalid hash method '{method}'.") def generate_password_hash( - password: str, method: str = "pbkdf2", salt_length: int = 16 + password: str, method: str = "scrypt", salt_length: int = 16 ) -> str: """Securely hash a password for storage. A password can be compared to a stored hash using :func:`check_password_hash`. The following methods are supported: - - ``scrypt``, more secure but not available on PyPy. The parameters are ``n``, - ``r``, and ``p``, the default is ``scrypt:32768:8:1``. See - :func:`hashlib.scrypt`. - - ``pbkdf2``, the default. The parameters are ``hash_method`` and ``iterations``, + - ``scrypt``, the default. The parameters are ``n``, ``r``, and ``p``, the default + is ``scrypt:32768:8:1``. See :func:`hashlib.scrypt`. + - ``pbkdf2``, less secure. The parameters are ``hash_method`` and ``iterations``, the default is ``pbkdf2:sha256:600000``. See :func:`hashlib.pbkdf2_hmac`. Default parameters may be updated to reflect current guidelines, and methods may be @@ -103,6 +92,9 @@ def generate_password_hash( :param method: The key derivation function and parameters. :param salt_length: The number of characters to generate for the salt. + .. versionchanged:: 3.1 + The default iterations for pbkdf2 was increased to 1,000,000. + .. versionchanged:: 2.3 Scrypt support was added. @@ -162,6 +154,8 @@ def safe_join(directory: str, *pathnames: str) -> str | None: if ( any(sep in filename for sep in _os_alt_seps) or os.path.isabs(filename) + # ntpath.isabs doesn't catch this on Python < 3.11 + or filename.startswith("/") or filename == ".." or filename.startswith("../") ): diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py index c031dc45e..ec166408e 100644 --- a/src/werkzeug/serving.py +++ b/src/werkzeug/serving.py @@ -11,6 +11,7 @@ from myapp import create_app from werkzeug import run_simple """ + from __future__ import annotations import errno @@ -36,6 +37,12 @@ try: import ssl + + connection_dropped_errors: tuple[type[Exception], ...] = ( + ConnectionError, + socket.timeout, + ssl.SSLEOFError, + ) except ImportError: class _SslDummy: @@ -46,6 +53,7 @@ def __getattr__(self, name: str) -> t.Any: ) ssl = _SslDummy() # type: ignore + connection_dropped_errors = (ConnectionError, socket.timeout) _log_add_style = True @@ -73,7 +81,7 @@ class ForkingMixIn: # type: ignore LISTEN_QUEUE = 128 _TSSLContextArg = t.Optional[ - t.Union["ssl.SSLContext", t.Tuple[str, t.Optional[str]], t.Literal["adhoc"]] + t.Union["ssl.SSLContext", tuple[str, t.Optional[str]], t.Literal["adhoc"]] ] if t.TYPE_CHECKING: @@ -154,9 +162,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): @property def server_version(self) -> str: # type: ignore - from . import __version__ - - return f"Werkzeug/{__version__}" + return self.server._server_version def make_environ(self) -> WSGIEnvironment: request_url = urlsplit(self.path) @@ -362,7 +368,7 @@ def execute(app: WSGIApplication) -> None: try: execute(self.server.app) - except (ConnectionError, socket.timeout) as e: + except connection_dropped_errors as e: self.connection_dropped(e, environ) except Exception as e: if self.server.passthrough_errors: @@ -467,9 +473,11 @@ def log_message(self, format: str, *args: t.Any) -> None: self.log("info", format, *args) def log(self, type: str, message: str, *args: t.Any) -> None: + # an IPv6 scoped address contains "%" which breaks logging + address_string = self.address_string().replace("%", "%%") _log( type, - f"{self.address_string()} - - [{self.log_date_time_string()}] {message}\n", + f"{address_string} - - [{self.log_date_time_string()}] {message}\n", *args, ) @@ -498,10 +506,10 @@ def generate_adhoc_ssl_pair( ) -> tuple[Certificate, RSAPrivateKeyWithSerialization]: try: from cryptography import x509 - from cryptography.x509.oid import NameOID from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID except ImportError: raise TypeError( "Using ad-hoc certificates requires the cryptography library." @@ -533,7 +541,10 @@ def generate_adhoc_ssl_pair( .not_valid_before(dt.now(timezone.utc)) .not_valid_after(dt.now(timezone.utc) + timedelta(days=365)) .add_extension(x509.ExtendedKeyUsage([x509.OID_SERVER_AUTH]), critical=False) - .add_extension(x509.SubjectAlternativeName([x509.DNSName(cn)]), critical=False) + .add_extension( + x509.SubjectAlternativeName([x509.DNSName(cn), x509.DNSName(f"*.{cn}")]), + critical=False, + ) .sign(pkey, hashes.SHA256(), backend) ) return cert, pkey @@ -561,7 +572,7 @@ def make_ssl_devcert( """ if host is not None: - cn = f"*.{host}/CN={host}" + cn = host cert, pkey = generate_adhoc_ssl_pair(cn=cn) from cryptography.hazmat.primitives import serialization @@ -585,8 +596,8 @@ def make_ssl_devcert( def generate_adhoc_ssl_context() -> ssl.SSLContext: """Generates an adhoc SSL context for the development server.""" - import tempfile import atexit + import tempfile cert, pkey = generate_adhoc_ssl_pair() @@ -764,7 +775,7 @@ def __init__( if sys.platform == "darwin" and port == 5000: print( "On macOS, try disabling the 'AirPlay Receiver' service" - " from System Preferences -> Sharing.", + " from System Preferences -> General -> AirDrop & Handoff.", file=sys.stderr, ) @@ -796,6 +807,10 @@ def __init__( else: self.ssl_context = None + import importlib.metadata + + self._server_version = f"Werkzeug/{importlib.metadata.version('werkzeug')}" + def log(self, type: str, message: str, *args: t.Any) -> None: _log(type, message, *args) @@ -1066,6 +1081,9 @@ def run_simple( from .debug import DebuggedApplication application = DebuggedApplication(application, evalex=use_evalex) + # Allow the specified hostname to use the debugger, in addition to + # localhost domains. + application.trusted_hosts.append(hostname) if not is_running_from_reloader(): fd = None diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py index 968553f2b..5c3c60883 100644 --- a/src/werkzeug/test.py +++ b/src/werkzeug/test.py @@ -4,7 +4,6 @@ import mimetypes import sys import typing as t -import warnings from collections import defaultdict from datetime import datetime from io import BytesIO @@ -17,7 +16,6 @@ from urllib.parse import urlunsplit from ._internal import _get_environ -from ._internal import _make_encode_wrapper from ._internal import _wsgi_decoding_dance from ._internal import _wsgi_encoding_dance from .datastructures import Authorization @@ -48,9 +46,9 @@ from .wsgi import get_current_url if t.TYPE_CHECKING: + import typing_extensions as te from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment - import typing_extensions as te def stream_encode_multipart( @@ -58,24 +56,14 @@ def stream_encode_multipart( use_tempfile: bool = True, threshold: int = 1024 * 500, boundary: str | None = None, - charset: str | None = None, ) -> tuple[t.IO[bytes], int, str]: """Encode a dict of values (either strings or file descriptors or :class:`FileStorage` objects.) into a multipart encoded string stored in a file descriptor. - .. versionchanged:: 2.3 - The ``charset`` parameter is deprecated and will be removed in Werkzeug 3.0 + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. """ - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be removed in Werkzeug 3.0", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - if boundary is None: boundary = f"---------------WerkzeugFormPart_{time()}{random()}" @@ -144,9 +132,7 @@ def write_binary(s: bytes) -> int: if not isinstance(value, str): value = str(value) write_binary(encoder.send_event(Field(name=key, headers=Headers()))) - write_binary( - encoder.send_event(Data(data=value.encode(charset), more_data=False)) - ) + write_binary(encoder.send_event(Data(data=value.encode(), more_data=False))) write_binary(encoder.send_event(Epilogue(data=b""))) @@ -156,18 +142,16 @@ def write_binary(s: bytes) -> int: def encode_multipart( - values: t.Mapping[str, t.Any], - boundary: str | None = None, - charset: str | None = None, + values: t.Mapping[str, t.Any], boundary: str | None = None ) -> tuple[str, bytes]: """Like `stream_encode_multipart` but returns a tuple in the form (``boundary``, ``data``) where data is bytes. - .. versionchanged:: 2.3 - The ``charset`` parameter is deprecated and will be removed in Werkzeug 3.0 + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. """ stream, length, boundary = stream_encode_multipart( - values, use_tempfile=False, boundary=boundary, charset=charset + values, use_tempfile=False, boundary=boundary ) return boundary, stream.read() @@ -188,7 +172,7 @@ def _iter_data(data: t.Mapping[str, t.Any]) -> t.Iterator[tuple[str, t.Any]]: yield key, value -_TAnyMultiDict = t.TypeVar("_TAnyMultiDict", bound=MultiDict) +_TAnyMultiDict = t.TypeVar("_TAnyMultiDict", bound="MultiDict[t.Any, t.Any]") class EnvironBuilder: @@ -259,8 +243,8 @@ class EnvironBuilder: ``Authorization`` header value. A ``(username, password)`` tuple is a shortcut for ``Basic`` authorization. - .. versionchanged:: 2.3 - The ``charset`` parameter is deprecated and will be removed in Werkzeug 3.0 + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. .. versionchanged:: 2.1 ``CONTENT_TYPE`` and ``CONTENT_LENGTH`` are not duplicated as @@ -305,10 +289,10 @@ class EnvironBuilder: json_dumps = staticmethod(json.dumps) del json - _args: MultiDict | None + _args: MultiDict[str, str] | None _query_string: str | None _input_stream: t.IO[bytes] | None - _form: MultiDict | None + _form: MultiDict[str, str] | None _files: FileMultiDict | None def __init__( @@ -328,35 +312,20 @@ def __init__( data: None | (t.IO[bytes] | str | bytes | t.Mapping[str, t.Any]) = None, environ_base: t.Mapping[str, t.Any] | None = None, environ_overrides: t.Mapping[str, t.Any] | None = None, - charset: str | None = None, mimetype: str | None = None, json: t.Mapping[str, t.Any] | None = None, auth: Authorization | tuple[str, str] | None = None, ) -> None: - path_s = _make_encode_wrapper(path) - if query_string is not None and path_s("?") in path: + if query_string is not None and "?" in path: raise ValueError("Query string is defined in the path and as an argument") request_uri = urlsplit(path) - if query_string is None and path_s("?") in path: + if query_string is None and "?" in path: query_string = request_uri.query - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be" - " removed in Werkzeug 3.0", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - self.charset = charset self.path = iri_to_uri(request_uri.path) self.request_uri = path if base_url is not None: - base_url = iri_to_uri( - base_url, charset=charset if charset != "utf-8" else None - ) + base_url = iri_to_uri(base_url) self.base_url = base_url # type: ignore if isinstance(query_string, str): self.query_string = query_string @@ -409,7 +378,7 @@ def __init__( if hasattr(data, "read"): data = data.read() if isinstance(data, str): - data = data.encode(self.charset) + data = data.encode() if isinstance(data, bytes): self.input_stream = BytesIO(data) if self.content_length is None: @@ -526,7 +495,7 @@ def mimetype(self) -> str | None: @mimetype.setter def mimetype(self, value: str) -> None: - self.content_type = get_content_type(value, self.charset) + self.content_type = get_content_type(value, "utf-8") @property def mimetype_params(self) -> t.Mapping[str, str]: @@ -537,7 +506,7 @@ def mimetype_params(self) -> t.Mapping[str, str]: .. versionadded:: 0.14 """ - def on_update(d: CallbackDict) -> None: + def on_update(d: CallbackDict[str, str]) -> None: self.headers["Content-Type"] = dump_options_header(self.mimetype, d) d = parse_options_header(self.headers.get("content-type", ""))[1] @@ -576,7 +545,7 @@ def _get_form(self, name: str, storage: type[_TAnyMultiDict]) -> _TAnyMultiDict: return rv # type: ignore - def _set_form(self, name: str, value: MultiDict) -> None: + def _set_form(self, name: str, value: MultiDict[str, t.Any]) -> None: """Common behavior for setting the :attr:`form` and :attr:`files` properties. @@ -587,12 +556,12 @@ def _set_form(self, name: str, value: MultiDict) -> None: setattr(self, name, value) @property - def form(self) -> MultiDict: + def form(self) -> MultiDict[str, str]: """A :class:`MultiDict` of form values.""" return self._get_form("_form", MultiDict) @form.setter - def form(self, value: MultiDict) -> None: + def form(self, value: MultiDict[str, str]) -> None: self._set_form("_form", value) @property @@ -628,7 +597,7 @@ def query_string(self) -> str: """ if self._query_string is None: if self._args is not None: - return _urlencode(self._args, encoding=self.charset) + return _urlencode(self._args) return "" return self._query_string @@ -638,7 +607,7 @@ def query_string(self, value: str | None) -> None: self._args = None @property - def args(self) -> MultiDict: + def args(self) -> MultiDict[str, str]: """The URL arguments as :class:`MultiDict`.""" if self._query_string is not None: raise AttributeError("a query string is defined") @@ -647,7 +616,7 @@ def args(self) -> MultiDict: return self._args @args.setter - def args(self, value: MultiDict | None) -> None: + def args(self, value: MultiDict[str, str] | None) -> None: self._query_string = None self._args = value @@ -687,7 +656,7 @@ def close(self) -> None: try: files = self.files.values() except AttributeError: - files = () # type: ignore + files = () for f in files: try: f.close() @@ -716,13 +685,12 @@ def get_environ(self) -> WSGIEnvironment: input_stream.seek(start_pos) content_length = end_pos - start_pos elif mimetype == "multipart/form-data": - charset = self.charset if self.charset != "utf-8" else None input_stream, content_length, boundary = stream_encode_multipart( - CombinedMultiDict([self.form, self.files]), charset=charset + CombinedMultiDict([self.form, self.files]) ) content_type = f'{mimetype}; boundary="{boundary}"' elif mimetype == "application/x-www-form-urlencoded": - form_encoded = _urlencode(self.form, encoding=self.charset).encode("ascii") + form_encoded = _urlencode(self.form).encode("ascii") content_length = len(form_encoded) input_stream = BytesIO(form_encoded) else: @@ -733,15 +701,15 @@ def get_environ(self) -> WSGIEnvironment: result.update(self.environ_base) def _path_encode(x: str) -> str: - return _wsgi_encoding_dance(unquote(x, encoding=self.charset), self.charset) + return _wsgi_encoding_dance(unquote(x)) - raw_uri = _wsgi_encoding_dance(self.request_uri, self.charset) + raw_uri = _wsgi_encoding_dance(self.request_uri) result.update( { "REQUEST_METHOD": self.method, "SCRIPT_NAME": _path_encode(self.script_root), "PATH_INFO": _path_encode(self.path), - "QUERY_STRING": _wsgi_encoding_dance(self.query_string, self.charset), + "QUERY_STRING": _wsgi_encoding_dance(self.query_string), # Non-standard, added by mod_wsgi, uWSGI "REQUEST_URI": raw_uri, # Non-standard, added by gunicorn @@ -841,14 +809,16 @@ def __init__( if response_wrapper in {None, Response}: response_wrapper = TestResponse - elif not isinstance(response_wrapper, TestResponse): + elif response_wrapper is not None and not issubclass( + response_wrapper, TestResponse + ): response_wrapper = type( "WrapperTestResponse", - (TestResponse, response_wrapper), # type: ignore + (TestResponse, response_wrapper), {}, ) - self.response_wrapper = t.cast(t.Type["TestResponse"], response_wrapper) + self.response_wrapper = t.cast(type["TestResponse"], response_wrapper) if use_cookies: self._cookies: dict[tuple[str, str, str], Cookie] | None = {} @@ -857,20 +827,6 @@ def __init__( self.allow_subdomain_redirects = allow_subdomain_redirects - @property - def cookie_jar(self) -> t.Iterable[Cookie] | None: - warnings.warn( - "The 'cookie_jar' attribute is a private API and will be removed in" - " Werkzeug 3.0. Use the 'get_cookie' method instead.", - DeprecationWarning, - stacklevel=2, - ) - - if self._cookies is None: - return None - - return self._cookies.values() - def get_cookie( self, key: str, domain: str = "localhost", path: str = "/" ) -> Cookie | None: @@ -894,7 +850,7 @@ def set_cookie( self, key: str, value: str = "", - *args: t.Any, + *, domain: str = "localhost", origin_only: bool = True, path: str = "/", @@ -920,34 +876,21 @@ def set_cookie( or as a prefix. :param kwargs: Passed to :func:`.dump_cookie`. + .. versionchanged:: 3.0 + The parameter ``server_name`` is removed. The first parameter is + ``key``. Use the ``domain`` and ``origin_only`` parameters instead. + .. versionchanged:: 2.3 The ``origin_only`` parameter was added. .. versionchanged:: 2.3 The ``domain`` parameter defaults to ``localhost``. - - .. versionchanged:: 2.3 - The first parameter ``server_name`` is deprecated and will be removed in - Werkzeug 3.0. The first parameter is ``key``. Use the ``domain`` and - ``origin_only`` parameters instead. """ if self._cookies is None: raise TypeError( "Cookies are disabled. Create a client with 'use_cookies=True'." ) - if args: - warnings.warn( - "The first parameter 'server_name' is no longer used, and will be" - " removed in Werkzeug 3.0. The positional parameters are 'key' and" - " 'value'. Use the 'domain' and 'origin_only' parameters instead.", - DeprecationWarning, - stacklevel=2, - ) - domain = key - key = value - value = args[0] - cookie = Cookie._from_response_header( domain, "/", dump_cookie(key, value, domain=domain, path=path, **kwargs) ) @@ -961,10 +904,9 @@ def set_cookie( def delete_cookie( self, key: str, - *args: t.Any, + *, domain: str = "localhost", path: str = "/", - **kwargs: t.Any, ) -> None: """Delete a cookie if it exists. Cookies are uniquely identified by ``(domain, path, key)``. @@ -973,44 +915,21 @@ def delete_cookie( :param domain: The domain the cookie was set for. :param path: The path the cookie was set for. - .. versionchanged:: 2.3 - The ``domain`` parameter defaults to ``localhost``. + .. versionchanged:: 3.0 + The ``server_name`` parameter is removed. The first parameter is + ``key``. Use the ``domain`` parameter instead. - .. versionchanged:: 2.3 - The first parameter ``server_name`` is deprecated and will be removed in - Werkzeug 3.0. The first parameter is ``key``. Use the ``domain`` parameter - instead. + .. versionchanged:: 3.0 + The ``secure``, ``httponly`` and ``samesite`` parameters are removed. .. versionchanged:: 2.3 - The ``secure``, ``httponly`` and ``samesite`` parameters are deprecated and - will be removed in Werkzeug 2.4. + The ``domain`` parameter defaults to ``localhost``. """ if self._cookies is None: raise TypeError( "Cookies are disabled. Create a client with 'use_cookies=True'." ) - if args: - warnings.warn( - "The first parameter 'server_name' is no longer used, and will be" - " removed in Werkzeug 2.4. The first parameter is 'key'. Use the" - " 'domain' parameter instead.", - DeprecationWarning, - stacklevel=2, - ) - domain = key - key = args[0] - - if kwargs: - kwargs_keys = ", ".join(f"'{k}'" for k in kwargs) - plural = "parameters are" if len(kwargs) > 1 else "parameter is" - warnings.warn( - f"The {kwargs_keys} {plural} deprecated and will be" - f" removed in Werkzeug 2.4.", - DeprecationWarning, - stacklevel=2, - ) - self._cookies.pop((domain, path, key), None) def _add_cookies_to_wsgi(self, environ: WSGIEnvironment) -> None: @@ -1194,8 +1113,8 @@ def open( finally: builder.close() - response = self.run_wsgi_app(request.environ, buffered=buffered) - response = self.response_wrapper(*response, request=request) + response_parts = self.run_wsgi_app(request.environ, buffered=buffered) + response = self.response_wrapper(*response_parts, request=request) redirects = set() history: list[TestResponse] = [] @@ -1512,7 +1431,7 @@ def _to_request_header(self) -> str: def _from_response_header(cls, server_name: str, path: str, header: str) -> te.Self: header, _, parameters_str = header.partition(";") key, _, value = header.partition("=") - decoded_key, decoded_value = next(parse_cookie(header).items()) + decoded_key, decoded_value = next(parse_cookie(header).items()) # type: ignore[call-overload] params = {} for item in parameters_str.split(";"): diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py index 57f1f6fdf..cdf7fac1a 100644 --- a/src/werkzeug/testapp.py +++ b/src/werkzeug/testapp.py @@ -1,8 +1,10 @@ """A small application that can be used to test a WSGI server and check it for WSGI compliance. """ + from __future__ import annotations +import importlib.metadata import os import sys import typing as t @@ -10,7 +12,6 @@ from markupsafe import escape -from . import __version__ as _werkzeug_version from .wrappers.request import Request from .wrappers.response import Response @@ -153,13 +154,13 @@ def test_app(req: Request) -> Response: sys_path = [] for item, virtual, expanded in iter_sys_path(): - class_ = [] + css = [] if virtual: - class_.append("virtual") + css.append("virtual") if expanded: - class_.append("exp") - class_ = f' class="{" ".join(class_)}"' if class_ else "" - sys_path.append(f"{escape(item)}") + css.append("exp") + class_str = f' class="{" ".join(css)}"' if css else "" + sys_path.append(f"{escape(item)}") context = { "python_version": "
".join(escape(sys.version).splitlines()), @@ -167,7 +168,7 @@ def test_app(req: Request) -> Response: "os": escape(os.name), "api_version": sys.api_version, "byteorder": sys.byteorder, - "werkzeug_version": _werkzeug_version, + "werkzeug_version": _get_werkzeug_version(), "python_eggs": "\n".join(python_eggs), "wsgi_env": "\n".join(wsgi_env), "sys_path": "\n".join(sys_path), @@ -175,6 +176,18 @@ def test_app(req: Request) -> Response: return Response(TEMPLATE % context, mimetype="text/html") +_werkzeug_version = "" + + +def _get_werkzeug_version() -> str: + global _werkzeug_version + + if not _werkzeug_version: + _werkzeug_version = importlib.metadata.version("werkzeug") + + return _werkzeug_version + + if __name__ == "__main__": from .serving import run_simple diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py index f5760eb4c..5bffe3928 100644 --- a/src/werkzeug/urls.py +++ b/src/werkzeug/urls.py @@ -1,796 +1,17 @@ -"""Functions for working with URLs. - -Contains implementations of functions from :mod:`urllib.parse` that -handle bytes and strings. -""" from __future__ import annotations import codecs -import os import re import typing as t -import warnings +import urllib.parse from urllib.parse import quote from urllib.parse import unquote from urllib.parse import urlencode from urllib.parse import urlsplit from urllib.parse import urlunsplit -from ._internal import _check_str_tuple -from ._internal import _decode_idna -from ._internal import _make_encode_wrapper -from ._internal import _to_str from .datastructures import iter_multi_items -if t.TYPE_CHECKING: - from . import datastructures as ds - -# A regular expression for what a valid schema looks like -_scheme_re = re.compile(r"^[a-zA-Z0-9+-.]+$") - -# Characters that are safe in any part of an URL. -_always_safe_chars = ( - "abcdefghijklmnopqrstuvwxyz" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "0123456789" - "-._~" - "$!'()*+,;" # RFC3986 sub-delims set, not including query string delimiters &= -) -_always_safe = frozenset(_always_safe_chars.encode("ascii")) - -_hexdigits = "0123456789ABCDEFabcdef" -_hextobyte = { - f"{a}{b}".encode("ascii"): int(f"{a}{b}", 16) - for a in _hexdigits - for b in _hexdigits -} -_bytetohex = [f"%{char:02X}".encode("ascii") for char in range(256)] - - -class _URLTuple(t.NamedTuple): - scheme: str - netloc: str - path: str - query: str - fragment: str - - -class BaseURL(_URLTuple): - """Superclass of :py:class:`URL` and :py:class:`BytesURL`. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use the ``urllib.parse`` library instead. - """ - - __slots__ = () - _at: str - _colon: str - _lbracket: str - _rbracket: str - - def __new__(cls, *args: t.Any, **kwargs: t.Any) -> BaseURL: - warnings.warn( - f"'werkzeug.urls.{cls.__name__}' is deprecated and will be removed in" - " Werkzeug 3.0. Use the 'urllib.parse' library instead.", - DeprecationWarning, - stacklevel=2, - ) - return super().__new__(cls, *args, **kwargs) - - def __str__(self) -> str: - return self.to_url() - - def replace(self, **kwargs: t.Any) -> BaseURL: - """Return an URL with the same values, except for those parameters - given new values by whichever keyword arguments are specified.""" - return self._replace(**kwargs) - - @property - def host(self) -> str | None: - """The host part of the URL if available, otherwise `None`. The - host is either the hostname or the IP address mentioned in the - URL. It will not contain the port. - """ - return self._split_host()[0] - - @property - def ascii_host(self) -> str | None: - """Works exactly like :attr:`host` but will return a result that - is restricted to ASCII. If it finds a netloc that is not ASCII - it will attempt to idna decode it. This is useful for socket - operations when the URL might include internationalized characters. - """ - rv = self.host - if rv is not None and isinstance(rv, str): - try: - rv = rv.encode("idna").decode("ascii") - except UnicodeError: - pass - return rv - - @property - def port(self) -> int | None: - """The port in the URL as an integer if it was present, `None` - otherwise. This does not fill in default ports. - """ - try: - rv = int(_to_str(self._split_host()[1])) - if 0 <= rv <= 65535: - return rv - except (ValueError, TypeError): - pass - return None - - @property - def auth(self) -> str | None: - """The authentication part in the URL if available, `None` - otherwise. - """ - return self._split_netloc()[0] - - @property - def username(self) -> str | None: - """The username if it was part of the URL, `None` otherwise. - This undergoes URL decoding and will always be a string. - """ - rv = self._split_auth()[0] - if rv is not None: - return _url_unquote_legacy(rv) - return None - - @property - def raw_username(self) -> str | None: - """The username if it was part of the URL, `None` otherwise. - Unlike :attr:`username` this one is not being decoded. - """ - return self._split_auth()[0] - - @property - def password(self) -> str | None: - """The password if it was part of the URL, `None` otherwise. - This undergoes URL decoding and will always be a string. - """ - rv = self._split_auth()[1] - if rv is not None: - return _url_unquote_legacy(rv) - return None - - @property - def raw_password(self) -> str | None: - """The password if it was part of the URL, `None` otherwise. - Unlike :attr:`password` this one is not being decoded. - """ - return self._split_auth()[1] - - def decode_query(self, *args: t.Any, **kwargs: t.Any) -> ds.MultiDict[str, str]: - """Decodes the query part of the URL. Ths is a shortcut for - calling :func:`url_decode` on the query argument. The arguments and - keyword arguments are forwarded to :func:`url_decode` unchanged. - """ - return url_decode(self.query, *args, **kwargs) - - def join(self, *args: t.Any, **kwargs: t.Any) -> BaseURL: - """Joins this URL with another one. This is just a convenience - function for calling into :meth:`url_join` and then parsing the - return value again. - """ - return url_parse(url_join(self, *args, **kwargs)) - - def to_url(self) -> str: - """Returns a URL string or bytes depending on the type of the - information stored. This is just a convenience function - for calling :meth:`url_unparse` for this URL. - """ - return url_unparse(self) - - def encode_netloc(self) -> str: - """Encodes the netloc part to an ASCII safe URL as bytes.""" - rv = self.ascii_host or "" - if ":" in rv: - rv = f"[{rv}]" - port = self.port - if port is not None: - rv = f"{rv}:{port}" - auth = ":".join( - filter( - None, - [ - url_quote(self.raw_username or "", "utf-8", "strict", "/:%"), - url_quote(self.raw_password or "", "utf-8", "strict", "/:%"), - ], - ) - ) - if auth: - rv = f"{auth}@{rv}" - return rv - - def decode_netloc(self) -> str: - """Decodes the netloc part into a string.""" - host = self.host or "" - - if isinstance(host, bytes): - host = host.decode() - - rv = _decode_idna(host) - - if ":" in rv: - rv = f"[{rv}]" - port = self.port - if port is not None: - rv = f"{rv}:{port}" - auth = ":".join( - filter( - None, - [ - _url_unquote_legacy(self.raw_username or "", "/:%@"), - _url_unquote_legacy(self.raw_password or "", "/:%@"), - ], - ) - ) - if auth: - rv = f"{auth}@{rv}" - return rv - - def to_uri_tuple(self) -> BaseURL: - """Returns a :class:`BytesURL` tuple that holds a URI. This will - encode all the information in the URL properly to ASCII using the - rules a web browser would follow. - - It's usually more interesting to directly call :meth:`iri_to_uri` which - will return a string. - """ - return url_parse(iri_to_uri(self)) - - def to_iri_tuple(self) -> BaseURL: - """Returns a :class:`URL` tuple that holds a IRI. This will try - to decode as much information as possible in the URL without - losing information similar to how a web browser does it for the - URL bar. - - It's usually more interesting to directly call :meth:`uri_to_iri` which - will return a string. - """ - return url_parse(uri_to_iri(self)) - - def get_file_location( - self, pathformat: str | None = None - ) -> tuple[str | None, str | None]: - """Returns a tuple with the location of the file in the form - ``(server, location)``. If the netloc is empty in the URL or - points to localhost, it's represented as ``None``. - - The `pathformat` by default is autodetection but needs to be set - when working with URLs of a specific system. The supported values - are ``'windows'`` when working with Windows or DOS paths and - ``'posix'`` when working with posix paths. - - If the URL does not point to a local file, the server and location - are both represented as ``None``. - - :param pathformat: The expected format of the path component. - Currently ``'windows'`` and ``'posix'`` are - supported. Defaults to ``None`` which is - autodetect. - """ - if self.scheme != "file": - return None, None - - path = url_unquote(self.path) - host = self.netloc or None - - if pathformat is None: - if os.name == "nt": - pathformat = "windows" - else: - pathformat = "posix" - - if pathformat == "windows": - if path[:1] == "/" and path[1:2].isalpha() and path[2:3] in "|:": - path = f"{path[1:2]}:{path[3:]}" - windows_share = path[:3] in ("\\" * 3, "/" * 3) - import ntpath - - path = ntpath.normpath(path) - # Windows shared drives are represented as ``\\host\\directory``. - # That results in a URL like ``file://///host/directory``, and a - # path like ``///host/directory``. We need to special-case this - # because the path contains the hostname. - if windows_share and host is None: - parts = path.lstrip("\\").split("\\", 1) - if len(parts) == 2: - host, path = parts - else: - host = parts[0] - path = "" - elif pathformat == "posix": - import posixpath - - path = posixpath.normpath(path) - else: - raise TypeError(f"Invalid path format {pathformat!r}") - - if host in ("127.0.0.1", "::1", "localhost"): - host = None - - return host, path - - def _split_netloc(self) -> tuple[str | None, str]: - if self._at in self.netloc: - auth, _, netloc = self.netloc.partition(self._at) - return auth, netloc - return None, self.netloc - - def _split_auth(self) -> tuple[str | None, str | None]: - auth = self._split_netloc()[0] - if not auth: - return None, None - if self._colon not in auth: - return auth, None - - username, _, password = auth.partition(self._colon) - return username, password - - def _split_host(self) -> tuple[str | None, str | None]: - rv = self._split_netloc()[1] - if not rv: - return None, None - - if not rv.startswith(self._lbracket): - if self._colon in rv: - host, _, port = rv.partition(self._colon) - return host, port - return rv, None - - idx = rv.find(self._rbracket) - if idx < 0: - return rv, None - - host = rv[1:idx] - rest = rv[idx + 1 :] - if rest.startswith(self._colon): - return host, rest[1:] - return host, None - - -class URL(BaseURL): - """Represents a parsed URL. This behaves like a regular tuple but - also has some extra attributes that give further insight into the - URL. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use the ``urllib.parse`` library instead. - """ - - __slots__ = () - _at = "@" - _colon = ":" - _lbracket = "[" - _rbracket = "]" - - def encode(self, charset: str = "utf-8", errors: str = "replace") -> BytesURL: - """Encodes the URL to a tuple made out of bytes. The charset is - only being used for the path, query and fragment. - """ - return BytesURL( - self.scheme.encode("ascii"), - self.encode_netloc(), - self.path.encode(charset, errors), - self.query.encode(charset, errors), - self.fragment.encode(charset, errors), - ) - - -class BytesURL(BaseURL): - """Represents a parsed URL in bytes. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use the ``urllib.parse`` library instead. - """ - - __slots__ = () - _at = b"@" # type: ignore - _colon = b":" # type: ignore - _lbracket = b"[" # type: ignore - _rbracket = b"]" # type: ignore - - def __str__(self) -> str: - return self.to_url().decode("utf-8", "replace") # type: ignore - - def encode_netloc(self) -> bytes: # type: ignore - """Returns the netloc unchanged as bytes.""" - return self.netloc # type: ignore - - def decode(self, charset: str = "utf-8", errors: str = "replace") -> URL: - """Decodes the URL to a tuple made out of strings. The charset is - only being used for the path, query and fragment. - """ - return URL( - self.scheme.decode("ascii"), # type: ignore - self.decode_netloc(), - self.path.decode(charset, errors), # type: ignore - self.query.decode(charset, errors), # type: ignore - self.fragment.decode(charset, errors), # type: ignore - ) - - -_unquote_maps: dict[frozenset[int], dict[bytes, int]] = {frozenset(): _hextobyte} - - -def _unquote_to_bytes(string: str | bytes, unsafe: str | bytes = "") -> bytes: - if isinstance(string, str): - string = string.encode("utf-8") - - if isinstance(unsafe, str): - unsafe = unsafe.encode("utf-8") - - unsafe = frozenset(bytearray(unsafe)) - groups = iter(string.split(b"%")) - result = bytearray(next(groups, b"")) - - try: - hex_to_byte = _unquote_maps[unsafe] - except KeyError: - hex_to_byte = _unquote_maps[unsafe] = { - h: b for h, b in _hextobyte.items() if b not in unsafe - } - - for group in groups: - code = group[:2] - - if code in hex_to_byte: - result.append(hex_to_byte[code]) - result.extend(group[2:]) - else: - result.append(37) # % - result.extend(group) - - return bytes(result) - - -def _url_encode_impl( - obj: t.Mapping[str, str] | t.Iterable[tuple[str, str]], - charset: str, - sort: bool, - key: t.Callable[[tuple[str, str]], t.Any] | None, -) -> t.Iterator[str]: - from .datastructures import iter_multi_items - - iterable: t.Iterable[tuple[str, str]] = iter_multi_items(obj) - - if sort: - iterable = sorted(iterable, key=key) - - for key_str, value_str in iterable: - if value_str is None: - continue - - if not isinstance(key_str, bytes): - key_bytes = str(key_str).encode(charset) - else: - key_bytes = key_str - - if not isinstance(value_str, bytes): - value_bytes = str(value_str).encode(charset) - else: - value_bytes = value_str - - yield f"{_fast_url_quote_plus(key_bytes)}={_fast_url_quote_plus(value_bytes)}" - - -def _url_unquote_legacy(value: str, unsafe: str = "") -> str: - try: - return url_unquote(value, charset="utf-8", errors="strict", unsafe=unsafe) - except UnicodeError: - return url_unquote(value, charset="latin1", unsafe=unsafe) - - -def url_parse( - url: str, scheme: str | None = None, allow_fragments: bool = True -) -> BaseURL: - """Parses a URL from a string into a :class:`URL` tuple. If the URL - is lacking a scheme it can be provided as second argument. Otherwise, - it is ignored. Optionally fragments can be stripped from the URL - by setting `allow_fragments` to `False`. - - The inverse of this function is :func:`url_unparse`. - - :param url: the URL to parse. - :param scheme: the default schema to use if the URL is schemaless. - :param allow_fragments: if set to `False` a fragment will be removed - from the URL. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use ``urllib.parse.urlsplit`` instead. - """ - warnings.warn( - "'werkzeug.urls.url_parse' is deprecated and will be removed in Werkzeug 3.0." - " Use 'urllib.parse.urlsplit' instead.", - DeprecationWarning, - stacklevel=2, - ) - s = _make_encode_wrapper(url) - is_text_based = isinstance(url, str) - - if scheme is None: - scheme = s("") - netloc = query = fragment = s("") - i = url.find(s(":")) - if i > 0 and _scheme_re.match(_to_str(url[:i], errors="replace")): - # make sure "iri" is not actually a port number (in which case - # "scheme" is really part of the path) - rest = url[i + 1 :] - if not rest or any(c not in s("0123456789") for c in rest): - # not a port number - scheme, url = url[:i].lower(), rest - - if url[:2] == s("//"): - delim = len(url) - for c in s("/?#"): - wdelim = url.find(c, 2) - if wdelim >= 0: - delim = min(delim, wdelim) - netloc, url = url[2:delim], url[delim:] - if (s("[") in netloc and s("]") not in netloc) or ( - s("]") in netloc and s("[") not in netloc - ): - raise ValueError("Invalid IPv6 URL") - - if allow_fragments and s("#") in url: - url, fragment = url.split(s("#"), 1) - if s("?") in url: - url, query = url.split(s("?"), 1) - - result_type = URL if is_text_based else BytesURL - - return result_type(scheme, netloc, url, query, fragment) - - -def _make_fast_url_quote( - charset: str = "utf-8", - errors: str = "strict", - safe: str | bytes = "/:", - unsafe: str | bytes = "", -) -> t.Callable[[bytes], str]: - """Precompile the translation table for a URL encoding function. - - Unlike :func:`url_quote`, the generated function only takes the - string to quote. - - :param charset: The charset to encode the result with. - :param errors: How to handle encoding errors. - :param safe: An optional sequence of safe characters to never encode. - :param unsafe: An optional sequence of unsafe characters to always encode. - """ - if isinstance(safe, str): - safe = safe.encode(charset, errors) - - if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) - - safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - table = [chr(c) if c in safe else f"%{c:02X}" for c in range(256)] - - def quote(string: bytes) -> str: - return "".join([table[c] for c in string]) - - return quote - - -_fast_url_quote = _make_fast_url_quote() -_fast_quote_plus = _make_fast_url_quote(safe=" ", unsafe="+") - - -def _fast_url_quote_plus(string: bytes) -> str: - return _fast_quote_plus(string).replace(" ", "+") - - -def url_quote( - string: str | bytes, - charset: str = "utf-8", - errors: str = "strict", - safe: str | bytes = "/:", - unsafe: str | bytes = "", -) -> str: - """URL encode a single string with a given encoding. - - :param s: the string to quote. - :param charset: the charset to be used. - :param safe: an optional sequence of safe characters. - :param unsafe: an optional sequence of unsafe characters. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use ``urllib.parse.quote`` instead. - - .. versionadded:: 0.9.2 - The `unsafe` parameter was added. - """ - warnings.warn( - "'werkzeug.urls.url_quote' is deprecated and will be removed in Werkzeug 3.0." - " Use 'urllib.parse.quote' instead.", - DeprecationWarning, - stacklevel=2, - ) - - if not isinstance(string, (str, bytes, bytearray)): - string = str(string) - if isinstance(string, str): - string = string.encode(charset, errors) - if isinstance(safe, str): - safe = safe.encode(charset, errors) - if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) - safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - rv = bytearray() - for char in bytearray(string): - if char in safe: - rv.append(char) - else: - rv.extend(_bytetohex[char]) - return bytes(rv).decode(charset) - - -def url_quote_plus( - string: str, charset: str = "utf-8", errors: str = "strict", safe: str = "" -) -> str: - """URL encode a single string with the given encoding and convert - whitespace to "+". - - :param s: The string to quote. - :param charset: The charset to be used. - :param safe: An optional sequence of safe characters. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use ``urllib.parse.quote_plus`` instead. - """ - warnings.warn( - "'werkzeug.urls.url_quote_plus' is deprecated and will be removed in Werkzeug" - " 2.4. Use 'urllib.parse.quote_plus' instead.", - DeprecationWarning, - stacklevel=2, - ) - - return url_quote(string, charset, errors, safe + " ", "+").replace(" ", "+") - - -def url_unparse(components: tuple[str, str, str, str, str]) -> str: - """The reverse operation to :meth:`url_parse`. This accepts arbitrary - as well as :class:`URL` tuples and returns a URL as a string. - - :param components: the parsed URL as tuple which should be converted - into a URL string. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use ``urllib.parse.urlunsplit`` instead. - """ - warnings.warn( - "'werkzeug.urls.url_unparse' is deprecated and will be removed in Werkzeug 3.0." - " Use 'urllib.parse.urlunsplit' instead.", - DeprecationWarning, - stacklevel=2, - ) - _check_str_tuple(components) - scheme, netloc, path, query, fragment = components - s = _make_encode_wrapper(scheme) - url = s("") - - # We generally treat file:///x and file:/x the same which is also - # what browsers seem to do. This also allows us to ignore a schema - # register for netloc utilization or having to differentiate between - # empty and missing netloc. - if netloc or (scheme and path.startswith(s("/"))): - if path and path[:1] != s("/"): - path = s("/") + path - url = s("//") + (netloc or s("")) + path - elif path: - url += path - if scheme: - url = scheme + s(":") + url - if query: - url = url + s("?") + query - if fragment: - url = url + s("#") + fragment - return url - - -def url_unquote( - s: str | bytes, - charset: str = "utf-8", - errors: str = "replace", - unsafe: str = "", -) -> str: - """URL decode a single string with a given encoding. If the charset - is set to `None` no decoding is performed and raw bytes are - returned. - - :param s: the string to unquote. - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param errors: the error handling for the charset decoding. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use ``urllib.parse.unquote`` instead. - """ - warnings.warn( - "'werkzeug.urls.url_unquote' is deprecated and will be removed in Werkzeug 3.0." - " Use 'urllib.parse.unquote' instead.", - DeprecationWarning, - stacklevel=2, - ) - rv = _unquote_to_bytes(s, unsafe) - if charset is None: - return rv - return rv.decode(charset, errors) - - -def url_unquote_plus( - s: str | bytes, charset: str = "utf-8", errors: str = "replace" -) -> str: - """URL decode a single string with the given `charset` and decode "+" to - whitespace. - - Per default encoding errors are ignored. If you want a different behavior - you can set `errors` to ``'replace'`` or ``'strict'``. - - :param s: The string to unquote. - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param errors: The error handling for the `charset` decoding. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use ``urllib.parse.unquote_plus`` instead. - """ - warnings.warn( - "'werkzeug.urls.url_unquote_plus' is deprecated and will be removed in Werkzeug" - " 2.4. Use 'urllib.parse.unquote_plus' instead.", - DeprecationWarning, - stacklevel=2, - ) - - if isinstance(s, str): - s = s.replace("+", " ") - else: - s = s.replace(b"+", b" ") - - return url_unquote(s, charset, errors) - - -def url_fix(s: str, charset: str = "utf-8") -> str: - r"""Sometimes you get an URL by a user that just isn't a real URL because - it contains unsafe characters like ' ' and so on. This function can fix - some of the problems in a similar way browsers handle data entered by the - user: - - >>> url_fix('http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)') - 'http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)' - - :param s: the string with the URL to fix. - :param charset: The target charset for the URL if the url was given - as a string. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. - """ - warnings.warn( - "'werkzeug.urls.url_fix' is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - # First step is to switch to text processing and to convert - # backslashes (which are invalid in URLs anyways) to slashes. This is - # consistent with what Chrome does. - s = _to_str(s, charset, "replace").replace("\\", "/") - - # For the specific case that we look like a malformed windows URL - # we want to fix this up manually: - if s.startswith("file://") and s[7:8].isalpha() and s[8:10] in (":/", "|/"): - s = f"file:///{s[7:]}" - - url = url_parse(s) - path = url_quote(url.path, charset, safe="/%+$!*'(),") - qs = url_quote_plus(url.query, charset, safe=":&%=+$!*'(),") - anchor = url_quote_plus(url.fragment, charset, safe=":&%=+$!*'(),") - return url_unparse((url.scheme, url.encode_netloc(), path, qs, anchor)) - def _codec_error_url_quote(e: UnicodeError) -> tuple[str, int]: """Used in :func:`uri_to_iri` after unquoting to re-quote any @@ -805,7 +26,7 @@ def _codec_error_url_quote(e: UnicodeError) -> tuple[str, int]: codecs.register_error("werkzeug.url_quote", _codec_error_url_quote) -def _make_unquote_part(name: str, chars: str) -> t.Callable[[str, str, str], str]: +def _make_unquote_part(name: str, chars: str) -> t.Callable[[str], str]: """Create a function that unquotes all percent encoded characters except those given. This allows working with unquoted characters if possible while not changing the meaning of a given part of a URL. @@ -813,12 +34,12 @@ def _make_unquote_part(name: str, chars: str) -> t.Callable[[str, str, str], str choices = "|".join(f"{ord(c):02X}" for c in sorted(chars)) pattern = re.compile(f"((?:%(?:{choices}))+)", re.I) - def _unquote_partial(value: str, encoding: str, errors: str) -> str: + def _unquote_partial(value: str) -> str: parts = iter(pattern.split(value)) out = [] for part in parts: - out.append(unquote(part, encoding, errors)) + out.append(unquote(part, "utf-8", "werkzeug.url_quote")) out.append(next(parts, "")) return "".join(out) @@ -837,11 +58,7 @@ def _unquote_partial(value: str, encoding: str, errors: str) -> str: _unquote_user = _make_unquote_part("user", _always_unsafe + ":@/?#") -def uri_to_iri( - uri: str | tuple[str, str, str, str, str], - charset: str | None = None, - errors: str | None = None, -) -> str: +def uri_to_iri(uri: str) -> str: """Convert a URI to an IRI. All valid UTF-8 characters are unquoted, leaving all reserved and invalid characters quoted. If the URL has a domain, it is decoded from Punycode. @@ -850,13 +67,10 @@ def uri_to_iri( 'http://\\u2603.net/p\\xe5th?q=\\xe8ry%DF' :param uri: The URI to convert. - :param charset: The encoding to encode unquoted bytes with. - :param errors: Error handler to use during ``bytes.encode``. By - default, invalid bytes are left quoted. - .. versionchanged:: 2.3 - Passing a tuple or bytes, and the ``charset`` and ``errors`` parameters, are - deprecated and will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + Passing a tuple or bytes, and the ``charset`` and ``errors`` parameters, + are removed. .. versionchanged:: 2.3 Which characters remain quoted is specific to each part of the URL. @@ -868,45 +82,10 @@ def uri_to_iri( .. versionadded:: 0.6 """ - if isinstance(uri, tuple): - warnings.warn( - "Passing a tuple is deprecated and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - uri = urlunsplit(uri) - - if isinstance(uri, bytes): - warnings.warn( - "Passing bytes is deprecated and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - uri = uri.decode() - - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be removed" - " in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - if errors is not None: - warnings.warn( - "The 'errors' parameter is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - errors = "werkzeug.url_quote" - parts = urlsplit(uri) - path = _unquote_path(parts.path, charset, errors) - query = _unquote_query(parts.query, charset, errors) - fragment = _unquote_fragment(parts.fragment, charset, errors) + path = _unquote_path(parts.path) + query = _unquote_query(parts.query) + fragment = _unquote_fragment(parts.fragment) if parts.hostname: netloc = _decode_idna(parts.hostname) @@ -920,22 +99,18 @@ def uri_to_iri( netloc = f"{netloc}:{parts.port}" if parts.username: - auth = _unquote_user(parts.username, charset, errors) + auth = _unquote_user(parts.username) if parts.password: - auth = f"{auth}:{_unquote_user(parts.password, charset, errors)}" + password = _unquote_user(parts.password) + auth = f"{auth}:{password}" netloc = f"{auth}@{netloc}" return urlunsplit((parts.scheme, netloc, path, query, fragment)) -def iri_to_uri( - iri: str | tuple[str, str, str, str, str], - charset: str | None = None, - errors: str | None = None, - safe_conversion: bool | None = None, -) -> str: +def iri_to_uri(iri: str) -> str: """Convert an IRI to a URI. All non-ASCII and unsafe characters are quoted. If the URL has a domain, it is encoded to Punycode. @@ -943,20 +118,14 @@ def iri_to_uri( 'http://xn--n3h.net/p%C3%A5th?q=%C3%A8ry%DF' :param iri: The IRI to convert. - :param charset: The encoding of the IRI. - :param errors: Error handler to use during ``bytes.encode``. - .. versionchanged:: 2.3 - Passing a tuple or bytes, and the ``charset`` and ``errors`` parameters, are - deprecated and will be removed in Werkzeug 3.0. + .. versionchanged:: 3.0 + Passing a tuple or bytes, the ``charset`` and ``errors`` parameters, + and the ``safe_conversion`` parameter, are removed. .. versionchanged:: 2.3 Which characters remain unquoted is specific to each part of the URL. - .. versionchanged:: 2.3 - The ``safe_conversion`` parameter is deprecated and will be removed in Werkzeug - 2.4. - .. versionchanged:: 0.15 All reserved characters remain unquoted. Previously, only some reserved characters were left unquoted. @@ -966,69 +135,12 @@ def iri_to_uri( .. versionadded:: 0.6 """ - if charset is not None: - warnings.warn( - "The 'charset' parameter is deprecated and will be removed" - " in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - charset = "utf-8" - - if isinstance(iri, tuple): - warnings.warn( - "Passing a tuple is deprecated and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - iri = urlunsplit(iri) - - if isinstance(iri, bytes): - warnings.warn( - "Passing bytes is deprecated and will not be supported in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - iri = iri.decode(charset) - - if errors is not None: - warnings.warn( - "The 'errors' parameter is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - else: - errors = "strict" - - if safe_conversion is not None: - warnings.warn( - "The 'safe_conversion' parameter is deprecated and will be removed in" - " Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - if safe_conversion: - # If we're not sure if it's safe to normalize the URL, and it only contains - # ASCII characters, return it as-is. - try: - ascii_iri = iri.encode("ascii") - - # Only return if it doesn't have whitespace. (Why?) - if len(ascii_iri.split()) == 1: - return iri - except UnicodeError: - pass - parts = urlsplit(iri) # safe = https://url.spec.whatwg.org/#url-path-segment-string # as well as percent for things that are already quoted - path = quote(parts.path, safe="%!$&'()*+,/:;=@", encoding=charset, errors=errors) - query = quote(parts.query, safe="%!$&'()*+,/:;=?@", encoding=charset, errors=errors) - fragment = quote( - parts.fragment, safe="%!#$&'()*+,/:;=?@", encoding=charset, errors=errors - ) + path = quote(parts.path, safe="%!$&'()*+,/:;=@") + query = quote(parts.query, safe="%!$&'()*+,/:;=?@") + fragment = quote(parts.fragment, safe="%!#$&'()*+,/:;=?@") if parts.hostname: netloc = parts.hostname.encode("idna").decode("ascii") @@ -1045,333 +157,47 @@ def iri_to_uri( auth = quote(parts.username, safe="%!$&'()*+,;=") if parts.password: - pass_quoted = quote(parts.password, safe="%!$&'()*+,;=") - auth = f"{auth}:{pass_quoted}" + password = quote(parts.password, safe="%!$&'()*+,;=") + auth = f"{auth}:{password}" netloc = f"{auth}@{netloc}" return urlunsplit((parts.scheme, netloc, path, query, fragment)) -def _invalid_iri_to_uri(iri: str) -> str: - """The URL scheme ``itms-services://`` must contain the ``//`` even though it does - not have a host component. There may be other invalid schemes as well. Currently, - responses will always call ``iri_to_uri`` on the redirect ``Location`` header, which - removes the ``//``. For now, if the IRI only contains ASCII and does not contain - spaces, pass it on as-is. In Werkzeug 3.0, this should become a - ``response.process_location`` flag. - - :meta private: - """ - try: - iri.encode("ascii") - except UnicodeError: - pass - else: - if len(iri.split(None, 1)) == 1: - return iri - - return iri_to_uri(iri) - - -def url_decode( - s: t.AnyStr, - charset: str = "utf-8", - include_empty: bool = True, - errors: str = "replace", - separator: str = "&", - cls: type[ds.MultiDict] | None = None, -) -> ds.MultiDict[str, str]: - """Parse a query string and return it as a :class:`MultiDict`. - - :param s: The query string to parse. - :param charset: Decode bytes to string with this charset. If not - given, bytes are returned as-is. - :param include_empty: Include keys with empty values in the dict. - :param errors: Error handling behavior when decoding bytes. - :param separator: Separator character between pairs. - :param cls: Container to hold result instead of :class:`MultiDict`. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. Use ``urllib.parse.parse_qs`` instead. - - .. versionchanged:: 2.1 - The ``decode_keys`` parameter was removed. - - .. versionchanged:: 0.5 - In previous versions ";" and "&" could be used for url decoding. - Now only "&" is supported. If you want to use ";", a different - ``separator`` can be provided. - - .. versionchanged:: 0.5 - The ``cls`` parameter was added. - """ - warnings.warn( - "'werkzeug.urls.url_decode' is deprecated and will be removed in Werkzeug 2.4." - " Use 'urllib.parse.parse_qs' instead.", - DeprecationWarning, - stacklevel=2, - ) - - if cls is None: - from .datastructures import MultiDict # noqa: F811 - - cls = MultiDict - if isinstance(s, str) and not isinstance(separator, str): - separator = separator.decode(charset or "ascii") - elif isinstance(s, bytes) and not isinstance(separator, bytes): - separator = separator.encode(charset or "ascii") # type: ignore - return cls( - _url_decode_impl( - s.split(separator), charset, include_empty, errors # type: ignore - ) - ) - - -def url_decode_stream( - stream: t.IO[bytes], - charset: str = "utf-8", - include_empty: bool = True, - errors: str = "replace", - separator: bytes = b"&", - cls: type[ds.MultiDict] | None = None, - limit: int | None = None, -) -> ds.MultiDict[str, str]: - """Works like :func:`url_decode` but decodes a stream. The behavior - of stream and limit follows functions like - :func:`~werkzeug.wsgi.make_line_iter`. The generator of pairs is - directly fed to the `cls` so you can consume the data while it's - parsed. - - :param stream: a stream with the encoded querystring - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param include_empty: Set to `False` if you don't want empty values to - appear in the dict. - :param errors: the decoding error behavior. - :param separator: the pair separator to be used, defaults to ``&`` - :param cls: an optional dict class to use. If this is not specified - or `None` the default :class:`MultiDict` is used. - :param limit: the content length of the URL data. Not necessary if - a limited stream is provided. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 2.4. Use ``urllib.parse.parse_qs`` instead. - - .. versionchanged:: 2.1 - The ``decode_keys`` and ``return_iterator`` parameters were removed. - - .. versionadded:: 0.8 - """ - warnings.warn( - "'werkzeug.urls.url_decode_stream' is deprecated and will be removed in" - " Werkzeug 2.4. Use 'urllib.parse.parse_qs' instead.", - DeprecationWarning, - stacklevel=2, - ) - - from .wsgi import make_chunk_iter - - pair_iter = make_chunk_iter(stream, separator, limit) - decoder = _url_decode_impl(pair_iter, charset, include_empty, errors) - - if cls is None: - from .datastructures import MultiDict # noqa: F811 - - cls = MultiDict - - return cls(decoder) +# Python < 3.12 +# itms-services was worked around in previous iri_to_uri implementations, but +# we can tell Python directly that it needs to preserve the //. +if "itms-services" not in urllib.parse.uses_netloc: + urllib.parse.uses_netloc.append("itms-services") -def _url_decode_impl( - pair_iter: t.Iterable[t.AnyStr], charset: str, include_empty: bool, errors: str -) -> t.Iterator[tuple[str, str]]: - for pair in pair_iter: - if not pair: - continue - s = _make_encode_wrapper(pair) - equal = s("=") - if equal in pair: - key, value = pair.split(equal, 1) - else: - if not include_empty: - continue - key = pair - value = s("") - yield ( - url_unquote_plus(key, charset, errors), - url_unquote_plus(value, charset, errors), - ) - - -def url_encode( - obj: t.Mapping[str, str] | t.Iterable[tuple[str, str]], - charset: str = "utf-8", - sort: bool = False, - key: t.Callable[[tuple[str, str]], t.Any] | None = None, - separator: str = "&", -) -> str: - """URL encode a dict/`MultiDict`. If a value is `None` it will not appear - in the result string. Per default only values are encoded into the target - charset strings. - - :param obj: the object to encode into a query string. - :param charset: the charset of the query string. - :param sort: set to `True` if you want parameters to be sorted by `key`. - :param separator: the separator to be used for the pairs. - :param key: an optional function to be used for sorting. For more details - check out the :func:`sorted` documentation. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 2.4. Use ``urllib.parse.urlencode`` instead. - - .. versionchanged:: 2.1 - The ``encode_keys`` parameter was removed. - - .. versionchanged:: 0.5 - Added the ``sort``, ``key``, and ``separator`` parameters. - """ - warnings.warn( - "'werkzeug.urls.url_encode' is deprecated and will be removed in Werkzeug 2.4." - " Use 'urllib.parse.urlencode' instead.", - DeprecationWarning, - stacklevel=2, - ) - separator = _to_str(separator, "ascii") - return separator.join(_url_encode_impl(obj, charset, sort, key)) - - -def url_encode_stream( - obj: t.Mapping[str, str] | t.Iterable[tuple[str, str]], - stream: t.IO[str] | None = None, - charset: str = "utf-8", - sort: bool = False, - key: t.Callable[[tuple[str, str]], t.Any] | None = None, - separator: str = "&", -) -> None: - """Like :meth:`url_encode` but writes the results to a stream - object. If the stream is `None` a generator over all encoded - pairs is returned. - - :param obj: the object to encode into a query string. - :param stream: a stream to write the encoded object into or `None` if - an iterator over the encoded pairs should be returned. In - that case the separator argument is ignored. - :param charset: the charset of the query string. - :param sort: set to `True` if you want parameters to be sorted by `key`. - :param separator: the separator to be used for the pairs. - :param key: an optional function to be used for sorting. For more details - check out the :func:`sorted` documentation. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 2.4. Use ``urllib.parse.urlencode`` instead. - - .. versionchanged:: 2.1 - The ``encode_keys`` parameter was removed. - - .. versionadded:: 0.8 - """ - warnings.warn( - "'werkzeug.urls.url_encode_stream' is deprecated and will be removed in" - " Werkzeug 2.4. Use 'urllib.parse.urlencode' instead.", - DeprecationWarning, - stacklevel=2, - ) - separator = _to_str(separator, "ascii") - gen = _url_encode_impl(obj, charset, sort, key) - if stream is None: - return gen # type: ignore - for idx, chunk in enumerate(gen): - if idx: - stream.write(separator) - stream.write(chunk) - return None - - -def url_join( - base: str | tuple[str, str, str, str, str], - url: str | tuple[str, str, str, str, str], - allow_fragments: bool = True, -) -> str: - """Join a base URL and a possibly relative URL to form an absolute - interpretation of the latter. - - :param base: the base URL for the join operation. - :param url: the URL to join. - :param allow_fragments: indicates whether fragments should be allowed. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 2.4. Use ``urllib.parse.urljoin`` instead. - """ - warnings.warn( - "'werkzeug.urls.url_join' is deprecated and will be removed in Werkzeug 2.4." - " Use 'urllib.parse.urljoin' instead.", - DeprecationWarning, - stacklevel=2, - ) - - if isinstance(base, tuple): - base = url_unparse(base) - if isinstance(url, tuple): - url = url_unparse(url) - - _check_str_tuple((base, url)) - s = _make_encode_wrapper(base) - - if not base: - return url - if not url: - return base - - bscheme, bnetloc, bpath, bquery, bfragment = url_parse( - base, allow_fragments=allow_fragments - ) - scheme, netloc, path, query, fragment = url_parse(url, bscheme, allow_fragments) - if scheme != bscheme: - return url - if netloc: - return url_unparse((scheme, netloc, path, query, fragment)) - netloc = bnetloc - - if path[:1] == s("/"): - segments = path.split(s("/")) - elif not path: - segments = bpath.split(s("/")) - if not query: - query = bquery - else: - segments = bpath.split(s("/"))[:-1] + path.split(s("/")) +def _decode_idna(domain: str) -> str: + try: + data = domain.encode("ascii") + except UnicodeEncodeError: + # If the domain is not ASCII, it's decoded already. + return domain - # If the rightmost part is "./" we want to keep the slash but - # remove the dot. - if segments[-1] == s("."): - segments[-1] = s("") + try: + # Try decoding in one shot. + return data.decode("idna") + except UnicodeDecodeError: + pass - # Resolve ".." and "." - segments = [segment for segment in segments if segment != s(".")] - while True: - i = 1 - n = len(segments) - 1 - while i < n: - if segments[i] == s("..") and segments[i - 1] not in (s(""), s("..")): - del segments[i - 1 : i + 1] - break - i += 1 - else: - break + # Decode each part separately, leaving invalid parts as punycode. + parts = [] - # Remove trailing ".." if the URL is absolute - unwanted_marker = [s(""), s("..")] - while segments[:2] == unwanted_marker: - del segments[1] + for part in data.split(b"."): + try: + parts.append(part.decode("idna")) + except UnicodeDecodeError: + parts.append(part.decode("ascii")) - path = s("/").join(segments) - return url_unparse((scheme, netloc, path, query, fragment)) + return ".".join(parts) -def _urlencode( - query: t.Mapping[str, str] | t.Iterable[tuple[str, str]], encoding: str = "utf-8" -) -> str: +def _urlencode(query: t.Mapping[str, str] | t.Iterable[tuple[str, str]]) -> str: items = [x for x in iter_multi_items(query) if x[1] is not None] # safe = https://url.spec.whatwg.org/#percent-encoded-bytes - return urlencode(items, safe="!$'()*,/:;?@", encoding=encoding) + return urlencode(items, safe="!$'()*,/:;?@") diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py index 785ac28b9..3d3bbf066 100644 --- a/src/werkzeug/utils.py +++ b/src/werkzeug/utils.py @@ -26,6 +26,7 @@ if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment + from .wrappers.request import Request from .wrappers.response import Response @@ -149,7 +150,7 @@ def lookup(self, obj: Request) -> WSGIEnvironment: class header_property(_DictAccessorProperty[_TAccessorValue]): """Like `environ_property` but for headers.""" - def lookup(self, obj: Request | Response) -> Headers: + def lookup(self, obj: Request | Response) -> Headers: # type: ignore[override] return obj.headers @@ -316,7 +317,7 @@ def append_slash_redirect(environ: WSGIEnvironment, code: int = 308) -> Response def send_file( - path_or_file: os.PathLike | str | t.IO[bytes], + path_or_file: os.PathLike[str] | str | t.IO[bytes], environ: WSGIEnvironment, mimetype: str | None = None, as_attachment: bool = False, @@ -327,7 +328,7 @@ def send_file( max_age: None | (int | t.Callable[[str | None], int | None]) = None, use_x_sendfile: bool = False, response_class: type[Response] | None = None, - _root_path: os.PathLike | str | None = None, + _root_path: os.PathLike[str] | str | None = None, ) -> Response: """Send the contents of a file to the client. @@ -415,7 +416,7 @@ def send_file( if isinstance(path_or_file, (os.PathLike, str)) or hasattr( path_or_file, "__fspath__" ): - path_or_file = t.cast(t.Union[os.PathLike, str], path_or_file) + path_or_file = t.cast("t.Union[os.PathLike[str], str]", path_or_file) # Flask will pass app.root_path, allowing its send_file wrapper # to not have to deal with paths. @@ -514,7 +515,7 @@ def send_file( if isinstance(etag, str): rv.set_etag(etag) elif etag and path is not None: - check = adler32(path.encode("utf-8")) & 0xFFFFFFFF + check = adler32(path.encode()) & 0xFFFFFFFF rv.set_etag(f"{mtime}-{size}-{check}") if conditional: @@ -535,8 +536,8 @@ def send_file( def send_from_directory( - directory: os.PathLike | str, - path: os.PathLike | str, + directory: os.PathLike[str] | str, + path: os.PathLike[str] | str, environ: WSGIEnvironment, **kwargs: t.Any, ) -> Response: @@ -560,20 +561,20 @@ def send_from_directory( .. versionadded:: 2.0 Adapted from Flask's implementation. """ - path = safe_join(os.fspath(directory), os.fspath(path)) + path_str = safe_join(os.fspath(directory), os.fspath(path)) - if path is None: + if path_str is None: raise NotFound() # Flask will pass app.root_path, allowing its send_from_directory # wrapper to not have to deal with paths. if "_root_path" in kwargs: - path = os.path.join(kwargs["_root_path"], path) + path_str = os.path.join(kwargs["_root_path"], path_str) - if not os.path.isfile(path): + if not os.path.isfile(path_str): raise NotFound() - return send_file(path, environ, **kwargs) + return send_file(path_str, environ, **kwargs) def import_string(import_name: str, silent: bool = False) -> t.Any: diff --git a/src/werkzeug/wrappers/__init__.py b/src/werkzeug/wrappers/__init__.py index b8c45d71c..b36f228f2 100644 --- a/src/werkzeug/wrappers/__init__.py +++ b/src/werkzeug/wrappers/__init__.py @@ -1,3 +1,3 @@ from .request import Request as Request from .response import Response as Response -from .response import ResponseStream +from .response import ResponseStream as ResponseStream diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py index f4f51b1dc..719a3bc00 100644 --- a/src/werkzeug/wrappers/request.py +++ b/src/werkzeug/wrappers/request.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections.abc as cabc import functools import json import typing as t @@ -50,6 +51,10 @@ class Request(_SansIORequest): prevent consuming the form data in middleware, which would make it unavailable to the final application. + .. versionchanged:: 3.0 + The ``charset``, ``url_charset``, and ``encoding_errors`` parameters + were removed. + .. versionchanged:: 2.1 Old ``BaseRequest`` and mixin classes were removed. @@ -79,8 +84,11 @@ class Request(_SansIORequest): #: data in memory for post data is longer than the specified value a #: :exc:`~werkzeug.exceptions.RequestEntityTooLarge` exception is raised. #: + #: .. versionchanged:: 3.1 + #: Defaults to 500kB instead of unlimited. + #: #: .. versionadded:: 0.5 - max_form_memory_size: int | None = None + max_form_memory_size: int | None = 500_000 #: The maximum number of multipart parts to parse, passed to #: :attr:`form_data_parser_class`. Parsing form data with more than this @@ -145,9 +153,6 @@ def from_values(cls, *args: t.Any, **kwargs: t.Any) -> Request: """ from ..test import EnvironBuilder - kwargs.setdefault( - "charset", cls.charset if not isinstance(cls.charset, property) else None - ) builder = EnvironBuilder(*args, **kwargs) try: return builder.get_request(cls) @@ -181,13 +186,13 @@ def my_wsgi_app(request): from ..exceptions import HTTPException @functools.wraps(f) - def application(*args): # type: ignore + def application(*args: t.Any) -> cabc.Iterable[bytes]: request = cls(args[-2]) with request: try: resp = f(*args[:-2] + (request,)) except HTTPException as e: - resp = e.get_response(args[-2]) + resp = t.cast("WSGIApplication", e.get_response(args[-2])) return resp(*args[-2:]) return t.cast("WSGIApplication", application) @@ -240,12 +245,8 @@ def make_form_data_parser(self) -> FormDataParser: .. versionadded:: 0.8 """ - charset = self._charset if self._charset != "utf-8" else None - errors = self._encoding_errors if self._encoding_errors != "replace" else None return self.form_data_parser_class( stream_factory=self._get_file_stream, - charset=charset, - errors=errors, max_form_memory_size=self.max_form_memory_size, max_content_length=self.max_content_length, max_form_parts=self.max_form_parts, @@ -372,13 +373,12 @@ def data(self) -> bytes: return self.get_data(parse_form_data=True) @t.overload - def get_data( # type: ignore + def get_data( self, cache: bool = True, as_text: t.Literal[False] = False, parse_form_data: bool = False, - ) -> bytes: - ... + ) -> bytes: ... @t.overload def get_data( @@ -386,8 +386,7 @@ def get_data( cache: bool = True, as_text: t.Literal[True] = ..., parse_form_data: bool = False, - ) -> str: - ... + ) -> str: ... def get_data( self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False @@ -424,7 +423,7 @@ def get_data( if cache: self._cached_data = rv if as_text: - rv = rv.decode(self._charset, self._encoding_errors) + rv = rv.decode(errors="replace") return rv @cached_property @@ -567,14 +566,12 @@ def json(self) -> t.Any | None: @t.overload def get_json( self, force: bool = ..., silent: t.Literal[False] = ..., cache: bool = ... - ) -> t.Any: - ... + ) -> t.Any: ... @t.overload def get_json( self, force: bool = ..., silent: bool = ..., cache: bool = ... - ) -> t.Any | None: - ... + ) -> t.Any | None: ... def get_json( self, force: bool = False, silent: bool = False, cache: bool = True diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py index c8488094e..7f01287c7 100644 --- a/src/werkzeug/wrappers/response.py +++ b/src/werkzeug/wrappers/response.py @@ -5,33 +5,33 @@ from http import HTTPStatus from urllib.parse import urljoin +from .._internal import _get_environ from ..datastructures import Headers +from ..http import generate_etag +from ..http import http_date +from ..http import is_resource_modified +from ..http import parse_etags +from ..http import parse_range_header from ..http import remove_entity_headers from ..sansio.response import Response as _SansIOResponse -from ..urls import _invalid_iri_to_uri from ..urls import iri_to_uri from ..utils import cached_property +from ..wsgi import _RangeWrapper from ..wsgi import ClosingIterator from ..wsgi import get_current_url -from werkzeug._internal import _get_environ -from werkzeug.http import generate_etag -from werkzeug.http import http_date -from werkzeug.http import is_resource_modified -from werkzeug.http import parse_etags -from werkzeug.http import parse_range_header -from werkzeug.wsgi import _RangeWrapper if t.TYPE_CHECKING: from _typeshed.wsgi import StartResponse from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment + from .request import Request -def _iter_encoded(iterable: t.Iterable[str | bytes], charset: str) -> t.Iterator[bytes]: +def _iter_encoded(iterable: t.Iterable[str | bytes]) -> t.Iterator[bytes]: for item in iterable: if isinstance(item, str): - yield item.encode(charset) + yield item.encode() else: yield item @@ -260,12 +260,10 @@ def from_app( return cls(*run_wsgi_app(app, environ, buffered)) @t.overload - def get_data(self, as_text: t.Literal[False] = False) -> bytes: - ... + def get_data(self, as_text: t.Literal[False] = False) -> bytes: ... @t.overload - def get_data(self, as_text: t.Literal[True]) -> str: - ... + def get_data(self, as_text: t.Literal[True]) -> str: ... def get_data(self, as_text: bool = False) -> bytes | str: """The string representation of the response body. Whenever you call @@ -284,7 +282,7 @@ def get_data(self, as_text: bool = False) -> bytes | str: rv = b"".join(self.iter_encoded()) if as_text: - return rv.decode(self._charset) + return rv.decode() return rv @@ -296,7 +294,7 @@ def set_data(self, value: bytes | str) -> None: .. versionadded:: 0.9 """ if isinstance(value, str): - value = value.encode(self._charset) + value = value.encode() self.response = [value] if self.automatically_set_content_length: self.headers["Content-Length"] = str(len(value)) @@ -366,7 +364,7 @@ def iter_encoded(self) -> t.Iterator[bytes]: # Encode in a separate function so that self.response is fetched # early. This allows us to wrap the response with the return # value from get_app_iter or iter_encoded. - return _iter_encoded(self.response, self._charset) + return _iter_encoded(self.response) @property def is_streamed(self) -> bool: @@ -480,7 +478,7 @@ def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers: content_length = value if location is not None: - location = _invalid_iri_to_uri(location) + location = iri_to_uri(location) if self.autocorrect_location_header: # Make the location header an absolute URL. @@ -595,12 +593,10 @@ def json(self) -> t.Any | None: return self.get_json() @t.overload - def get_json(self, force: bool = ..., silent: t.Literal[False] = ...) -> t.Any: - ... + def get_json(self, force: bool = ..., silent: t.Literal[False] = ...) -> t.Any: ... @t.overload - def get_json(self, force: bool = ..., silent: bool = ...) -> t.Any | None: - ... + def get_json(self, force: bool = ..., silent: bool = ...) -> t.Any | None: ... def get_json(self, force: bool = False, silent: bool = False) -> t.Any | None: """Parse :attr:`data` as JSON. Useful during testing. @@ -832,4 +828,4 @@ def tell(self) -> int: @property def encoding(self) -> str: - return self.response._charset + return "utf-8" diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index 6061e1141..01d40af2f 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -1,16 +1,10 @@ from __future__ import annotations import io -import re import typing as t -import warnings from functools import partial from functools import update_wrapper -from itertools import chain -from ._internal import _make_encode_wrapper -from ._internal import _to_bytes -from ._internal import _to_str from .exceptions import ClientDisconnected from .exceptions import RequestEntityTooLarge from .sansio import utils as _sansio_utils @@ -200,45 +194,18 @@ def get_input_stream( return t.cast(t.IO[bytes], LimitedStream(stream, content_length)) -def get_path_info( - environ: WSGIEnvironment, - charset: t.Any = ..., - errors: str | None = None, -) -> str: +def get_path_info(environ: WSGIEnvironment) -> str: """Return ``PATH_INFO`` from the WSGI environment. :param environ: WSGI environment to get the path from. - .. versionchanged:: 2.3 - The ``charset`` and ``errors`` parameters are deprecated and will be removed in - Werkzeug 3.0. + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. .. versionadded:: 0.9 """ - if charset is not ...: - warnings.warn( - "The 'charset' parameter is deprecated and will be removed" - " in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - if charset is None: - charset = "utf-8" - else: - charset = "utf-8" - - if errors is not None: - warnings.warn( - "The 'errors' parameter is deprecated and will be removed in Werkzeug 3.0", - DeprecationWarning, - stacklevel=2, - ) - else: - errors = "replace" - - path = environ.get("PATH_INFO", "").encode("latin1") - return path.decode(charset, errors) # type: ignore[no-any-return] + path: bytes = environ.get("PATH_INFO", "").encode("latin1") + return path.decode(errors="replace") class ClosingIterator: @@ -455,225 +422,6 @@ def close(self) -> None: self.iterable.close() -def _make_chunk_iter( - stream: t.Iterable[bytes] | t.IO[bytes], - limit: int | None, - buffer_size: int, -) -> t.Iterator[bytes]: - """Helper for the line and chunk iter functions.""" - warnings.warn( - "'_make_chunk_iter' is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - - if isinstance(stream, (bytes, bytearray, str)): - raise TypeError( - "Passed a string or byte object instead of true iterator or stream." - ) - if not hasattr(stream, "read"): - for item in stream: - if item: - yield item - return - stream = t.cast(t.IO[bytes], stream) - if not isinstance(stream, LimitedStream) and limit is not None: - stream = t.cast(t.IO[bytes], LimitedStream(stream, limit)) - _read = stream.read - while True: - item = _read(buffer_size) - if not item: - break - yield item - - -def make_line_iter( - stream: t.Iterable[bytes] | t.IO[bytes], - limit: int | None = None, - buffer_size: int = 10 * 1024, - cap_at_buffer: bool = False, -) -> t.Iterator[bytes]: - """Safely iterates line-based over an input stream. If the input stream - is not a :class:`LimitedStream` the `limit` parameter is mandatory. - - This uses the stream's :meth:`~file.read` method internally as opposite - to the :meth:`~file.readline` method that is unsafe and can only be used - in violation of the WSGI specification. The same problem applies to the - `__iter__` function of the input stream which calls :meth:`~file.readline` - without arguments. - - If you need line-by-line processing it's strongly recommended to iterate - over the input stream using this helper function. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. - - .. versionadded:: 0.11 - added support for the `cap_at_buffer` parameter. - - .. versionadded:: 0.9 - added support for iterators as input stream. - - .. versionchanged:: 0.8 - This function now ensures that the limit was reached. - - :param stream: the stream or iterate to iterate over. - :param limit: the limit in bytes for the stream. (Usually - content length. Not necessary if the `stream` - is a :class:`LimitedStream`. - :param buffer_size: The optional buffer size. - :param cap_at_buffer: if this is set chunks are split if they are longer - than the buffer size. Internally this is implemented - that the buffer size might be exhausted by a factor - of two however. - """ - warnings.warn( - "'make_line_iter' is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - _iter = _make_chunk_iter(stream, limit, buffer_size) - - first_item = next(_iter, "") - - if not first_item: - return - - s = _make_encode_wrapper(first_item) - empty = t.cast(bytes, s("")) - cr = t.cast(bytes, s("\r")) - lf = t.cast(bytes, s("\n")) - crlf = t.cast(bytes, s("\r\n")) - - _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) - - def _iter_basic_lines() -> t.Iterator[bytes]: - _join = empty.join - buffer: list[bytes] = [] - while True: - new_data = next(_iter, "") - if not new_data: - break - new_buf: list[bytes] = [] - buf_size = 0 - for item in t.cast( - t.Iterator[bytes], chain(buffer, new_data.splitlines(True)) - ): - new_buf.append(item) - buf_size += len(item) - if item and item[-1:] in crlf: - yield _join(new_buf) - new_buf = [] - elif cap_at_buffer and buf_size >= buffer_size: - rv = _join(new_buf) - while len(rv) >= buffer_size: - yield rv[:buffer_size] - rv = rv[buffer_size:] - new_buf = [rv] - buffer = new_buf - if buffer: - yield _join(buffer) - - # This hackery is necessary to merge 'foo\r' and '\n' into one item - # of 'foo\r\n' if we were unlucky and we hit a chunk boundary. - previous = empty - for item in _iter_basic_lines(): - if item == lf and previous[-1:] == cr: - previous += item - item = empty - if previous: - yield previous - previous = item - if previous: - yield previous - - -def make_chunk_iter( - stream: t.Iterable[bytes] | t.IO[bytes], - separator: bytes, - limit: int | None = None, - buffer_size: int = 10 * 1024, - cap_at_buffer: bool = False, -) -> t.Iterator[bytes]: - """Works like :func:`make_line_iter` but accepts a separator - which divides chunks. If you want newline based processing - you should use :func:`make_line_iter` instead as it - supports arbitrary newline markers. - - .. deprecated:: 2.3 - Will be removed in Werkzeug 3.0. - - .. versionchanged:: 0.11 - added support for the `cap_at_buffer` parameter. - - .. versionchanged:: 0.9 - added support for iterators as input stream. - - .. versionadded:: 0.8 - - :param stream: the stream or iterate to iterate over. - :param separator: the separator that divides chunks. - :param limit: the limit in bytes for the stream. (Usually - content length. Not necessary if the `stream` - is otherwise already limited). - :param buffer_size: The optional buffer size. - :param cap_at_buffer: if this is set chunks are split if they are longer - than the buffer size. Internally this is implemented - that the buffer size might be exhausted by a factor - of two however. - """ - warnings.warn( - "'make_chunk_iter' is deprecated and will be removed in Werkzeug 3.0.", - DeprecationWarning, - stacklevel=2, - ) - _iter = _make_chunk_iter(stream, limit, buffer_size) - - first_item = next(_iter, b"") - - if not first_item: - return - - _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) - if isinstance(first_item, str): - separator = _to_str(separator) - _split = re.compile(f"({re.escape(separator)})").split - _join = "".join - else: - separator = _to_bytes(separator) - _split = re.compile(b"(" + re.escape(separator) + b")").split - _join = b"".join - - buffer: list[bytes] = [] - while True: - new_data = next(_iter, b"") - if not new_data: - break - chunks = _split(new_data) - new_buf: list[bytes] = [] - buf_size = 0 - for item in chain(buffer, chunks): - if item == separator: - yield _join(new_buf) - new_buf = [] - buf_size = 0 - else: - buf_size += len(item) - new_buf.append(item) - - if cap_at_buffer and buf_size >= buffer_size: - rv = _join(new_buf) - while len(rv) >= buffer_size: - yield rv[:buffer_size] - rv = rv[buffer_size:] - new_buf = [rv] - buf_size = len(rv) - - buffer = new_buf - if buffer: - yield _join(buffer) - - class LimitedStream(io.RawIOBase): """Wrap a stream so that it doesn't read more than a given limit. This is used to limit ``wsgi.input`` to the ``Content-Length`` header value or diff --git a/tests/conftest.py b/tests/conftest.py index b73202cdb..f05fd84ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,55 +1,149 @@ +from __future__ import annotations + +import collections.abc as cabc import http.client import json import os import socket import ssl +import subprocess import sys +import time +import typing as t +from contextlib import closing +from contextlib import ExitStack from pathlib import Path +from types import TracebackType import ephemeral_port_reserve import pytest -from xprocess import ProcessStarter - -from werkzeug.utils import cached_property -run_path = str(Path(__file__).parent / "live_apps" / "run.py") +if t.TYPE_CHECKING: + import typing_extensions as te class UnixSocketHTTPConnection(http.client.HTTPConnection): - def connect(self): + def connect(self) -> None: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + # Raises FileNotFoundError if the server hasn't started yet. self.sock.connect(self.host) +# Used to annotate the ``DevServerClient.request`` return value. +class DataHTTPResponse(http.client.HTTPResponse): + data: bytes + json: t.Any + + class DevServerClient: - def __init__(self, kwargs): - host = kwargs.get("hostname", "127.0.0.1") + """Manage a live dev server process and make requests to it. Must be used + as a context manager. + + If ``hostname`` starts with ``unix://``, the server listens to a unix socket + file instead of a TCP socket. + + If ``port`` is not given, a random port is reserved for use by the server, + to allow multiple servers to run simultaneously. + + If ``ssl_context`` is given, the server listens with TLS enabled. It can be + the special value ``custom`` to generate and pass a context to + ``run_simple``, as opposed to ``adhoc`` which tells ``run_simple`` to + generate the context. + + :param app_name: The name of the app from the ``live_apps`` folder to load. + :param tmp_path: The current test's temporary directory. The server process + sets the working dir here, it is added to the Python path, the log file + is written here, and for unix connections the socket is opened here. + :param server_kwargs: Arguments to pass to ``live_apps/run.py`` to control + how ``run_simple`` is called in the subprocess. + """ - if not host.startswith("unix"): - port = kwargs.get("port") + scheme: str + """One of ``http``, ``https``, or ``unix``. Set based on ``ssl_context`` or + ``hostname``. + """ + addr: str + """The host and port.""" + url: str + """The scheme, host, and port.""" + + def __init__( + self, app_name: str = "standard", *, tmp_path: Path, **server_kwargs: t.Any + ) -> None: + host = server_kwargs.get("hostname", "127.0.0.1") + + if not host.startswith("unix://"): + port = server_kwargs.get("port") if port is None: - kwargs["port"] = port = ephemeral_port_reserve.reserve(host) + server_kwargs["port"] = port = ephemeral_port_reserve.reserve(host) - scheme = "https" if "ssl_context" in kwargs else "http" + self.scheme = "https" if "ssl_context" in server_kwargs else "http" self.addr = f"{host}:{port}" - self.url = f"{scheme}://{self.addr}" + self.url = f"{self.scheme}://{self.addr}" else: + self.scheme = "unix" self.addr = host[7:] # strip "unix://" self.url = host - self.log = None - - def tail_log(self, path): - # surrogateescape allows for handling of file streams - # containing junk binary values as normal text streams - self.log = open(path, errors="surrogateescape") - self.log.read() - - def connect(self, **kwargs): - protocol = self.url.partition(":")[0] - - if protocol == "https": + self._app_name = app_name + self._server_kwargs = server_kwargs + self._tmp_path = tmp_path + self._log_write: t.IO[bytes] | None = None + self._log_read: t.IO[str] | None = None + self._proc: subprocess.Popen[bytes] | None = None + + def __enter__(self) -> te.Self: + """Start the server process and wait for it to be ready.""" + log_path = self._tmp_path / "log.txt" + self._log_write = open(log_path, "wb") + self._log_read = open(log_path, encoding="utf8", errors="surrogateescape") + tmp_dir = os.fspath(self._tmp_path) + self._proc = subprocess.Popen( + [ + sys.executable, + os.fspath(Path(__file__).parent / "live_apps/run.py"), + self._app_name, + json.dumps(self._server_kwargs), + ], + env={**os.environ, "PYTHONUNBUFFERED": "1", "PYTHONPATH": tmp_dir}, + cwd=tmp_dir, + close_fds=True, + stdout=self._log_write, + stderr=subprocess.STDOUT, + ) + self.wait_ready() + return self + + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + """Clean up the server process.""" + assert self._proc is not None + self._proc.terminate() + self._proc.wait() + self._proc = None + assert self._log_read is not None + self._log_read.close() + self._log_read = None + assert self._log_write is not None + self._log_write.close() + self._log_write = None + + def connect(self, **kwargs: t.Any) -> http.client.HTTPConnection: + """Create a connection to the server, without sending a request. + Useful if a test requires lower level methods to try something that + ``HTTPClient.request`` will not do. + + If the server's scheme is HTTPS and the TLS ``context`` argument is not + given, a default permissive context is used. + + :param kwargs: Arguments to :class:`http.client.HTTPConnection`. + """ + if self.scheme == "https": if "context" not in kwargs: context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.check_hostname = False @@ -58,21 +152,31 @@ def connect(self, **kwargs): return http.client.HTTPSConnection(self.addr, **kwargs) - if protocol == "unix": + if self.scheme == "unix": return UnixSocketHTTPConnection(self.addr, **kwargs) return http.client.HTTPConnection(self.addr, **kwargs) - def request(self, path="", **kwargs): + def request(self, url: str = "", **kwargs: t.Any) -> DataHTTPResponse: + """Open a connection and make a request to the server, returning the + response. + + The response object ``data`` parameter has the result of + ``response.read()``. If the response has a ``application/json`` content + type, the ``json`` parameter is populated with ``json.loads(data)``. + + :param url: URL to put in the request line. + :param kwargs: Arguments to :meth:`http.client.HTTPConnection.request`. + """ kwargs.setdefault("method", "GET") - kwargs.setdefault("url", path) - conn = self.connect() - conn.request(**kwargs) + kwargs["url"] = url + response: DataHTTPResponse - with conn.getresponse() as response: - response.data = response.read() + with closing(self.connect()) as conn: + conn.request(**kwargs) - conn.close() + with conn.getresponse() as response: # type: ignore[assignment] + response.data = response.read() if response.headers.get("Content-Type", "").startswith("application/json"): response.json = json.loads(response.data) @@ -81,53 +185,66 @@ def request(self, path="", **kwargs): return response - def wait_for_log(self, start): + def wait_ready(self) -> None: + """Wait until a request to ``/ensure`` is successful, indicating the + server has started and is listening. + """ while True: - for line in self.log: - if line.startswith(start): - return + try: + self.request("/ensure") + return + # ConnectionRefusedError for http, FileNotFoundError for unix + except (ConnectionRefusedError, FileNotFoundError): + time.sleep(0.1) - def wait_for_reload(self): - self.wait_for_log(" * Restarting with ") + def read_log(self) -> str: + """Read from the current position to the current end of the log.""" + assert self._log_read is not None + return self._log_read.read() + def wait_for_log(self, value: str) -> None: + """Wait until a line in the log contains the given string. -@pytest.fixture() -def dev_server(xprocess, request, tmp_path): - """A function that will start a dev server in an external process - and return a client for interacting with the server. - """ + :param value: The string to search for. + """ + assert self._log_read is not None - def start_dev_server(name="standard", **kwargs): - client = DevServerClient(kwargs) + while True: + for line in self._log_read: + if value in line: + return + + time.sleep(0.1) + + def wait_for_reload(self) -> None: + """Wait until the server logs that it is restarting, then wait for it to + be ready. + """ + self.wait_for_log("Restarting with") + self.wait_ready() - class Starter(ProcessStarter): - args = [sys.executable, run_path, name, json.dumps(kwargs)] - # Extend the existing env, otherwise Windows and CI fails. - # Modules will be imported from tmp_path for the reloader. - # Unbuffered output so the logs update immediately. - env = {**os.environ, "PYTHONPATH": str(tmp_path), "PYTHONUNBUFFERED": "1"} - @cached_property - def pattern(self): - client.request("/ensure") - return "GET /ensure" +class StartDevServer(t.Protocol): + def __call__(self, name: str = "standard", **kwargs: t.Any) -> DevServerClient: ... - # Each test that uses the fixture will have a different log. - xp_name = f"dev_server-{request.node.name}" - _, log_path = xprocess.ensure(xp_name, Starter, restart=True) - client.tail_log(log_path) - @request.addfinalizer - def close(): - xprocess.getinfo(xp_name).terminate() - client.log.close() +@pytest.fixture() +def dev_server(tmp_path: Path) -> cabc.Iterator[StartDevServer]: + """A function that will start a dev server in a subprocess and return a + client for interacting with the server. + """ + exit_stack = ExitStack() + def start_dev_server(name: str = "standard", **kwargs: t.Any) -> DevServerClient: + client = DevServerClient(name, tmp_path=tmp_path, **kwargs) + exit_stack.enter_context(client) # type: ignore[arg-type] return client - return start_dev_server + with exit_stack: + yield start_dev_server @pytest.fixture() -def standard_app(dev_server): +def standard_app(dev_server: t.Callable[..., DevServerClient]) -> DevServerClient: """Equivalent to ``dev_server("standard")``.""" return dev_server() diff --git a/tests/live_apps/data_app.py b/tests/live_apps/data_app.py index 561390a1c..9b2e78b91 100644 --- a/tests/live_apps/data_app.py +++ b/tests/live_apps/data_app.py @@ -11,7 +11,7 @@ def app(request: Request) -> Response: { "environ": request.environ, "form": request.form.to_dict(), - "files": {k: v.read().decode("utf8") for k, v in request.files.items()}, + "files": {k: v.read().decode() for k, v in request.files.items()}, }, default=lambda x: str(x), ), diff --git a/tests/live_apps/run.py b/tests/live_apps/run.py index aacdcb664..1371e6723 100644 --- a/tests/live_apps/run.py +++ b/tests/live_apps/run.py @@ -4,6 +4,7 @@ from werkzeug.serving import generate_adhoc_ssl_context from werkzeug.serving import run_simple +from werkzeug.serving import WSGIRequestHandler from werkzeug.wrappers import Request from werkzeug.wrappers import Response @@ -23,10 +24,14 @@ def app(request): kwargs.update(hostname="127.0.0.1", port=5000, application=app) kwargs.update(json.loads(sys.argv[2])) ssl_context = kwargs.get("ssl_context") +override_client_addr = kwargs.pop("override_client_addr", None) if ssl_context == "custom": kwargs["ssl_context"] = generate_adhoc_ssl_context() elif isinstance(ssl_context, list): kwargs["ssl_context"] = tuple(ssl_context) +if override_client_addr: + WSGIRequestHandler.address_string = lambda _: override_client_addr + run_simple(**kwargs) diff --git a/tests/middleware/test_http_proxy.py b/tests/middleware/test_http_proxy.py index a1497c5cc..5e1f005b2 100644 --- a/tests/middleware/test_http_proxy.py +++ b/tests/middleware/test_http_proxy.py @@ -5,7 +5,6 @@ from werkzeug.wrappers import Response -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server def test_http_proxy(standard_app): app = ProxyMiddleware( diff --git a/tests/middleware/test_profiler.py b/tests/middleware/test_profiler.py new file mode 100644 index 000000000..585aeb54b --- /dev/null +++ b/tests/middleware/test_profiler.py @@ -0,0 +1,50 @@ +import datetime +import os +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import patch + +from werkzeug.middleware.profiler import Profile +from werkzeug.middleware.profiler import ProfilerMiddleware +from werkzeug.test import Client + + +def dummy_application(environ, start_response): + start_response("200 OK", [("Content-Type", "text/plain")]) + return [b"Foo"] + + +def test_filename_format_function(): + # This should be called once with the generated file name + mock_capture_name = MagicMock() + + def filename_format(env): + now = datetime.datetime.fromtimestamp(env["werkzeug.profiler"]["time"]) + timestamp = now.strftime("%Y-%m-%d:%H:%M:%S") + path = ( + "_".join(token for token in env["PATH_INFO"].split("/") if token) or "ROOT" + ) + elapsed = env["werkzeug.profiler"]["elapsed"] + name = f"{timestamp}.{env['REQUEST_METHOD']}.{path}.{elapsed:.0f}ms.prof" + mock_capture_name(name=name) + return name + + client = Client( + ProfilerMiddleware( + dummy_application, + stream=None, + profile_dir="profiles", + filename_format=filename_format, + ) + ) + + # Replace the Profile class with a function that simulates an __init__() + # call and returns our mock instance. + mock_profile = MagicMock(wraps=Profile()) + mock_profile.dump_stats = MagicMock() + with patch("werkzeug.middleware.profiler.Profile", lambda: mock_profile): + client.get("/foo/bar") + + mock_capture_name.assert_called_once_with(name=ANY) + name = mock_capture_name.mock_calls[0].kwargs["name"] + mock_profile.dump_stats.assert_called_once_with(os.path.join("profiles", name)) diff --git a/tests/sansio/test_multipart.py b/tests/sansio/test_multipart.py index 35109d4bd..cf36fefd6 100644 --- a/tests/sansio/test_multipart.py +++ b/tests/sansio/test_multipart.py @@ -24,11 +24,7 @@ def test_decoder_simple() -> None: asdasd -----------------------------9704338192090380615194531385$-- - """.replace( - "\n", "\r\n" - ).encode( - "utf-8" - ) + """.replace("\n", "\r\n").encode() decoder.receive_data(data) decoder.receive_data(None) events = [decoder.next_event()] @@ -147,11 +143,7 @@ def test_empty_field() -> None: Content-Type: text/plain; charset="UTF-8" --foo-- - """.replace( - "\n", "\r\n" - ).encode( - "utf-8" - ) + """.replace("\n", "\r\n").encode() decoder.receive_data(data) decoder.receive_data(None) events = [decoder.next_event()] diff --git a/tests/sansio/test_utils.py b/tests/sansio/test_utils.py index 04d02e44c..a63e7c660 100644 --- a/tests/sansio/test_utils.py +++ b/tests/sansio/test_utils.py @@ -1,7 +1,5 @@ from __future__ import annotations -import typing as t - import pytest from werkzeug.sansio.utils import get_content_length @@ -16,20 +14,24 @@ ("https", "spam", None, "spam"), ("https", "spam:443", None, "spam"), ("http", "spam:8080", None, "spam:8080"), + ("http", "127.0.0.1:8080", None, "127.0.0.1:8080"), + ("http", "[::1]:8080", None, "[::1]:8080"), ("ws", "spam", None, "spam"), ("ws", "spam:80", None, "spam"), ("wss", "spam", None, "spam"), ("wss", "spam:443", None, "spam"), ("http", None, ("spam", 80), "spam"), ("http", None, ("spam", 8080), "spam:8080"), + ("http", None, ("127.0.0.1", 8080), "127.0.0.1:8080"), + ("http", None, ("::1", 8080), "[::1]:8080"), ("http", None, ("unix/socket", None), "unix/socket"), ("http", "spam", ("eggs", 80), "spam"), ], ) def test_get_host( scheme: str, - host_header: t.Optional[str], - server: t.Optional[t.Tuple[str, t.Optional[int]]], + host_header: str | None, + server: tuple[str, int | None] | None, expected: str, ) -> None: assert get_host(scheme, host_header, server) == expected diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 5206aa6a2..0cd497438 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import pickle import tempfile @@ -10,6 +12,8 @@ from werkzeug import datastructures as ds from werkzeug import http +from werkzeug.datastructures.structures import _ImmutableOrderedMultiDict +from werkzeug.datastructures.structures import _OrderedMultiDict from werkzeug.exceptions import BadRequestKeyError @@ -42,7 +46,7 @@ def items(self, multi=1): class _MutableMultiDictTests: - storage_class: t.Type["ds.MultiDict"] + storage_class: type[ds.MultiDict] def test_pickle(self): cls = self.storage_class @@ -257,9 +261,20 @@ def test_basic_interface(self): md.setlist("foo", [1, 2]) assert md.getlist("foo") == [1, 2] + def test_or(self) -> None: + a = self.storage_class({"x": 1}) + b = a | {"y": 2} + assert isinstance(b, self.storage_class) + assert "x" in b and "y" in b + + def test_ior(self) -> None: + a = self.storage_class({"x": 1}) + a |= {"y": 2} + assert "x" in a and "y" in a + class _ImmutableDictTests: - storage_class: t.Type[dict] + storage_class: type[dict] def test_follows_dict_interface(self): cls = self.storage_class @@ -304,6 +319,17 @@ def test_dict_is_hashable(self): assert immutable in x assert immutable2 in x + def test_or(self) -> None: + a = self.storage_class({"x": 1}) + b = a | {"y": 2} + assert "x" in b and "y" in b + + def test_ior(self) -> None: + a = self.storage_class({"x": 1}) + + with pytest.raises(TypeError): + a |= {"y": 2} + class TestImmutableTypeConversionDict(_ImmutableDictTests): storage_class = ds.ImmutableTypeConversionDict @@ -334,8 +360,9 @@ class TestImmutableDict(_ImmutableDictTests): storage_class = ds.ImmutableDict +@pytest.mark.filterwarnings("ignore:'OrderedMultiDict':DeprecationWarning") class TestImmutableOrderedMultiDict(_ImmutableDictTests): - storage_class = ds.ImmutableOrderedMultiDict + storage_class = _ImmutableOrderedMultiDict def test_ordered_multidict_is_hashable(self): a = self.storage_class([("a", 1), ("b", 1), ("a", 2)]) @@ -413,8 +440,9 @@ def test_getitem_raise_badrequestkeyerror_for_empty_list_value(self): md["empty"] +@pytest.mark.filterwarnings("ignore:'OrderedMultiDict':DeprecationWarning") class TestOrderedMultiDict(_MutableMultiDictTests): - storage_class = ds.OrderedMultiDict + storage_class = _OrderedMultiDict def test_ordered_interface(self): cls = self.storage_class @@ -550,8 +578,9 @@ def test_value_conversion(self): assert d.get("foo", type=int) == 1 def test_return_default_when_conversion_is_not_possible(self): - d = self.storage_class(foo="bar") + d = self.storage_class(foo="bar", baz=None) assert d.get("foo", default=-1, type=int) == -1 + assert d.get("baz", default=-1, type=int) == -1 def test_propagate_exceptions_in_conversion(self): d = self.storage_class(foo="bar") @@ -795,6 +824,17 @@ def test_equality(self): assert h1 == h2 + def test_or(self) -> None: + a = ds.Headers({"x": 1}) + b = a | {"y": 2} + assert isinstance(b, ds.Headers) + assert "x" in b and "y" in b + + def test_ior(self) -> None: + a = ds.Headers({"x": 1}) + a |= {"y": 2} + assert "x" in a and "y" in a + class TestEnvironHeaders: storage_class = ds.EnvironHeaders @@ -836,6 +876,22 @@ def test_return_type_is_str(self): assert headers["Foo"] == "\xe2\x9c\x93" assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93") + def test_or(self) -> None: + headers = ds.EnvironHeaders({"x": "1"}) + + with pytest.raises(TypeError): + headers | {"y": "2"} + + def test_ior(self) -> None: + headers = ds.EnvironHeaders({}) + + with pytest.raises(TypeError): + headers |= {"y": "2"} + + def test_str(self) -> None: + headers = ds.EnvironHeaders({"CONTENT_LENGTH": "50", "HTTP_HOST": "test"}) + assert str(headers) == "Content-Length: 50\r\nHost: test\r\n\r\n" + class TestHeaderSet: storage_class = ds.HeaderSet @@ -923,7 +979,7 @@ def test_callback_dict_writes(self): assert_calls, func = make_call_asserter() initial = {"a": "foo", "b": "bar"} dct = self.storage_class(initial=initial, on_update=func) - with assert_calls(8, "callback not triggered by write method"): + with assert_calls(9, "callback not triggered by write method"): # always-write methods dct["z"] = 123 dct["z"] = 123 # must trigger again @@ -933,6 +989,7 @@ def test_callback_dict_writes(self): dct.popitem() dct.update([]) dct.clear() + dct |= {} with assert_calls(0, "callback triggered by failed del"): pytest.raises(KeyError, lambda: dct.__delitem__("x")) with assert_calls(0, "callback triggered by failed pop"): @@ -950,7 +1007,39 @@ def test_set_none(self): cc.no_cache = None assert cc.no_cache is None cc.no_cache = False - assert cc.no_cache is False + assert cc.no_cache is None + + def test_no_transform(self): + cc = ds.RequestCacheControl([("no-transform", None)]) + assert cc.no_transform is True + cc = ds.RequestCacheControl() + assert cc.no_transform is False + + def test_min_fresh(self): + cc = ds.RequestCacheControl([("min-fresh", "0")]) + assert cc.min_fresh == 0 + cc = ds.RequestCacheControl([("min-fresh", None)]) + assert cc.min_fresh is None + cc = ds.RequestCacheControl() + assert cc.min_fresh is None + + def test_must_understand(self): + cc = ds.ResponseCacheControl([("must-understand", None)]) + assert cc.must_understand is True + cc = ds.ResponseCacheControl() + assert cc.must_understand is False + + def test_stale_while_revalidate(self): + cc = ds.ResponseCacheControl([("stale-while-revalidate", "1")]) + assert cc.stale_while_revalidate == 1 + cc = ds.ResponseCacheControl() + assert cc.stale_while_revalidate is None + + def test_stale_if_error(self): + cc = ds.ResponseCacheControl([("stale-if-error", "1")]) + assert cc.stale_if_error == 1 + cc = ds.ResponseCacheControl() + assert cc.stale_while_revalidate is None class TestContentSecurityPolicy: @@ -1194,3 +1283,15 @@ def test_range_to_header(ranges): def test_range_validates_ranges(ranges): with pytest.raises(ValueError): ds.Range("bytes", ranges) + + +@pytest.mark.parametrize( + ("value", "expect"), + [ + ({"a": "ab"}, [("a", "ab")]), + ({"a": ["a", "b"]}, [("a", "a"), ("a", "b")]), + ({"a": b"ab"}, [("a", b"ab")]), + ], +) +def test_iter_multi_data(value: t.Any, expect: list[tuple[t.Any, t.Any]]) -> None: + assert list(ds.iter_multi_items(value)) == expect diff --git a/tests/test_debug.py b/tests/test_debug.py index cf171d1a5..f51779cbc 100644 --- a/tests/test_debug.py +++ b/tests/test_debug.py @@ -245,7 +245,6 @@ def test_get_machine_id(): assert isinstance(rv, bytes) -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.parametrize("crash", (True, False)) @pytest.mark.dev_server def test_basic(dev_server, crash): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index e4ee58633..67d76d2b4 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -7,6 +7,7 @@ from werkzeug import exceptions from werkzeug.datastructures import Headers from werkzeug.datastructures import WWWAuthenticate +from werkzeug.exceptions import default_exceptions from werkzeug.exceptions import HTTPException from werkzeug.wrappers import Response @@ -36,6 +37,7 @@ def test_proxy_exception(): (exceptions.RequestEntityTooLarge, 413), (exceptions.RequestURITooLarge, 414), (exceptions.UnsupportedMediaType, 415), + (exceptions.MisdirectedRequest, 421), (exceptions.UnprocessableEntity, 422), (exceptions.Locked, 423), (exceptions.InternalServerError, 500), @@ -138,7 +140,7 @@ def test_retry_after_mixin(cls, value, expect): @pytest.mark.parametrize( "cls", sorted( - (e for e in HTTPException.__subclasses__() if e.code and e.code >= 400), + (e for e in default_exceptions.values() if e.code and e.code >= 400), key=lambda e: e.code, # type: ignore ), ) @@ -158,7 +160,7 @@ def test_description_none(): @pytest.mark.parametrize( "cls", sorted( - (e for e in HTTPException.__subclasses__() if e.code), + (e for e in default_exceptions.values() if e.code), key=lambda e: e.code, # type: ignore ), ) diff --git a/tests/test_formparser.py b/tests/test_formparser.py index 1dcb167ef..ebd7fddcf 100644 --- a/tests/test_formparser.py +++ b/tests/test_formparser.py @@ -122,13 +122,21 @@ def test_limiting(self): req.max_form_parts = 1 pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - def test_x_www_urlencoded_max_form_parts(self): + def test_urlencoded_no_max(self) -> None: r = Request.from_values(method="POST", data={"a": 1, "b": 2}) r.max_form_parts = 1 assert r.form["a"] == "1" assert r.form["b"] == "2" + def test_urlencoded_silent_decode(self) -> None: + r = Request.from_values( + data=b"\x80", + content_type="application/x-www-form-urlencoded", + method="POST", + ) + assert not r.form + def test_missing_multipart_boundary(self): data = ( b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n" @@ -273,7 +281,7 @@ def test_basic(self): content_type=f'multipart/form-data; boundary="{boundary}"', content_length=len(data), ) as response: - assert response.get_data() == repr(text).encode("utf-8") + assert response.get_data() == repr(text).encode() @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") def test_ie7_unc_path(self): @@ -448,3 +456,15 @@ def test_file_rfc2231_filename_continuations(self): ) as request: assert request.files["rfc2231"].filename == "a b c d e f.txt" assert request.files["rfc2231"].read() == b"file contents" + + +def test_multipart_max_form_memory_size() -> None: + """max_form_memory_size is tracked across multiple data events.""" + data = b"--bound\r\nContent-Disposition: form-field; name=a\r\n\r\n" + data += b"a" * 15 + b"\r\n--bound--" + # The buffer size is less than the max size, so multiple data events will be + # returned. The field size is greater than the max. + parser = formparser.MultiPartParser(max_form_memory_size=10, buffer_size=5) + + with pytest.raises(RequestEntityTooLarge): + parser.parse(io.BytesIO(data), b"bound", None) diff --git a/tests/test_http.py b/tests/test_http.py index bbd51ba33..726b40bca 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -107,14 +107,21 @@ def test_set_header(self): def test_list_header(self, value, expect): assert http.parse_list_header(value) == expect - def test_dict_header(self): - d = http.parse_dict_header('foo="bar baz", blah=42') - assert d == {"foo": "bar baz", "blah": "42"} + @pytest.mark.parametrize( + ("value", "expect"), + [ + ('foo="bar baz", blah=42', {"foo": "bar baz", "blah": "42"}), + ("foo, bar=", {"foo": None, "bar": ""}), + ("=foo, =", {}), + ], + ) + def test_dict_header(self, value, expect): + assert http.parse_dict_header(value) == expect def test_cache_control_header(self): cc = http.parse_cache_control_header("max-age=0, no-cache") assert cc.max_age == 0 - assert cc.no_cache + assert cc.no_cache is True cc = http.parse_cache_control_header( 'private, community="UCI"', None, datastructures.ResponseCacheControl ) @@ -125,17 +132,17 @@ def test_cache_control_header(self): assert c.no_cache is None assert c.private is None c.no_cache = True - assert c.no_cache == "*" + assert c.no_cache and c.no_cache is True c.private = True - assert c.private == "*" + assert c.private and c.private is True del c.private - assert c.private is None + assert not c.private and c.private is None # max_age is an int, other types are converted c.max_age = 3.1 - assert c.max_age == 3 + assert c.max_age == 3 and c["max-age"] == "3" del c.max_age c.s_maxage = 3.1 - assert c.s_maxage == 3 + assert c.s_maxage == 3 and c["s-maxage"] == "3" del c.s_maxage assert c.to_header() == "no-cache" @@ -204,6 +211,10 @@ def test_authorization_header(self): assert Authorization.from_header(None) is None assert Authorization.from_header("foo").type == "foo" + def test_authorization_ignore_invalid_parameters(self): + a = Authorization.from_header("Digest foo, bar=, =qux, =") + assert a.to_header() == 'Digest foo, bar=""' + def test_authorization_token_padding(self): # padded with = token = base64.b64encode(b"This has base64 padding").decode() @@ -361,8 +372,8 @@ def test_parse_options_header_empty(self, value, expect): ('v;a="b\\"c";d=e', {"a": 'b"c', "d": "e"}), # HTTP headers use \\ for internal \ ('v;a="c:\\\\"', {"a": "c:\\"}), - # Invalid trailing slash in quoted part is left as-is. - ('v;a="c:\\"', {"a": "c:\\"}), + # Part with invalid trailing slash is discarded. + ('v;a="c:\\"', {}), ('v;a="b\\\\\\"c"', {"a": 'b\\"c'}), # multipart form data uses %22 for internal " ('v;a="b%22c"', {"a": 'b"c'}), @@ -377,6 +388,8 @@ def test_parse_options_header_empty(self, value, expect): ("v;a*0=b;a*1=c;d=e", {"a": "bc", "d": "e"}), ("v;a*0*=b", {"a": "b"}), ("v;a*0*=UTF-8''b;a*1=c;a*2*=%C2%B5", {"a": "bcµ"}), + # Long invalid quoted string with trailing slashes does not freeze. + ('v;a="' + "\\" * 400, {}), ], ) def test_parse_options_header(self, value, expect) -> None: @@ -576,6 +589,14 @@ def test_cookie_samesite_invalid(self): with pytest.raises(ValueError): http.dump_cookie("foo", "bar", samesite="invalid") + def test_cookie_partitioned(self): + value = http.dump_cookie("foo", "bar", partitioned=True, secure=True) + assert value == "foo=bar; Secure; Path=/; Partitioned" + + def test_cookie_partitioned_sets_secure(self): + value = http.dump_cookie("foo", "bar", partitioned=True, secure=False) + assert value == "foo=bar; Secure; Path=/; Partitioned" + class TestRange: def test_if_range_parsing(self): diff --git a/tests/test_local.py b/tests/test_local.py index 2af69d2d6..2250a5bee 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -170,7 +170,7 @@ class SomeClassWithWrapped: _cv_val.set(42) with pytest.raises(AttributeError): - proxy.__wrapped__ + proxy.__wrapped__ # noqa: B018 ns = local.Local(_cv_ns) ns.foo = SomeClassWithWrapped() @@ -179,7 +179,7 @@ class SomeClassWithWrapped: assert ns("foo").__wrapped__ == "wrapped" with pytest.raises(AttributeError): - ns("bar").__wrapped__ + ns("bar").__wrapped__ # noqa: B018 def test_proxy_doc(): diff --git a/tests/test_routing.py b/tests/test_routing.py index 65d2a5f90..02db898d6 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -95,6 +95,7 @@ def test_merge_slashes_match(): r.Rule("/yes/tail/", endpoint="yes_tail"), r.Rule("/with/", endpoint="with_path"), r.Rule("/no//merge", endpoint="no_merge", merge_slashes=False), + r.Rule("/no/merging", endpoint="no_merging", merge_slashes=False), ] ) adapter = url_map.bind("localhost", "/") @@ -124,6 +125,9 @@ def test_merge_slashes_match(): assert adapter.match("/no//merge")[0] == "no_merge" + assert adapter.match("/no/merging")[0] == "no_merging" + pytest.raises(NotFound, lambda: adapter.match("/no//merging")) + @pytest.mark.parametrize( ("path", "expected"), @@ -791,7 +795,7 @@ def __init__(self, url_map, *items): self.regex = items[0] # This is a regex pattern with nested groups - DATE_PATTERN = r"((\d{8}T\d{6}([.,]\d{1,3})?)|(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}([.,]\d{1,3})?))Z" # noqa: B950 + DATE_PATTERN = r"((\d{8}T\d{6}([.,]\d{1,3})?)|(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}([.,]\d{1,3})?))Z" # noqa: E501 map = r.Map( [ @@ -1072,6 +1076,9 @@ def test_converter_parser(): args, kwargs = r.parse_converter_args('"foo", "bar"') assert args == ("foo", "bar") + with pytest.raises(ValueError): + r.parse_converter_args("min=0;max=500") + def test_alias_redirects(): m = r.Map( diff --git a/tests/test_security.py b/tests/test_security.py index 0ef1eb052..455936879 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,5 +1,4 @@ import os -import posixpath import sys import pytest @@ -11,7 +10,7 @@ def test_default_password_method(): value = generate_password_hash("secret") - assert value.startswith("pbkdf2:") + assert value.startswith("scrypt:") @pytest.mark.xfail( @@ -26,7 +25,7 @@ def test_scrypt(): def test_pbkdf2(): value = generate_password_hash("secret", method="pbkdf2") assert check_password_hash(value, "secret") - assert value.startswith("pbkdf2:sha256:600000$") + assert value.startswith("pbkdf2:sha256:1000000$") def test_salted_hashes(): @@ -42,11 +41,22 @@ def test_require_salt(): generate_password_hash("secret", salt_length=0) -def test_safe_join(): - assert safe_join("foo", "bar/baz") == posixpath.join("foo", "bar/baz") - assert safe_join("foo", "../bar/baz") is None - if os.name == "nt": - assert safe_join("foo", "foo\\bar") is None +def test_invalid_method(): + with pytest.raises(ValueError, match="Invalid hash method"): + generate_password_hash("secret", "sha256") + + +@pytest.mark.parametrize( + ("path", "expect"), + [ + ("b/c", "a/b/c"), + ("../b/c", None), + ("b\\c", None if os.name == "nt" else "a/b\\c"), + ("//b/c", None), + ], +) +def test_safe_join(path, expect): + assert safe_join("a", path) == expect def test_safe_join_os_sep(): diff --git a/tests/test_serving.py b/tests/test_serving.py index 4abc755d9..6dd9d9dc3 100644 --- a/tests/test_serving.py +++ b/tests/test_serving.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import collections.abc as cabc import http.client import json import os @@ -5,11 +8,14 @@ import socket import ssl import sys +import typing as t from io import BytesIO from pathlib import Path +from unittest.mock import Mock from unittest.mock import patch import pytest +from watchdog import version as watchdog_version from watchdog.events import EVENT_TYPE_MODIFIED from watchdog.events import EVENT_TYPE_OPENED from watchdog.events import FileModifiedEvent @@ -23,8 +29,11 @@ from werkzeug.serving import make_ssl_devcert from werkzeug.test import stream_encode_multipart +if t.TYPE_CHECKING: + from conftest import DevServerClient + from conftest import StartDevServer + -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.parametrize( "kwargs", [ @@ -41,7 +50,9 @@ ], ) @pytest.mark.dev_server -def test_server(tmp_path, dev_server, kwargs: dict): +def test_server( + tmp_path: Path, dev_server: StartDevServer, kwargs: dict[str, t.Any] +) -> None: if kwargs.get("hostname") == "unix": kwargs["hostname"] = f"unix://{tmp_path / 'test.sock'}" @@ -51,9 +62,8 @@ def test_server(tmp_path, dev_server, kwargs: dict): assert r.json["PATH_INFO"] == "/" -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_untrusted_host(standard_app): +def test_untrusted_host(standard_app: DevServerClient) -> None: r = standard_app.request( "http://missing.test:1337/index.html#ignore", headers={"x-base-url": standard_app.url}, @@ -65,45 +75,42 @@ def test_untrusted_host(standard_app): assert r.json["SERVER_PORT"] == port -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_double_slash_path(standard_app): +def test_double_slash_path(standard_app: DevServerClient) -> None: r = standard_app.request("//double-slash") assert "double-slash" not in r.json["HTTP_HOST"] assert r.json["PATH_INFO"] == "/double-slash" -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_500_error(standard_app): +def test_500_error(standard_app: DevServerClient) -> None: r = standard_app.request("/crash") assert r.status == 500 assert b"Internal Server Error" in r.data -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_ssl_dev_cert(tmp_path, dev_server): - client = dev_server(ssl_context=make_ssl_devcert(tmp_path)) +def test_ssl_dev_cert(tmp_path: Path, dev_server: StartDevServer) -> None: + client = dev_server(ssl_context=make_ssl_devcert(os.fspath(tmp_path))) r = client.request() assert r.json["wsgi.url_scheme"] == "https" -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_ssl_object(dev_server): +def test_ssl_object(dev_server: StartDevServer) -> None: client = dev_server(ssl_context="custom") r = client.request() assert r.json["wsgi.url_scheme"] == "https" -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.parametrize("reloader_type", ["stat", "watchdog"]) @pytest.mark.skipif( os.name == "nt" and "CI" in os.environ, reason="unreliable on Windows during CI" ) @pytest.mark.dev_server -def test_reloader_sys_path(tmp_path, dev_server, reloader_type): +def test_reloader_sys_path( + tmp_path: Path, dev_server: StartDevServer, reloader_type: str +) -> None: """This tests the general behavior of the reloader. It also tests that fixing an import error triggers a reload, not just Python retrying the failed import. @@ -115,29 +122,51 @@ def test_reloader_sys_path(tmp_path, dev_server, reloader_type): assert client.request().status == 500 shutil.copyfile(Path(__file__).parent / "live_apps" / "standard_app.py", real_path) - client.wait_for_log(f" * Detected change in {str(real_path)!r}, reloading") + client.wait_for_log(f"Detected change in {str(real_path)!r}") client.wait_for_reload() assert client.request().status == 200 @patch.object(WatchdogReloaderLoop, "trigger_reload") -def test_watchdog_reloader_ignores_opened(mock_trigger_reload): +def test_watchdog_reloader_ignores_opened(mock_trigger_reload: Mock) -> None: reloader = WatchdogReloaderLoop() modified_event = FileModifiedEvent("") modified_event.event_type = EVENT_TYPE_MODIFIED reloader.event_handler.on_any_event(modified_event) mock_trigger_reload.assert_called_once() - reloader.trigger_reload.reset_mock() - + mock_trigger_reload.reset_mock() opened_event = FileModifiedEvent("") opened_event.event_type = EVENT_TYPE_OPENED reloader.event_handler.on_any_event(opened_event) - reloader.trigger_reload.assert_not_called() + mock_trigger_reload.assert_not_called() + + +@pytest.mark.skipif( + watchdog_version.VERSION_MAJOR < 5, + reason="'closed no write' event introduced in watchdog 5.0", +) +@patch.object(WatchdogReloaderLoop, "trigger_reload") +def test_watchdog_reloader_ignores_closed_no_write(mock_trigger_reload: Mock) -> None: + from watchdog.events import EVENT_TYPE_CLOSED_NO_WRITE # type: ignore[attr-defined] + + reloader = WatchdogReloaderLoop() + modified_event = FileModifiedEvent("") + modified_event.event_type = EVENT_TYPE_MODIFIED + reloader.event_handler.on_any_event(modified_event) + mock_trigger_reload.assert_called_once() + + mock_trigger_reload.reset_mock() + opened_event = FileModifiedEvent("") + opened_event.event_type = EVENT_TYPE_CLOSED_NO_WRITE + reloader.event_handler.on_any_event(opened_event) + mock_trigger_reload.assert_not_called() @pytest.mark.skipif(sys.version_info >= (3, 10), reason="not needed on >= 3.10") -def test_windows_get_args_for_reloading(monkeypatch, tmp_path): +def test_windows_get_args_for_reloading( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: argv = [str(tmp_path / "test.exe"), "run"] monkeypatch.setattr("sys.executable", str(tmp_path / "python.exe")) monkeypatch.setattr("sys.argv", argv) @@ -147,9 +176,10 @@ def test_windows_get_args_for_reloading(monkeypatch, tmp_path): assert rv == argv -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.parametrize("find", [_find_stat_paths, _find_watchdog_paths]) -def test_exclude_patterns(find): +def test_exclude_patterns( + find: t.Callable[[set[str], set[str]], cabc.Iterable[str]], +) -> None: # Select a path to exclude from the unfiltered list, assert that it is present and # then gets excluded. paths = find(set(), set()) @@ -161,9 +191,8 @@ def test_exclude_patterns(find): assert not any(p.startswith(path_to_exclude) for p in paths) -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_wrong_protocol(standard_app): +def test_wrong_protocol(standard_app: DevServerClient) -> None: """An HTTPS request to an HTTP server doesn't show a traceback. https://github.com/pallets/werkzeug/pull/838 """ @@ -172,12 +201,11 @@ def test_wrong_protocol(standard_app): with pytest.raises(ssl.SSLError): conn.request("GET", f"https://{standard_app.addr}") - assert "Traceback" not in standard_app.log.read() + assert "Traceback" not in standard_app.read_log() -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_content_type_and_length(standard_app): +def test_content_type_and_length(standard_app: DevServerClient) -> None: r = standard_app.request() assert "CONTENT_TYPE" not in r.json assert "CONTENT_LENGTH" not in r.json @@ -187,15 +215,16 @@ def test_content_type_and_length(standard_app): assert r.json["CONTENT_LENGTH"] == "2" -def test_port_is_int(): +def test_port_is_int() -> None: with pytest.raises(TypeError, match="port must be an integer"): - run_simple("127.0.0.1", "5000", None) + run_simple("127.0.0.1", "5000", None) # type: ignore[arg-type] -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.parametrize("send_length", [False, True]) @pytest.mark.dev_server -def test_chunked_request(monkeypatch, dev_server, send_length): +def test_chunked_request( + monkeypatch: pytest.MonkeyPatch, dev_server: StartDevServer, send_length: bool +) -> None: stream, length, boundary = stream_encode_multipart( { "value": "this is text", @@ -235,9 +264,8 @@ def test_chunked_request(monkeypatch, dev_server, send_length): assert environ["wsgi.input_terminated"] -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_multiple_headers_concatenated(standard_app): +def test_multiple_headers_concatenated(standard_app: DevServerClient) -> None: """A header key can be sent multiple times. The server will join all the values with commas. @@ -260,9 +288,8 @@ def test_multiple_headers_concatenated(standard_app): assert data["HTTP_XYZ"] == "a ,b,c ,d" -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_multiline_header_folding(standard_app): +def test_multiline_header_folding(standard_app: DevServerClient) -> None: """A header value can be split over multiple lines with a leading tab. The server will remove the newlines and preserve the tabs. @@ -280,9 +307,8 @@ def test_multiline_header_folding(standard_app): @pytest.mark.parametrize("endpoint", ["", "crash"]) -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_streaming_close_response(dev_server, endpoint): +def test_streaming_close_response(dev_server: StartDevServer, endpoint: str) -> None: """When using HTTP/1.0, chunked encoding is not supported. Fall back to Connection: close, but this allows no reliable way to distinguish between complete and truncated responses. @@ -292,9 +318,8 @@ def test_streaming_close_response(dev_server, endpoint): assert r.data == "".join(str(x) + "\n" for x in range(5)).encode() -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_streaming_chunked_response(dev_server): +def test_streaming_chunked_response(dev_server: StartDevServer) -> None: """When using HTTP/1.1, use Transfer-Encoding: chunked for streamed responses, since it can distinguish the end of the response without closing the connection. @@ -306,11 +331,20 @@ def test_streaming_chunked_response(dev_server): assert r.data == "".join(str(x) + "\n" for x in range(5)).encode() -@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server -def test_streaming_chunked_truncation(dev_server): +def test_streaming_chunked_truncation(dev_server: StartDevServer) -> None: """When using HTTP/1.1, chunked encoding allows the client to detect content truncated by a prematurely closed connection. """ with pytest.raises(http.client.IncompleteRead): dev_server("streaming", threaded=True).request("/crash") + + +@pytest.mark.dev_server +def test_host_with_ipv6_scope(dev_server: StartDevServer) -> None: + client = dev_server(override_client_addr="fe80::1ff:fe23:4567:890a%eth2") + r = client.request("/crash") + + assert r.status == 500 + assert b"Internal Server Error" in r.data + assert "Logging error" not in client.read_log() diff --git a/tests/test_test.py b/tests/test_test.py index c7f21fa11..d317d69c9 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -16,6 +16,7 @@ from werkzeug.test import EnvironBuilder from werkzeug.test import run_wsgi_app from werkzeug.test import stream_encode_multipart +from werkzeug.test import TestResponse from werkzeug.utils import redirect from werkzeug.wrappers import Request from werkzeug.wrappers import Response @@ -903,3 +904,24 @@ def test_no_content_type_header_addition(): c = Client(no_response_headers_app) response = c.open() assert response.headers == Headers([("Content-Length", "8")]) + + +def test_client_response_wrapper(): + class CustomResponse(Response): + pass + + class CustomTestResponse(TestResponse, Response): + pass + + c1 = Client(Response(), CustomResponse) + r1 = c1.open() + + assert isinstance(r1, CustomResponse) + assert type(r1) is not CustomResponse # Got subclassed + assert issubclass(type(r1), CustomResponse) + + c2 = Client(Response(), CustomTestResponse) + r2 = c2.open() + + assert isinstance(r2, CustomTestResponse) + assert type(r2) is CustomTestResponse # Did not get subclassed diff --git a/tests/test_urls.py b/tests/test_urls.py index 0b0f2aeed..101b886ec 100644 --- a/tests/test_urls.py +++ b/tests/test_urls.py @@ -1,239 +1,20 @@ -import io -import warnings - import pytest from werkzeug import urls -from werkzeug.datastructures import OrderedMultiDict - -pytestmark = [ - pytest.mark.filterwarnings("ignore:'werkzeug:DeprecationWarning"), - pytest.mark.filterwarnings("ignore:'_?make_chunk_iter':DeprecationWarning"), -] - - -def test_parsing(): - url = urls.url_parse("http://anon:hunter2@[2001:db8:0:1]:80/a/b/c") - assert url.netloc == "anon:hunter2@[2001:db8:0:1]:80" - assert url.username == "anon" - assert url.password == "hunter2" - assert url.port == 80 - assert url.ascii_host == "2001:db8:0:1" - - assert url.get_file_location() == (None, None) # no file scheme - - -@pytest.mark.parametrize("implicit_format", (True, False)) -@pytest.mark.parametrize("localhost", ("127.0.0.1", "::1", "localhost")) -def test_fileurl_parsing_windows(implicit_format, localhost, monkeypatch): - if implicit_format: - pathformat = None - monkeypatch.setattr("os.name", "nt") - else: - pathformat = "windows" - monkeypatch.delattr("os.name") # just to make sure it won't get used - - url = urls.url_parse("file:///C:/Documents and Settings/Foobar/stuff.txt") - assert url.netloc == "" - assert url.scheme == "file" - assert url.get_file_location(pathformat) == ( - None, - r"C:\Documents and Settings\Foobar\stuff.txt", - ) - - url = urls.url_parse("file://///server.tld/file.txt") - assert url.get_file_location(pathformat) == ("server.tld", r"file.txt") - - url = urls.url_parse("file://///server.tld") - assert url.get_file_location(pathformat) == ("server.tld", "") - - url = urls.url_parse(f"file://///{localhost}") - assert url.get_file_location(pathformat) == (None, "") - - url = urls.url_parse(f"file://///{localhost}/file.txt") - assert url.get_file_location(pathformat) == (None, r"file.txt") - - -def test_replace(): - url = urls.url_parse("http://de.wikipedia.org/wiki/Troll") - assert url.replace(query="foo=bar") == urls.url_parse( - "http://de.wikipedia.org/wiki/Troll?foo=bar" - ) - assert url.replace(scheme="https") == urls.url_parse( - "https://de.wikipedia.org/wiki/Troll" - ) - - -def test_quoting(): - assert urls.url_quote("\xf6\xe4\xfc") == "%C3%B6%C3%A4%C3%BC" - assert urls.url_unquote(urls.url_quote('#%="\xf6')) == '#%="\xf6' - assert urls.url_quote_plus("foo bar") == "foo+bar" - assert urls.url_unquote_plus("foo+bar") == "foo bar" - assert urls.url_quote_plus("foo+bar") == "foo%2Bbar" - assert urls.url_unquote_plus("foo%2Bbar") == "foo+bar" - assert urls.url_encode({b"a": None, b"b": b"foo bar"}) == "b=foo+bar" - assert urls.url_encode({"a": None, "b": "foo bar"}) == "b=foo+bar" - assert ( - urls.url_fix("http://de.wikipedia.org/wiki/Elf (Begriffsklärung)") - == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)" - ) - assert urls.url_quote_plus(42) == "42" - assert urls.url_quote(b"\xff") == "%FF" - - -def test_bytes_unquoting(): - assert ( - urls.url_unquote(urls.url_quote('#%="\xf6', charset="latin1"), charset=None) - == b'#%="\xf6' - ) - - -def test_url_decoding(): - x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel") - assert x["foo"] == "42" - assert x["bar"] == "23" - assert x["uni"] == "Hänsel" - - x = urls.url_decode(b"foo=42;bar=23;uni=H%C3%A4nsel", separator=b";") - assert x["foo"] == "42" - assert x["bar"] == "23" - assert x["uni"] == "Hänsel" - - x = urls.url_decode(b"%C3%9Ch=H%C3%A4nsel") - assert x["Üh"] == "Hänsel" - - -def test_url_bytes_decoding(): - x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel", charset=None) - assert x[b"foo"] == b"42" - assert x[b"bar"] == b"23" - assert x[b"uni"] == "Hänsel".encode() - - -def test_stream_decoding_string_fails(): - pytest.raises(TypeError, urls.url_decode_stream, "testing") - - -def test_url_encoding(): - assert urls.url_encode({"foo": "bar 45"}) == "foo=bar+45" - d = {"foo": 1, "bar": 23, "blah": "Hänsel"} - assert urls.url_encode(d, sort=True) == "bar=23&blah=H%C3%A4nsel&foo=1" - assert ( - urls.url_encode(d, sort=True, separator=";") == "bar=23;blah=H%C3%A4nsel;foo=1" - ) - - -def test_sorted_url_encode(): - assert ( - urls.url_encode( - {"a": 42, "b": 23, 1: 1, 2: 2}, sort=True, key=lambda i: str(i[0]) - ) - == "1=1&2=2&a=42&b=23" - ) - assert ( - urls.url_encode( - {"A": 1, "a": 2, "B": 3, "b": 4}, - sort=True, - key=lambda x: x[0].lower() + x[0], - ) - == "A=1&a=2&B=3&b=4" - ) - - -def test_streamed_url_encoding(): - out = io.StringIO() - urls.url_encode_stream({"foo": "bar 45"}, out) - assert out.getvalue() == "foo=bar+45" - - d = {"foo": 1, "bar": 23, "blah": "Hänsel"} - out = io.StringIO() - urls.url_encode_stream(d, out, sort=True) - assert out.getvalue() == "bar=23&blah=H%C3%A4nsel&foo=1" - out = io.StringIO() - urls.url_encode_stream(d, out, sort=True, separator=";") - assert out.getvalue() == "bar=23;blah=H%C3%A4nsel;foo=1" - - gen = urls.url_encode_stream(d, sort=True) - assert next(gen) == "bar=23" - assert next(gen) == "blah=H%C3%A4nsel" - assert next(gen) == "foo=1" - pytest.raises(StopIteration, lambda: next(gen)) - - -def test_url_fixing(): - x = urls.url_fix("http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)") - assert x == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)" - - x = urls.url_fix("http://just.a.test/$-_.+!*'(),") - assert x == "http://just.a.test/$-_.+!*'()," - - x = urls.url_fix("http://höhöhö.at/höhöhö/hähähä") - assert x == r"http://xn--hhh-snabb.at/h%C3%B6h%C3%B6h%C3%B6/h%C3%A4h%C3%A4h%C3%A4" - - -def test_url_fixing_filepaths(): - x = urls.url_fix(r"file://C:\Users\Administrator\My Documents\ÑÈáÇíí") - assert x == ( - r"file:///C%3A/Users/Administrator/My%20Documents/" - r"%C3%91%C3%88%C3%A1%C3%87%C3%AD%C3%AD" - ) - - a = urls.url_fix(r"file:/C:/") - b = urls.url_fix(r"file://C:/") - c = urls.url_fix(r"file:///C:/") - assert a == b == c == r"file:///C%3A/" - - x = urls.url_fix(r"file://host/sub/path") - assert x == r"file://host/sub/path" - - x = urls.url_fix(r"file:///") - assert x == r"file:///" - - -def test_url_fixing_qs(): - x = urls.url_fix(b"http://example.com/?foo=%2f%2f") - assert x == "http://example.com/?foo=%2f%2f" - - x = urls.url_fix( - "http://acronyms.thefreedictionary.com/" - "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" - ) - assert x == ( - "http://acronyms.thefreedictionary.com/" - "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" - ) def test_iri_support(): assert urls.uri_to_iri("http://xn--n3h.net/") == "http://\u2603.net/" - - with pytest.deprecated_call(): - assert ( - urls.uri_to_iri(b"http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th") - == "http://\xfcser:p\xe4ssword@\u2603.net/p\xe5th" - ) - assert urls.iri_to_uri("http://☃.net/") == "http://xn--n3h.net/" assert ( urls.iri_to_uri("http://üser:pässword@☃.net/påth") == "http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th" ) - assert ( urls.uri_to_iri("http://test.com/%3Fmeh?foo=%26%2F") == "http://test.com/%3Fmeh?foo=%26/" ) - - # this should work as well, might break on 2.4 because of a broken - # idna codec - with pytest.deprecated_call(): - assert urls.uri_to_iri(b"/foo") == "/foo" - - with pytest.deprecated_call(): - assert urls.iri_to_uri(b"/foo") == "/foo" - assert urls.iri_to_uri("/foo") == "/foo" - assert ( urls.iri_to_uri("http://föö.com:8080/bam/baz") == "http://xn--f-1gaa.com:8080/bam/baz" @@ -247,83 +28,11 @@ def test_iri_safe_quoting(): assert urls.iri_to_uri(urls.uri_to_iri(uri)) == uri -def test_ordered_multidict_encoding(): - d = OrderedMultiDict() - d.add("foo", 1) - d.add("foo", 2) - d.add("foo", 3) - d.add("bar", 0) - d.add("foo", 4) - assert urls.url_encode(d) == "foo=1&foo=2&foo=3&bar=0&foo=4" - - -def test_multidict_encoding(): - d = OrderedMultiDict() - d.add("2013-10-10T23:26:05.657975+0000", "2013-10-10T23:26:05.657975+0000") - assert ( - urls.url_encode(d) - == "2013-10-10T23%3A26%3A05.657975%2B0000=2013-10-10T23%3A26%3A05.657975%2B0000" - ) - - -def test_url_unquote_plus_unicode(): - # was broken in 0.6 - assert urls.url_unquote_plus("\x6d") == "\x6d" - - def test_quoting_of_local_urls(): rv = urls.iri_to_uri("/foo\x8f") assert rv == "/foo%C2%8F" -def test_url_attributes(): - rv = urls.url_parse("http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") - assert rv.scheme == "http" - assert rv.auth == "foo%3a:bar%3a" - assert rv.username == "foo:" - assert rv.password == "bar:" - assert rv.raw_username == "foo%3a" - assert rv.raw_password == "bar%3a" - assert rv.host == "::1" - assert rv.port == 80 - assert rv.path == "/123" - assert rv.query == "x=y" - assert rv.fragment == "frag" - - rv = urls.url_parse("http://\N{SNOWMAN}.com/") - assert rv.host == "\N{SNOWMAN}.com" - assert rv.ascii_host == "xn--n3h.com" - - -def test_url_attributes_bytes(): - rv = urls.url_parse(b"http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") - assert rv.scheme == b"http" - assert rv.auth == b"foo%3a:bar%3a" - assert rv.username == "foo:" - assert rv.password == "bar:" - assert rv.raw_username == b"foo%3a" - assert rv.raw_password == b"bar%3a" - assert rv.host == b"::1" - assert rv.port == 80 - assert rv.path == b"/123" - assert rv.query == b"x=y" - assert rv.fragment == b"frag" - - -def test_url_joining(): - assert urls.url_join("/foo", "/bar") == "/bar" - assert urls.url_join("http://example.com/foo", "/bar") == "http://example.com/bar" - assert urls.url_join("file:///tmp/", "test.html") == "file:///tmp/test.html" - assert urls.url_join("file:///tmp/x", "test.html") == "file:///tmp/test.html" - assert urls.url_join("file:///tmp/x", "../../../x.html") == "file:///x.html" - - -def test_partial_unencoded_decode(): - ref = "foo=정상처리".encode("euc-kr") - x = urls.url_decode(ref, charset="euc-kr") - assert x["foo"] == "정상처리" - - def test_iri_to_uri_idempotence_ascii_only(): uri = "http://www.idempoten.ce" uri = urls.iri_to_uri(uri) @@ -391,10 +100,7 @@ def test_iri_to_uri_dont_quote_valid_code_points(): assert urls.iri_to_uri("/path[bracket]?(paren)") == "/path%5Bbracket%5D?(paren)" -def test_url_parse_does_not_clear_warnings_registry(recwarn): - warnings.simplefilter("default") - warnings.simplefilter("ignore", DeprecationWarning) - for _ in range(2): - urls.url_parse("http://example.org/") - warnings.warn("test warning") - assert len(recwarn) == 1 +# Python < 3.12 +def test_itms_services() -> None: + url = "itms-services://?action=download-manifest&url=https://test.example/path" + assert urls.iri_to_uri(url) == url diff --git a/tests/test_utils.py b/tests/test_utils.py index b7f1bcb1a..c48eba556 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -176,6 +176,7 @@ def test_assign(): def test_import_string(): from datetime import date + from werkzeug.debug import DebuggedApplication assert utils.import_string("datetime.date") is date diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 8a91aefc1..8bc063c74 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -16,11 +16,11 @@ from werkzeug.datastructures import Headers from werkzeug.datastructures import ImmutableList from werkzeug.datastructures import ImmutableMultiDict -from werkzeug.datastructures import ImmutableOrderedMultiDict from werkzeug.datastructures import LanguageAccept from werkzeug.datastructures import MIMEAccept from werkzeug.datastructures import MultiDict from werkzeug.datastructures import WWWAuthenticate +from werkzeug.datastructures.structures import _ImmutableOrderedMultiDict from werkzeug.exceptions import BadRequest from werkzeug.exceptions import RequestedRangeNotSatisfiable from werkzeug.exceptions import SecurityError @@ -998,9 +998,10 @@ def generate_items(): assert resp.response == ["foo", "bar", "baz"] +@pytest.mark.filterwarnings("ignore:'OrderedMultiDict':DeprecationWarning") def test_form_data_ordering(): class MyRequest(wrappers.Request): - parameter_storage_class = ImmutableOrderedMultiDict + parameter_storage_class = _ImmutableOrderedMultiDict req = MyRequest.from_values("/?foo=1&bar=0&foo=3") assert list(req.args) == ["foo", "bar"] @@ -1009,7 +1010,7 @@ class MyRequest(wrappers.Request): ("bar", "0"), ("foo", "3"), ] - assert isinstance(req.args, ImmutableOrderedMultiDict) + assert isinstance(req.args, _ImmutableOrderedMultiDict) assert isinstance(req.values, CombinedMultiDict) assert req.values["foo"] == "1" assert req.values.getlist("foo") == ["1", "3"] @@ -1037,25 +1038,25 @@ class MyRequest(wrappers.Request): parameter_storage_class = dict req = MyRequest.from_values("/?foo=baz", headers={"Cookie": "foo=bar"}) - assert type(req.cookies) is dict + assert type(req.cookies) is dict # noqa: E721 assert req.cookies == {"foo": "bar"} - assert type(req.access_route) is list + assert type(req.access_route) is list # noqa: E721 - assert type(req.args) is dict - assert type(req.values) is CombinedMultiDict + assert type(req.args) is dict # noqa: E721 + assert type(req.values) is CombinedMultiDict # noqa: E721 assert req.values["foo"] == "baz" req = wrappers.Request.from_values(headers={"Cookie": "foo=bar;foo=baz"}) - assert type(req.cookies) is ImmutableMultiDict + assert type(req.cookies) is ImmutableMultiDict # noqa: E721 assert req.cookies.to_dict() == {"foo": "bar"} # it is possible to have multiple cookies with the same name assert req.cookies.getlist("foo") == ["bar", "baz"] - assert type(req.access_route) is ImmutableList + assert type(req.access_route) is ImmutableList # noqa: E721 MyRequest.list_storage_class = tuple req = MyRequest.from_values() - assert type(req.access_route) is tuple + assert type(req.access_route) is tuple # noqa: E721 def test_response_headers_passthrough(): @@ -1154,6 +1155,7 @@ class MyResponse(wrappers.Response): ("auto", "location", "expect"), ( (False, "/test", "/test"), + (False, "/\\\\test.example?q", "/%5C%5Ctest.example?q"), (True, "/test", "http://localhost/test"), (True, "test", "http://localhost/a/b/test"), (True, "./test", "http://localhost/a/b/test"), diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index 5f37aca97..7f4d2e9cf 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -256,142 +256,6 @@ def test_get_current_url_invalid_utf8(): assert rv == "http://localhost/?foo=bar&baz=blah&meh=%CF" -@pytest.mark.filterwarnings("ignore:'make_line_iter:DeprecationWarning") -@pytest.mark.filterwarnings("ignore:'_make_chunk_iter:DeprecationWarning") -def test_multi_part_line_breaks(): - data = b"abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) - assert lines == [ - b"abcdef\r\n", - b"ghijkl\r\n", - b"mnopqrstuvwxyz\r\n", - b"ABCDEFGHIJK", - ] - - data = b"abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) - assert lines == [ - b"abc\r\n", - b"This line is broken by the buffer length.\r\n", - b"Foo bar baz", - ] - - -@pytest.mark.filterwarnings("ignore:'make_line_iter:DeprecationWarning") -@pytest.mark.filterwarnings("ignore:'_make_chunk_iter:DeprecationWarning") -def test_multi_part_line_breaks_bytes(): - data = b"abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) - assert lines == [ - b"abcdef\r\n", - b"ghijkl\r\n", - b"mnopqrstuvwxyz\r\n", - b"ABCDEFGHIJK", - ] - - data = b"abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) - assert lines == [ - b"abc\r\n", - b"This line is broken by the buffer length.\r\n", - b"Foo bar baz", - ] - - -@pytest.mark.filterwarnings("ignore:'make_line_iter:DeprecationWarning") -@pytest.mark.filterwarnings("ignore:'_make_chunk_iter:DeprecationWarning") -def test_multi_part_line_breaks_problematic(): - data = b"abc\rdef\r\nghi" - for _ in range(1, 10): - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=4)) - assert lines == [b"abc\r", b"def\r\n", b"ghi"] - - -@pytest.mark.filterwarnings("ignore:'make_line_iter:DeprecationWarning") -@pytest.mark.filterwarnings("ignore:'_make_chunk_iter:DeprecationWarning") -def test_iter_functions_support_iterators(): - data = ["abcdef\r\nghi", "jkl\r\nmnopqrstuvwxyz\r", "\nABCDEFGHIJK"] - lines = list(wsgi.make_line_iter(data)) - assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"] - - -@pytest.mark.filterwarnings("ignore:'_?make_chunk_iter:DeprecationWarning") -def test_make_chunk_iter(): - data = [b"abcdefXghi", b"jklXmnopqrstuvwxyzX", b"ABCDEFGHIJK"] - rv = list(wsgi.make_chunk_iter(data, b"X")) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.BytesIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, b"X", limit=len(data), buffer_size=4)) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - -@pytest.mark.filterwarnings("ignore:'_?make_chunk_iter:DeprecationWarning") -def test_make_chunk_iter_bytes(): - data = [b"abcdefXghi", b"jklXmnopqrstuvwxyzX", b"ABCDEFGHIJK"] - rv = list(wsgi.make_chunk_iter(data, "X")) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.BytesIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4)) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.BytesIO(data) - rv = list( - wsgi.make_chunk_iter( - test_stream, "X", limit=len(data), buffer_size=4, cap_at_buffer=True - ) - ) - assert rv == [ - b"abcd", - b"ef", - b"ghij", - b"kl", - b"mnop", - b"qrst", - b"uvwx", - b"yz", - b"ABCD", - b"EFGH", - b"IJK", - ] - - -@pytest.mark.filterwarnings("ignore:'make_line_iter:DeprecationWarning") -@pytest.mark.filterwarnings("ignore:'_make_chunk_iter:DeprecationWarning") -def test_lines_longer_buffer_size(): - data = b"1234567890\n1234567890\n" - for bufsize in range(1, 15): - lines = list( - wsgi.make_line_iter(io.BytesIO(data), limit=len(data), buffer_size=bufsize) - ) - assert lines == [b"1234567890\n", b"1234567890\n"] - - -@pytest.mark.filterwarnings("ignore:'make_line_iter:DeprecationWarning") -@pytest.mark.filterwarnings("ignore:'_make_chunk_iter:DeprecationWarning") -def test_lines_longer_buffer_size_cap(): - data = b"1234567890\n1234567890\n" - for bufsize in range(1, 15): - lines = list( - wsgi.make_line_iter( - io.BytesIO(data), - limit=len(data), - buffer_size=bufsize, - cap_at_buffer=True, - ) - ) - assert len(lines[0]) == bufsize or lines[0].endswith(b"\n") - - def test_range_wrapper(): response = Response(b"Hello World") range_wrapper = _RangeWrapper(response.response, 6, 4) diff --git a/tox.ini b/tox.ini index eca667f84..cebd251fd 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] envlist = - py3{12,11,10,9,8} + py3{13,12,11,10,9} pypy310 style typing @@ -10,6 +10,8 @@ skip_missing_interpreters = true [testenv] package = wheel wheel_build_env = .pkg +constrain_package_deps = true +use_frozen_constraints = true deps = -r requirements/tests.txt commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs} @@ -24,4 +26,27 @@ commands = mypy [testenv:docs] deps = -r requirements/docs.txt -commands = sphinx-build -W -b html -d {envtmpdir}/doctrees docs {envtmpdir}/html +commands = sphinx-build -E -W -b dirhtml docs docs/_build/dirhtml + +[testenv:update-actions] +labels = update +deps = gha-update +commands = gha-update + +[testenv:update-pre_commit] +labels = update +deps = pre-commit +skip_install = true +commands = pre-commit autoupdate -j4 + +[testenv:update-requirements] +labels = update +deps = pip-tools +skip_install = true +change_dir = requirements +commands = + pip-compile build.in -q {posargs:-U} + pip-compile docs.in -q {posargs:-U} + pip-compile tests.in -q {posargs:-U} + pip-compile typing.in -q {posargs:-U} + pip-compile dev.in -q {posargs:-U}