diff --git a/tpchgen-cli/src/main.rs b/tpchgen-cli/src/main.rs index 65656e9..d1caa42 100644 --- a/tpchgen-cli/src/main.rs +++ b/tpchgen-cli/src/main.rs @@ -91,11 +91,11 @@ struct Cli { tables: Option>, /// Number of parts to generate (manual parallel generation) - #[arg(short, long, default_value_t = 1)] + #[arg(short, long, default_value_t = -1)] parts: i32, /// Which part to generate (1-based, only relevant if parts > 1) - #[arg(long, default_value_t = 1)] + #[arg(long, default_value_t = -1)] part: i32, /// Output format: tbl, csv, parquet (default: tbl) @@ -260,6 +260,7 @@ macro_rules! define_generate { self.scale_factor, self.part, self.parts, + self.num_threads, ); let scale_factor = self.scale_factor; info!("Writing table {} (SF={scale_factor}) to {filename}", $TABLE); diff --git a/tpchgen-cli/src/plan.rs b/tpchgen-cli/src/plan.rs index 6fa9a72..5c72c0d 100644 --- a/tpchgen-cli/src/plan.rs +++ b/tpchgen-cli/src/plan.rs @@ -1,6 +1,7 @@ //! [`GenerationPlan`] that describes how to generate a TPC-H dataset. use crate::{OutputFormat, Table}; +use log::debug; use std::fmt::Display; use tpchgen::generators::{ CustomerGenerator, OrderGenerator, PartGenerator, PartSuppGenerator, SupplierGenerator, @@ -66,12 +67,49 @@ impl GenerationPlan { scale_factor: f64, cli_part: i32, cli_part_count: i32, + num_threads: usize, ) -> Self { - // parallel generation disabled if user specifies a part explicitly + // If a single part is specified, split it into chunks to enable parallel generation. if cli_part != -1 || cli_part_count != -1 { + // These tables are small not parameterized by part count, + // so we must create only a single part. + if table == &Table::Nation || table == &Table::Region { + return Self { + part_count: 1, + part_list: vec![1], + }; + } + + // sanity check arguments (TODO: real Errors) + if cli_part < 1 || cli_part_count < 1 || cli_part > cli_part_count { + panic!( + "Invalid CLI part or part count. \ + Expect greater than 1 and cli_part <= cli_part_count. \ + Got: cli_part={cli_part}, cli_part_count={cli_part_count}", + ); + } + + let num_chunks = num_threads as i32; + + // The new total number of parts is the original number of parts multiplied by the number of chunks. + let new_total_parts = cli_part_count * num_chunks; + + // The new part numbers to generate are the chunks that make up the original part. + let start_part = (cli_part - 1) * num_chunks + 1; + let end_part = cli_part * num_chunks; + let new_parts_to_generate = (start_part..=end_part).collect(); + debug!( + "Generating {} parts for table {:?} with scale factor {}", + new_total_parts, table, scale_factor + ); + debug!( + "CLI part: {}, CLI part count: {}, num_threads: {}", + cli_part, cli_part_count, num_threads + ); + debug!("New parts to generate: {:?}", new_parts_to_generate); return Self { - part_count: cli_part_count, - part_list: vec![cli_part], + part_count: new_total_parts, + part_list: new_parts_to_generate, }; } @@ -223,7 +261,33 @@ mod tests { } #[test] - fn sf1_lineitem_cli_parts() { + fn sf1_nation_cli_parts() { + Test::new() + .with_table(Table::Nation) + .with_format(OutputFormat::Tbl) + .with_scale_factor(1.0) + // nation table is small, so it can not be made in parts + .with_cli_part(1) + .with_cli_part_count(10) + // we expect there is still only one part + .assert(1, [1]) + } + + #[test] + fn sf1_region_cli_parts() { + Test::new() + .with_table(Table::Region) + .with_format(OutputFormat::Tbl) + .with_scale_factor(1.0) + // region table is small, so it can not be made in parts + .with_cli_part(1) + .with_cli_part_count(10) + // we expect there is still only one part + .assert(1, [1]) + } + + #[test] + fn sf1_lineitem_cli_parts_1() { Test::new() .with_table(Table::Lineitem) .with_format(OutputFormat::Tbl) @@ -231,7 +295,60 @@ mod tests { // Generate only part 1 of the lineitem table .with_cli_part(1) .with_cli_part_count(10) - .assert(10, [1]) + // we expect there are num_threads * 10 parts + .assert(40, [1, 2, 3, 4]) + } + + #[test] + fn sf1_lineitem_cli_parts_4() { + Test::new() + .with_table(Table::Lineitem) + .with_format(OutputFormat::Tbl) + .with_scale_factor(1.0) + .with_cli_part(4) // part 4 of 10 + .with_cli_part_count(10) + // we expect there are num_threads * 10 parts + .assert(40, [13, 14, 15, 16]) + } + + #[test] + fn sf1_lineitem_cli_parts_10() { + Test::new() + .with_table(Table::Lineitem) + .with_format(OutputFormat::Tbl) + .with_scale_factor(1.0) + .with_cli_part(10) // part 10 of 10 + .with_cli_part_count(10) + // we expect there are num_threads * 10 parts + .assert(40, [37, 38, 39, 40]) + } + + #[test] + #[should_panic( + expected = "Invalid CLI part or part count. Expect greater than 1 and cli_part <= cli_part_count. Got: cli_part=0, cli_part_count=10" + )] + fn sf1_lineitem_cli_parts_invalid_small() { + Test::new() + .with_table(Table::Lineitem) + .with_format(OutputFormat::Tbl) + .with_scale_factor(1.0) + .with_cli_part(0) // part 0 of 10 (invalid) + .with_cli_part_count(10) + .assert(40, [13, 14, 15, 16]) + } + + #[test] + #[should_panic( + expected = "Invalid CLI part or part count. Expect greater than 1 and cli_part <= cli_part_count. Got: cli_part=11, cli_part_count=10" + )] + fn sf1_lineitem_cli_parts_invalid_big() { + Test::new() + .with_table(Table::Lineitem) + .with_format(OutputFormat::Tbl) + .with_scale_factor(1.0) + .with_cli_part(11) // part 11 of 10 (invalid) + .with_cli_part_count(10) + .assert(40, [13, 14, 15, 16]) } #[test] @@ -278,6 +395,7 @@ mod tests { scale_factor: f64, cli_part: i32, cli_part_count: i32, + num_cpus: usize, } impl Test { @@ -298,6 +416,7 @@ mod tests { self.scale_factor, self.cli_part, self.cli_part_count, + self.num_cpus, ); assert_eq!(plan.part_count, expected_part_count); let expected_part_numbers: Vec = expected_part_numbers.into_iter().collect(); @@ -343,6 +462,7 @@ mod tests { scale_factor: 1.0, cli_part: -1, cli_part_count: -1, + num_cpus: 4, // hard code 4 cores for testing } } }