Skip to content

Commit

Permalink
Extract extra-flattening routine from source tree resolver (#10820)
Browse files Browse the repository at this point in the history
## Summary

I needed this for #10794, but it
makes sense as a standalone change, since it's much more testable. We
can also reuse this in at least one more place.
  • Loading branch information
charliermarsh authored Jan 21, 2025
1 parent 7863918 commit 61bc818
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 71 deletions.
3 changes: 2 additions & 1 deletion crates/uv-distribution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ pub use download::LocalWheel;
pub use error::Error;
pub use index::{BuiltWheelIndex, RegistryWheelIndex};
pub use metadata::{
ArchiveMetadata, BuildRequires, LoweredRequirement, Metadata, MetadataError, RequiresDist,
ArchiveMetadata, BuildRequires, FlatRequiresDist, LoweredRequirement, Metadata, MetadataError,
RequiresDist,
};
pub use reporter::Reporter;
pub use source::prune;
Expand Down
2 changes: 1 addition & 1 deletion crates/uv-distribution/src/metadata/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use uv_workspace::WorkspaceError;
pub use crate::metadata::build_requires::BuildRequires;
pub use crate::metadata::lowering::LoweredRequirement;
use crate::metadata::lowering::LoweringError;
pub use crate::metadata::requires_dist::RequiresDist;
pub use crate::metadata::requires_dist::{FlatRequiresDist, RequiresDist};

mod build_requires;
mod lowering;
Expand Down
281 changes: 280 additions & 1 deletion crates/uv-distribution/src/metadata/requires_dist.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use std::collections::BTreeMap;
use std::collections::{BTreeMap, VecDeque};
use std::path::Path;
use std::slice;

use rustc_hash::FxHashSet;

use uv_configuration::{LowerBound, SourceStrategy};
use uv_distribution_types::IndexLocations;
use uv_normalize::{ExtraName, GroupName, PackageName, DEV_DEPENDENCIES};
use uv_pep508::MarkerTree;
use uv_workspace::dependency_groups::FlatDependencyGroups;
use uv_workspace::pyproject::{Sources, ToolUvSources};
use uv_workspace::{DiscoveryOptions, ProjectWorkspace};
Expand Down Expand Up @@ -314,18 +318,167 @@ impl From<Metadata> for RequiresDist {
}
}

/// Like [`uv_pypi_types::RequiresDist`], but with any recursive (or self-referential) dependencies
/// resolved.
///
/// For example, given:
/// ```toml
/// [project]
/// name = "example"
/// version = "0.1.0"
/// requires-python = ">=3.13.0"
/// dependencies = []
///
/// [project.optional-dependencies]
/// all = [
/// "example[async]",
/// ]
/// async = [
/// "fastapi",
/// ]
/// ```
///
/// A build backend could return:
/// ```txt
/// Metadata-Version: 2.2
/// Name: example
/// Version: 0.1.0
/// Requires-Python: >=3.13.0
/// Provides-Extra: all
/// Requires-Dist: example[async]; extra == "all"
/// Provides-Extra: async
/// Requires-Dist: fastapi; extra == "async"
/// ```
///
/// Or:
/// ```txt
/// Metadata-Version: 2.4
/// Name: example
/// Version: 0.1.0
/// Requires-Python: >=3.13.0
/// Provides-Extra: all
/// Requires-Dist: fastapi; extra == 'all'
/// Provides-Extra: async
/// Requires-Dist: fastapi; extra == 'async'
/// ```
///
/// The [`FlatRequiresDist`] struct is used to flatten out the recursive dependencies, i.e., convert
/// from the former to the latter.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FlatRequiresDist(Vec<uv_pypi_types::Requirement>);

impl FlatRequiresDist {
/// Flatten a set of requirements, resolving any self-references.
pub fn from_requirements(
requirements: Vec<uv_pypi_types::Requirement>,
name: &PackageName,
) -> Self {
// If there are no self-references, we can return early.
if requirements.iter().all(|req| req.name != *name) {
return Self(requirements);
}

// Transitively process all extras that are recursively included.
let mut flattened = requirements.clone();
let mut seen = FxHashSet::<(ExtraName, MarkerTree)>::default();
let mut queue: VecDeque<_> = flattened
.iter()
.filter(|req| req.name == *name)
.flat_map(|req| req.extras.iter().cloned().map(|extra| (extra, req.marker)))
.collect();
while let Some((extra, marker)) = queue.pop_front() {
if !seen.insert((extra.clone(), marker)) {
continue;
}

// Find the requirements for the extra.
for requirement in &requirements {
if requirement.marker.top_level_extra_name().as_ref() == Some(&extra) {
let requirement = {
let mut marker = marker;
marker.and(requirement.marker);
uv_pypi_types::Requirement {
name: requirement.name.clone(),
extras: requirement.extras.clone(),
groups: requirement.groups.clone(),
source: requirement.source.clone(),
origin: requirement.origin.clone(),
marker: marker.simplify_extras(slice::from_ref(&extra)),
}
};
if requirement.name == *name {
// Add each transitively included extra.
queue.extend(
requirement
.extras
.iter()
.cloned()
.map(|extra| (extra, requirement.marker)),
);
} else {
// Add the requirements for that extra.
flattened.push(requirement);
}
}
}
}

// Drop all the self-references now that we've flattened them out.
flattened.retain(|req| req.name != *name);

// Retain any self-constraints for that extra, e.g., if `project[foo]` includes
// `project[bar]>1.0`, as a dependency, we need to propagate `project>1.0`, in addition to
// transitively expanding `project[bar]`.
for req in &requirements {
if req.name == *name {
if !req.source.is_empty() {
flattened.push(uv_pypi_types::Requirement {
name: req.name.clone(),
extras: vec![],
groups: req.groups.clone(),
source: req.source.clone(),
origin: req.origin.clone(),
marker: req.marker,
});
}
}
}

Self(flattened)
}

/// Consume the [`FlatRequiresDist`] and return the inner vector.
pub fn into_inner(self) -> Vec<uv_pypi_types::Requirement> {
self.0
}
}

impl IntoIterator for FlatRequiresDist {
type Item = uv_pypi_types::Requirement;
type IntoIter = std::vec::IntoIter<uv_pypi_types::Requirement>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}

#[cfg(test)]
mod test {
use std::path::Path;
use std::str::FromStr;

use anyhow::Context;
use indoc::indoc;
use insta::assert_snapshot;

use uv_configuration::{LowerBound, SourceStrategy};
use uv_distribution_types::IndexLocations;
use uv_normalize::PackageName;
use uv_pep508::Requirement;
use uv_workspace::pyproject::PyProjectToml;
use uv_workspace::{DiscoveryOptions, ProjectWorkspace};

use crate::metadata::requires_dist::FlatRequiresDist;
use crate::RequiresDist;

async fn requires_dist_from_pyproject_toml(contents: &str) -> anyhow::Result<RequiresDist> {
Expand Down Expand Up @@ -601,4 +754,130 @@ mod test {
error: metadata field project not found
"###);
}

#[test]
fn test_flat_requires_dist_noop() {
let name = PackageName::from_str("pkg").unwrap();
let requirements = vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev'")
.unwrap()
.into(),
];

let expected = FlatRequiresDist(vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev'")
.unwrap()
.into(),
]);

let actual = FlatRequiresDist::from_requirements(requirements, &name);

assert_eq!(actual, expected);
}

#[test]
fn test_flat_requires_dist_basic() {
let name = PackageName::from_str("pkg").unwrap();
let requirements = vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("pkg[dev]; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev'")
.unwrap()
.into(),
];

let expected = FlatRequiresDist(vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'test'")
.unwrap()
.into(),
]);

let actual = FlatRequiresDist::from_requirements(requirements, &name);

assert_eq!(actual, expected);
}

#[test]
fn test_flat_requires_dist_with_markers() {
let name = PackageName::from_str("pkg").unwrap();
let requirements = vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("pkg[dev]; extra == 'test' and sys_platform == 'win32'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev' and sys_platform == 'win32'")
.unwrap()
.into(),
];

let expected = FlatRequiresDist(vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev' and sys_platform == 'win32'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'test' and sys_platform == 'win32'")
.unwrap()
.into(),
]);

let actual = FlatRequiresDist::from_requirements(requirements, &name);

assert_eq!(actual, expected);
}

#[test]
fn test_flat_requires_dist_self_constraint() {
let name = PackageName::from_str("pkg").unwrap();
let requirements = vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev'")
.unwrap()
.into(),
Requirement::from_str("pkg[async]==1.0.0").unwrap().into(),
];

let expected = FlatRequiresDist(vec![
Requirement::from_str("requests>=2.0.0").unwrap().into(),
Requirement::from_str("pytest; extra == 'test'")
.unwrap()
.into(),
Requirement::from_str("black; extra == 'dev'")
.unwrap()
.into(),
Requirement::from_str("pkg==1.0.0").unwrap().into(),
]);

let actual = FlatRequiresDist::from_requirements(requirements, &name);

assert_eq!(actual, expected);
}
}
Loading

0 comments on commit 61bc818

Please sign in to comment.