Skip to content

Commit b9b9940

Browse files
committed
add word2vec with local optimizatio
1 parent d51151b commit b9b9940

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

src/apps/word2vec/w2v_local.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include "word2vec.h"
2+
//#include "word2vec_global.h"
3+
//#include <gflags/gflags.h>
4+
using namespace std;
5+
int main(int argc, char* argv[]) {
6+
GlobalMPI::initialize(argc, argv);
7+
// init config
8+
fms::CMDLine cmdline(argc, argv);
9+
std::string param_help = cmdline.registerParameter("help", "this screen");
10+
std::string param_config_path = cmdline.registerParameter("config", "path of config file \t[string]");
11+
std::string param_data_path = cmdline.registerParameter("data", "path of dataset, text only! \t[string]");
12+
std::string param_niters = cmdline.registerParameter("niters", "number of iterations \t[int]");
13+
std::string param_param_output = cmdline.registerParameter("output", "path to output the parameters\t[string]");
14+
15+
if(cmdline.hasParameter(param_help) || argc == 1) {
16+
cout << endl;
17+
cout << "===================================================================" << endl;
18+
cout << " Word2Vec application" << endl;
19+
cout << " Author: Suprjom <[email protected]>" << endl;
20+
cout << "===================================================================" << endl;
21+
cmdline.print_help();
22+
cout << endl;
23+
cout << endl;
24+
return 0;
25+
}
26+
if (!cmdline.hasParameter(param_config_path) ||
27+
!cmdline.hasParameter(param_data_path) ||
28+
!cmdline.hasParameter(param_niters)
29+
) {
30+
LOG(ERROR) << "missing parameter";
31+
cmdline.print_help();
32+
return 0;
33+
}
34+
std::string config_path = cmdline.getValue(param_config_path);
35+
std::string data_path = cmdline.getValue(param_data_path);
36+
std::string output_path = cmdline.getValue(param_param_output);
37+
int niters = stoi(cmdline.getValue(param_niters));
38+
global_config().load_conf(config_path);
39+
global_config().parse();
40+
41+
// init cluster
42+
Cluster<ClusterWorker, server_t, w2v_key_t> cluster;
43+
cluster.initialize();
44+
45+
Word2Vec<MiniBatch> w2v(data_path, niters);
46+
w2v.train();
47+
swift_snails::format_string(output_path, "-%d.txt", global_mpi().rank());
48+
RAW_LOG_WARNING ("server output parameter to %s", output_path.c_str());
49+
cluster.finalize(output_path);
50+
51+
LOG(WARNING) << "cluster exit.";
52+
53+
return 0;
54+
}

0 commit comments

Comments
 (0)