Skip to content

Commit a0204f7

Browse files
committed
feat: add_hparams
1 parent 1809620 commit a0204f7

File tree

8 files changed

+605
-29
lines changed

8 files changed

+605
-29
lines changed

BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ proto_library(
1414
name = "proto",
1515
srcs = [":proto_files"],
1616
strip_import_prefix = "proto",
17+
deps = [
18+
"@protobuf//:struct_proto",
19+
],
1720
)
1821

1922
cc_proto_library(

MODULE.bazel

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,15 @@ module(
66
)
77

88
bazel_dep(name = "protobuf", version = "31.1")
9+
10+
# Hedron's Compile Commands Extractor for Bazel
11+
# https://github.com/hedronvision/bazel-compile-commands-extractor
12+
bazel_dep(name = "hedron_compile_commands", dev_dependency = True)
13+
git_override(
14+
module_name = "hedron_compile_commands",
15+
remote = "https://github.com/hedronvision/bazel-compile-commands-extractor.git",
16+
commit = "4f28899228fb3ad0126897876f147ca15026151e",
17+
patches = [
18+
":third_party/hedron_compile_commands.patch",
19+
],
20+
)

MODULE.bazel.lock

Lines changed: 8 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

include/tensorboard_logger.h

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
#ifndef TENSORBOARD_LOGGER_H
22
#define TENSORBOARD_LOGGER_H
33

4+
#include <atomic>
45
#include <exception>
56
#include <fstream>
7+
#include <mutex>
68
#include <string>
7-
#include <vector>
8-
#include <atomic>
99
#include <thread>
10-
#include <mutex>
10+
#include <vector>
1111

1212
#include "crc.h"
1313
#include "event.pb.h"
14-
14+
#include "plugin_data.pb.h"
15+
using ::google::protobuf::Value;
16+
using std::map;
17+
using std::string;
18+
using tensorboard::hparams::HParamsPluginData;
19+
using tensorboard::hparams::SessionStartInfo;
1520
using tensorflow::Event;
1621
using tensorflow::Summary;
1722

@@ -23,11 +28,9 @@ const std::string kProjectorConfigFile = "projector_config.pbtxt";
2328
const std::string kProjectorPluginName = "projector";
2429
const std::string kTextPluginName = "text";
2530

26-
27-
struct TensorBoardLoggerOptions
28-
{
29-
// Log is flushed whenever this many entries have been written since the last
30-
// forced flush.
31+
struct TensorBoardLoggerOptions {
32+
// Log is flushed whenever this many entries have been written since the
33+
// last forced flush.
3134
size_t max_queue_size_ = 100000;
3235
TensorBoardLoggerOptions &max_queue_size(size_t max_queue_size) {
3336
max_queue_size_ = max_queue_size;
@@ -50,15 +53,15 @@ struct TensorBoardLoggerOptions
5053

5154
class TensorBoardLogger {
5255
public:
53-
5456
explicit TensorBoardLogger(const std::string &log_file,
55-
const TensorBoardLoggerOptions &options={}) {
57+
const TensorBoardLoggerOptions &options = {}) {
5658
this->options = options;
5759
auto basename = get_basename(log_file);
5860
if (basename.find("tfevents") == std::string::npos) {
5961
throw std::runtime_error(
6062
"A valid event file must contain substring \"tfevents\" in its "
61-
"basename, got " + basename);
63+
"basename, got " +
64+
basename);
6265
}
6366
bucket_limits_ = nullptr;
6467
ofs_ = new std::ofstream(
@@ -85,6 +88,9 @@ class TensorBoardLogger {
8588
flushing_thread.join();
8689
}
8790
}
91+
92+
int add_hparams(const map<string, Value> &hparams, const string &group_name,
93+
double start_time_secs);
8894
int add_scalar(const std::string &tag, int step, double value);
8995
int add_scalar(const std::string &tag, int step, float value);
9096

@@ -189,6 +195,7 @@ class TensorBoardLogger {
189195

190196
private:
191197
int generate_default_buckets();
198+
int add_session_start_info(SessionStartInfo *session_start_info);
192199
int add_event(int64_t step, Summary *summary);
193200
int write(Event &event);
194201
void flusher();
@@ -197,7 +204,7 @@ class TensorBoardLogger {
197204
std::ofstream *ofs_;
198205
std::vector<double> *bucket_limits_;
199206
TensorBoardLoggerOptions options;
200-
207+
201208
std::atomic<bool> stop{false};
202209
size_t queue_size{0};
203210
std::thread flushing_thread;

0 commit comments

Comments
 (0)