Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ pub fn build(b: *std.Build) void {
target = b.resolveTargetQuery(target_query);
optimize = b.standardOptimizeOption(.{});

const use_bfs = b.option(bool, "bfs", "use BFS instead of DFS for graph linearization (DFS default)") orelse false;

const build_options = b.addOptions();
build_options.addOption(bool, "use_bfs", use_bfs);

if (use_bfs) {
std.log.scoped(.build).info("graph trasversal: BFS", .{});
} else {
std.log.scoped(.build).info("graph trasversal: DFS", .{});
}

// ************************************************ UNIT TESTS **************************************************
// $ zig build test --summary all
unit_test_creation(b, zantBuild);
Expand All @@ -44,11 +55,11 @@ pub fn build(b: *std.Build) void {

// ************************************************ GENERATED LIBRARY TESTS **************************************
// $ zig build lib-test -Dmodel="myModel" ...
lib_test(b, zantBuild);
lib_test(b, zantBuild, build_options);

// ************************************************ STATIC LIBRARY CREATION **************************************
// $ zig build lib -Dmodel="myModel" [ -Dtarget=... -Dcpu=... -Doptimize=[ReleaseSmall, ReleaseFast]]
const static_lib: *std.Build.Step.Compile = lib_creation(b, zantBuild) catch unreachable;
const static_lib: *std.Build.Step.Compile = lib_creation(b, zantBuild, build_options) catch unreachable;

// ************************************************ ONEOP CODEGEN ************************************************
// $ zig build op-codegen-gen [ -Dop="OpName" ]
Expand Down Expand Up @@ -182,7 +193,7 @@ inline fn lib_exe(b: *std.Build, zantBuild: ZantBuild) void {
model_exe_step.dependOn(&model_exe_cmd.step);
}

inline fn lib_test(b: *std.Build, zantBuild: ZantBuild) void {
inline fn lib_test(b: *std.Build, zantBuild: ZantBuild, build_options: *std.Build.Step.Options) void {
//
// OPTIONS: see codegen_options
//
Expand All @@ -204,6 +215,8 @@ inline fn lib_test(b: *std.Build, zantBuild: ZantBuild) void {
}),
});

test_generated_lib.root_module.addOptions("build_options", build_options);

if (zantBuild.zantOptions.stm32n6_flags.stm32n6_accel) build_utils.configureStm32n6Support(
b,
test_generated_lib,
Expand All @@ -220,7 +233,7 @@ inline fn lib_test(b: *std.Build, zantBuild: ZantBuild) void {
test_step_generated_lib.dependOn(&run_test_generated_lib.step);
}

inline fn lib_creation(b: *std.Build, zantBuild: ZantBuild) !*std.Build.Step.Compile {
inline fn lib_creation(b: *std.Build, zantBuild: ZantBuild, build_options: *std.Build.Step.Options) !*std.Build.Step.Compile {
const lib_model_path = std.fmt.allocPrint(b.allocator, "{s}lib_{s}.zig", .{ zantBuild.zantOptions.codegen_flags.generated_path_option, zantBuild.zantOptions.codegen_flags.model_name_option }) catch |err| {
std.log.scoped(.build).warn("Error allocating lib model path: {}\n", .{err});
return err;
Expand All @@ -235,6 +248,8 @@ inline fn lib_creation(b: *std.Build, zantBuild: ZantBuild) !*std.Build.Step.Com
}),
});

static_lib.root_module.addOptions("build_options", build_options);

if (zantBuild.zantOptions.stm32n6_flags.stm32n6_accel) build_utils.configureStm32n6Support(
b,
static_lib,
Expand Down
15 changes: 14 additions & 1 deletion docs/ZANT_CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Zant is a tensor computation framework with ONNX support. This document provides
| `-Dtype` | string | `"f32"` | Input tensor data type | `lib-gen`, `lib-exe` |
| `-Doutput_type` | string | `"f32"` | Output tensor data type | `lib-gen`, `lib-exe` |
| `-Dcomm` | bool | `false` | Generate code with comments | `lib-gen`, `lib-exe` |
| `-Ddynamic` | bool | `false` | Enable dynamic allocation | `lib-gen`, `lib-exe` |
| `-Ddynamic` | bool | `false` | Enable dynamic allocation (default: stati allocation with backing buffers) | `lib-gen`, `lib-exe` |
| `-Ddo_export` | bool | `false` | Generate exportable functions | `lib-gen`, `lib-exe` |
| `-Dv` | string | `"v1"` | Codegen version ("v1" or "v2") | `lib-gen`, `lib-exe` |
| `-Dlog` | bool | `false` | Enable logging during generation | `lib-gen`, `lib-exe` |
Expand All @@ -45,6 +45,19 @@ zig build lib-gen -Dmodel="custom" -Ddynamic -Dcomm=true
zig build lib-exe -Dmodel="mnist-8" -Dlog
```

## Memory allocation

Two metod exixst:

-Static (default):
- Generate: zig build lib-gen -Dmodel=my_model
- Quick check: grep -c "backing_buffer" generated/my_model/*.zig # high count expected

-Dynamic (use -Ddynamic=false):
- Generate: zig build lib-gen -Dmodel=my_model -Ddynamic=true
- Quick check: grep -c "fromShape" generated/my_model/*.zig # high count expected


## Extractor Commands

### Available Commands
Expand Down
83 changes: 82 additions & 1 deletion src/IR_zant/graphZant.zig
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ const onnx = zant.onnx;
const pattern_matcher = @import("fusion/pattern_matcher.zig");
const PatternConfig = pattern_matcher.PatternConfig;

//import biuld opsion
const build_option = @import("build_options");

const use_bfs = if (@hasDecl(build_option, "use_bfs"))
build_option.use_bfs
else
false;

pub const GraphZant = struct {
name: ?[]const u8,
nodes: std.ArrayList(*NodeZant),
Expand Down Expand Up @@ -147,8 +155,17 @@ pub const GraphZant = struct {
try pattern_matcher.fusePatterns(self, pattern_configs);
}

// linearize the graph
//choose BFS or DFS linearization
pub fn linearize(self: *GraphZant, alloc: std.mem.Allocator) !std.ArrayList(*NodeZant) {
if (use_bfs) {
return try self.linearize_bfs(alloc);
} else {
return try self.linearize_dfs(alloc);
}
}

// linearize the graph with DFS
pub fn linearize_dfs(self: *GraphZant, alloc: std.mem.Allocator) !std.ArrayList(*NodeZant) {
var visited = std.AutoHashMap(*NodeZant, bool).init(alloc);
var result: std.ArrayList(*NodeZant) = .empty;
defer visited.deinit();
Expand Down Expand Up @@ -178,6 +195,70 @@ pub const GraphZant = struct {
try result.append(allocator, node);
}

// linearize the graph with DFS
pub fn linearize_bfs(self: *GraphZant, alloc: std.mem.Allocator) !std.ArrayList(*NodeZant) {
var visited = std.AutoArrayHashMap(*NodeZant, bool).init(alloc);
defer visited.deinit();

var result = std.ArrayList(*NodeZant).init(alloc);

var root_nodes = std.ArrayList(*NodeZant).init(alloc);
defer root_nodes.deinit();

for (self.nodes.items) |node| {
const preds = try self.get_predecessors(node);
defer preds.deinit();

if (preds.items.len == 0) {
try root_nodes.append(node);
}
}

if (root_nodes.items.len == 0) {
for (self.nodes.items) |node| {
try root_nodes.append(node);
}
}

//eseguo il BFS per ogni radice
for (root_nodes.items) |root| {
try bfs(root, &visited, &result, alloc);
}

return result;
}

//BFS codice di visita dei nodi
pub fn bfs(
start_node: *NodeZant,
visited: *std.AutoHashMap(*NodeZant, bool),
result: *std.ArrayList(*NodeZant),
alloc: std.mem.Allocator,
) !void {
if (visited.get(start_node)) |_| return;

//coda per il BFS
var queue = std.ArrayList(*NodeZant).init(alloc);
defer queue.deinit();

try queue.append(start_node);
try visited.put(start_node, true);

while (queue.items.len > 0) {
const current = queue.orderedRemove(0);

try result.append(current);

//aggiungi nodi non visitati alla coda
for (current.next.items) |child| {
if (!visited.contains(child)) {
try visited.put(child, true);
try queue.append(child);
}
}
}
}

// TODO: unit tests for this
pub fn isDag(
self: *GraphZant,
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/cg_v1/codegen_v1.zig
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn codegnenerateFromGraphZant(model_name: []const u8, generated_path: []cons
}
}

if (!codegen_options.dynamic and codegen_options.static_planning) {
if (!codegen_options.dynamic) {
// NOTE: Not a strict requirement for the future, but the first draft
// will assume that there are no cycles (simplifies the implementation
// and works for non-recurrent neural networks)
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/cg_v1/predict/emit.zig
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pub const TensorEmitter = struct {
.type = type_str,
.return_code = templates.RC.INIT_ERROR,
});
} else if (codegen_options.static_planning) {
} else if (!codegen_options.dynamic) {
try writer.print(" var tensor_{[name]s} = Tensor({[type]s}).fromConstBuffer(&fba, backing_buffer_{[buffer_id]d}[0..{[tensor_size]d}], &shape_tensor_{[name]s});", .{
.name = sanitized_name,
.type = type_str,
Expand Down
48 changes: 17 additions & 31 deletions src/codegen/cg_v1/predict/predict.zig
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ fn write_linkersInitialization(writer: *std.Io.Writer, codegen_parameters: cg_v1

const linkers: []TensorZant = try IR_utils.getLinkers(tensorZantMap);

if (!codegen_options.dynamic and codegen_options.static_planning) {
if (!codegen_options.dynamic) {
var arena = std.heap.ArenaAllocator.init(allocator);
defer arena.deinit();

Expand All @@ -161,7 +161,7 @@ fn write_linkersInitialization(writer: *std.Io.Writer, codegen_parameters: cg_v1
var value_it = allocators.valueIterator();
var emitted_buffers = try std.bit_set.DynamicBitSet.initEmpty(
arena_alloc,
if (codegen_options.static_planning) blk: {
if (!codegen_options.dynamic) blk: {
break :blk codegen_parameters.tensors_backing_buffers.?.count();
} else 0,
);
Expand All @@ -181,7 +181,7 @@ fn write_linkersInitialization(writer: *std.Io.Writer, codegen_parameters: cg_v1
for (linkers) |*tz| {
const size = try emit.ShapeEmitter.emit(writer, tz);
var backing_buffer_id: ?cg_v1.static_memory_planning.BufferId = null;
if (!codegen_options.dynamic and codegen_options.static_planning) {
if (!codegen_options.dynamic) {
backing_buffer_id = codegen_parameters.tensors_backing_buffers.?.get(tz.name).?.id;
}
try emit.TensorEmitter.emitAllocation(writer, tz, size, codegen_options.dynamic, backing_buffer_id);
Expand Down Expand Up @@ -215,27 +215,20 @@ fn write_linkersResetMethod(writer: *std.Io.Writer, codegen_parameters: cg_v1.Co

var emitted_buffers = try std.bit_set.DynamicBitSet.initEmpty(
arena_alloc,
if (codegen_options.static_planning) blk: {
if (!codegen_options.dynamic) blk: {
break :blk codegen_parameters.tensors_backing_buffers.?.count();
} else 0,
);
for (linkers) |*tz| {
if (!codegen_options.dynamic) {
if (!codegen_options.static_planning) {
const backing_buffers = codegen_parameters.tensors_backing_buffers orelse return error.MissingTensorsBackingBuffers;
const backing_buffer = backing_buffers.get(tz.name).?;
if (!emitted_buffers.isSet(backing_buffer.id)) {
_ = try writer.print(
\\
\\ @memset(array_{s}[0..], 0);
, .{try tz.getNameSanitized()});
} else {
const backing_buffers = codegen_parameters.tensors_backing_buffers orelse return error.MissingTensorsBackingBuffers;
const backing_buffer = backing_buffers.get(tz.name).?;
if (!emitted_buffers.isSet(backing_buffer.id)) {
_ = try writer.print(
\\
\\ @memset(backing_buffer_{d}[0..], 0);
, .{backing_buffer.id});
emitted_buffers.set(backing_buffer.id);
}
\\ @memset(backing_buffer_{d}[0..], 0);
, .{backing_buffer.id});
emitted_buffers.set(backing_buffer.id);
}
}

Expand All @@ -254,21 +247,14 @@ fn write_linkersResetMethod(writer: *std.Io.Writer, codegen_parameters: cg_v1.Co

for (outputs) |*tz| {
if (!codegen_options.dynamic) {
if (!codegen_options.static_planning) {
const backing_buffers = codegen_parameters.tensors_backing_buffers orelse return error.MissingTensorsBackingBuffers;
const backing_buffer = backing_buffers.get(tz.name).?;
if (!emitted_buffers.isSet(backing_buffer.id)) {
_ = try writer.print(
\\
\\ @memset(array_{s}[0..], 0);
, .{try tz.getNameSanitized()});
} else {
const backing_buffers = codegen_parameters.tensors_backing_buffers orelse return error.MissingTensorsBackingBuffers;
const backing_buffer = backing_buffers.get(tz.name).?;
if (!emitted_buffers.isSet(backing_buffer.id)) {
_ = try writer.print(
\\
\\ @memset(backing_buffer_{d}[0..], 0);
, .{backing_buffer.id});
emitted_buffers.set(backing_buffer.id);
}
\\ @memset(backing_buffer_{d}[0..], 0);
, .{backing_buffer.id});
emitted_buffers.set(backing_buffer.id);
}
}

Expand Down Expand Up @@ -305,7 +291,7 @@ fn write_outputsInitialization(writer: *std.Io.Writer, codegen_parameters: cg_v1
for (outputs) |*tz| {
const size = try emit.ShapeEmitter.emit(writer, tz);
var backing_buffer_id: ?cg_v1.static_memory_planning.BufferId = null;
if (!codegen_options.dynamic and codegen_options.static_planning) {
if (!codegen_options.dynamic) {
backing_buffer_id = codegen_parameters.tensors_backing_buffers.?.get(tz.name).?.id;
}
try emit.TensorEmitter.emitAllocation(writer, tz, size, codegen_options.dynamic, backing_buffer_id);
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/cg_v1/predict_writer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ fn write_FBA(writer: *std.Io.Writer) !void {

const arena_alloc = arena.allocator();

if (!codegen_options.static_planning) {
if (codegen_options.dynamic) {
const section = link_section orelse ".tensor_pool";
try writer.print(old_static_format, .{
.link_section = if (should_use_tensor_pool) blk: {
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/main.zig
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn main() !void {
std.debug.print("\n output_type:{s} ", .{codegen_options.output_type});
std.debug.print("\n comm:{} ", .{codegen_options.comm});
std.debug.print("\n dynamic:{} ", .{codegen_options.dynamic});
std.debug.print("\n static_planning:{} ", .{codegen_options.static_planning});
//std.debug.print("\n static_planning:{} ", .{codegen_options.static_planning});
std.debug.print("\n version:{s} ", .{codegen_options.version});

var gpa = std.heap.GeneralPurposeAllocator(.{}){};
Expand Down
6 changes: 3 additions & 3 deletions zantBuild/codegen_flags.zig
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub const Codegen_flags = struct {
output_type_option: []const u8,
comm_option: bool,
dynamic_option: bool,
static_planning_option: bool,
//static_planning_option: bool,
fuse_option: bool,
export_option: bool,
codegen_version_option: []const u8,
Expand Down Expand Up @@ -51,8 +51,8 @@ pub const Codegen_flags = struct {
.input_type_option = b.option([]const u8, "type", "Input type") orelse "f32",
.output_type_option = b.option([]const u8, "output_type", "Output type") orelse "f32",
.comm_option = b.option(bool, "comm", "Codegen with comments") orelse false,
.dynamic_option = b.option(bool, "dynamic", "Dynamic allocation") orelse true,
.static_planning_option = b.option(bool, "static_planning", "Perform static memory planning to optimize memory allocation (ignored when -dynamic=true)") orelse false,
.dynamic_option = b.option(bool, "dynamic", "Dynamic allocation") orelse false,
//.static_planning_option = b.option(bool, "static_planning", "Perform static memory planning to optimize memory allocation (ignored when -dynamic=true)") orelse true,
.fuse_option = b.option(bool, "fuse", "enable Kernel fusion") orelse false,
.export_option = b.option(bool, "do_export", "codegen Exportable ") orelse false,
.codegen_version_option = b.option([]const u8, "v", "Version, v1 or v2") orelse "v1",
Expand Down
1 change: 1 addition & 0 deletions zantBuild/zantModules.zig
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub const ZantModules = struct {
codegen_mod.addImport("IR_zant", IR_zant_mod);
codegen_mod.addOptions("codegen_options", zantStepOptions.codegen_step_option); //<<--OSS!! it is an option!
IR_zant_mod.addImport("codegen", codegen_mod);
IR_zant_mod.addOptions("build_options", zantStepOptions.bench_step_option);

const core_mod = b.createModule(.{ .root_source_file = b.path("src/Core/core.zig") });
core_mod.addImport("zant", zant_mod);
Expand Down
2 changes: 1 addition & 1 deletion zantBuild/zantStepOptions.zig
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub const ZantStepOptions = struct {
codegen_flags.addOption([]const u8, "output_type", zantOptions.codegen_flags.output_type_option); //codegen
codegen_flags.addOption(bool, "comm", zantOptions.codegen_flags.comm_option); //codegen
codegen_flags.addOption(bool, "dynamic", zantOptions.codegen_flags.dynamic_option); //codegen
codegen_flags.addOption(bool, "static_planning", zantOptions.codegen_flags.static_planning_option); //codegen
//codegen_flags.addOption(bool, "static_planning", zantOptions.codegen_flags.static_planning_option); //codegen
codegen_flags.addOption(bool, "fuse", zantOptions.codegen_flags.fuse_option); //codegen
codegen_flags.addOption([]const u8, "version", zantOptions.codegen_flags.codegen_version_option); //codegen
codegen_flags.addOption(bool, "xip", zantOptions.codegen_flags.xip_enabled); //codegen
Expand Down
Loading