Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions sdks/python/apache_beam/runners/portability/stager.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def create_job_resources(
'--requirements_file command line option.' %
setup_options.requirements_file)
extra_packages, thinned_requirements_file = (
Stager._extract_local_packages(setup_options.requirements_file))
Stager._extract_local_packages(setup_options.requirements_file, temp_dir))
if extra_packages:
setup_options.extra_packages = (
setup_options.extra_packages or []) + extra_packages
Expand Down Expand Up @@ -701,14 +701,30 @@ def _remove_dependency_from_requirements(
return tmp_requirements_filename

@staticmethod
def _extract_local_packages(requirements_file):
def _extract_local_packages(requirements_file, temp_dir):
local_deps = []
pypi_deps = []
with open(requirements_file, 'r') as fin:
staging_temp_dir = tempfile.mkdtemp(dir=temp_dir)
for line in fin:
dep = line.strip()
parsed_url = urlparse(dep)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The urlparse function is used here but it is not imported in this file. This will cause a NameError at runtime. Please add from urllib.parse import urlparse to the imports section of the file.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect analysis (import is available at the top of the file).

if os.path.exists(dep):
local_deps.append(dep)
elif parsed_url.scheme:
last_component = os.path.basename(parsed_url.path)
if not last_component:
_LOGGER.warning(
'Extra package %s has an unexpected format hence will not be downloaded locally.'
% dep)
continue
# Download remote package.
_LOGGER.info(
'Downloading remote extra package: %s locally before staging',
dep)
local_file_path = FileSystems.join(staging_temp_dir, last_component)
Stager._download_file(dep, local_file_path)
local_deps.append(local_file_path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a bug in the requirements thinning logic. The Stager._remove_dependency_from_requirements method (called after this loop) uses the items in local_deps to identify which lines to remove from the original requirements file.

By appending the local temporary path (local_file_path) to local_deps instead of the original URL string (dep), the thinning logic will fail to match the URL line in the file. Consequently, the URL will remain in the thinned requirements file, causing the worker's pip to attempt to download it again, which defeats the purpose of staging the package locally.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this analysis is correct. AFAICT Stager._remove_dependency_from_requirements is called one time specifically to remove the "apache-beam" package. Not as a way to trim the requirements file based on local deps.

tmp_requirements_filepath = Stager._remove_dependency_from_requirements(

Comment on lines +725 to +727
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation has two potential issues:

  1. Filename Collision: If multiple URLs have the same basename (e.g., .../v1/pkg.tgz and .../v2/pkg.tgz), they will overwrite each other in the temporary directory, and local_deps will contain the same path twice.
  2. Unsupported Schemes: FileSystems.open (used in _download_file) does not support VCS schemes like git+https. If such a URL is encountered, the staging process will crash.

Adding a unique prefix to the filename and wrapping the download in a try-except block to fall back to pypi_deps (letting pip handle it on the worker) makes this more robust.

Suggested change
local_file_path = FileSystems.join(staging_temp_dir, last_component)
Stager._download_file(dep, local_file_path)
local_deps.append(local_file_path)
local_file_path = FileSystems.join(
staging_temp_dir, f"{len(local_deps)}_{last_component}")
try:
Stager._download_file(dep, local_file_path)
local_deps.append(local_file_path)
except Exception as e:
_LOGGER.warning(
'Failed to download remote package %s, falling back to pip: %s',
dep,
e)
pypi_deps.append(dep)

else:
pypi_deps.append(dep)
if local_deps:
Expand Down
33 changes: 33 additions & 0 deletions sdks/python/apache_beam/runners/portability/stager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,39 @@ def test_download_file_unrecognized(
self.stager._download_file(from_url, to_path)
assert mock_mkdir.called

@mock.patch('apache_beam.runners.portability.stager.Stager._download_file')
def test_extract_local_packages(self, mock_download):
temp_dir = self.make_temp_dir()
req_file = os.path.join(temp_dir, 'requirements.txt')

local_file = os.path.join(temp_dir, 'local.tar.gz')
self.create_temp_file(local_file, 'nothing')

url = 'http://example.com/remote.tar.gz'
invalid_url = 'http://example.com/'
pypi_dep = 'pytest'

contents = '\n'.join([local_file, url, invalid_url, pypi_dep])
self.create_temp_file(req_file, contents)

def fake_download(src, dst):
with open(dst, 'w') as f:
f.write('downloaded')

mock_download.side_effect = fake_download

local_deps, thinned_req = stager.Stager._extract_local_packages(req_file, temp_dir)

self.assertEqual(len(local_deps), 2)
self.assertEqual(local_deps[0], local_file)
self.assertTrue(local_deps[1].endswith('remote.tar.gz'))

mock_download.assert_called_once_with(url, mock.ANY)

with open(thinned_req, 'r') as f:
lines = f.read().splitlines()
self.assertEqual(lines, [pypi_dep])

def test_no_staging_location(self):
with self.assertRaises(RuntimeError) as cm:
self.stager.stage_job_resources([], staging_location=None)
Expand Down
Loading