|
| 1 | +import datetime |
| 2 | +import json |
| 3 | +import sqlite3 |
| 4 | + |
| 5 | +import ssh_ca |
| 6 | + |
| 7 | + |
| 8 | +class SqliteAuthority(ssh_ca.Authority): |
| 9 | + def __init__(self, config, ssh_ca_section, ca_key): |
| 10 | + super(SqliteAuthority, self).__init__(ca_key) |
| 11 | + |
| 12 | + self.dbfile = ssh_ca.get_config_value( |
| 13 | + config, ssh_ca_section, 'dbfile', required=True) |
| 14 | + self.conn = sqlite3.connect(self.dbfile) |
| 15 | + self._check_schema() |
| 16 | + |
| 17 | + def _check_schema(self): |
| 18 | + version = self.conn.execute('PRAGMA user_version').fetchone() |
| 19 | + if version[0] == 0: |
| 20 | + with self.conn: |
| 21 | + self.conn.execute('PRAGMA user_version=1') |
| 22 | + self.conn.execute( |
| 23 | + 'create table keys (name, environment, public_key)') |
| 24 | + self.conn.execute( |
| 25 | + 'create table serial (row, serial integer)') |
| 26 | + self.conn.execute( |
| 27 | + 'create table audit_log (entry integer primary key, log)') |
| 28 | + self.conn.execute( |
| 29 | + 'insert into serial (row, serial) values (1, 0)') |
| 30 | + |
| 31 | + def increment_serial_number(self): |
| 32 | + with self.conn: |
| 33 | + self.conn.execute('update serial set serial=serial+1 where row=1') |
| 34 | + cur = self.conn.execute('select serial from serial where row=1') |
| 35 | + new_serial = cur.fetchone()[0] |
| 36 | + return new_serial |
| 37 | + |
| 38 | + def get_public_key(self, username, environment): |
| 39 | + select = 'select public_key from keys where name is ?' |
| 40 | + args = (username, ) |
| 41 | + cur = self.conn.execute(select, args) |
| 42 | + result = cur.fetchone() |
| 43 | + if result: |
| 44 | + return result[0] |
| 45 | + else: |
| 46 | + return None |
| 47 | + |
| 48 | + def upload_public_key(self, username, key_file): |
| 49 | + key = open(key_file).read() |
| 50 | + arglist = (username, key) |
| 51 | + insert_stmt = 'insert into keys (name, public_key) values (?, ?)' |
| 52 | + with self.conn: |
| 53 | + self.conn.execute(insert_stmt, arglist) |
| 54 | + |
| 55 | + def upload_public_key_cert(self, username, cert_contents): |
| 56 | + return "%s: %s" % (username, cert_contents) |
| 57 | + |
| 58 | + def make_host_audit_log(self, serial, valid_for, ca_key_filename, |
| 59 | + reason, hostnames): |
| 60 | + audit_info = { |
| 61 | + 'valid_for': valid_for, |
| 62 | + 'ca_key_filename': ca_key_filename, |
| 63 | + 'reason': reason, |
| 64 | + 'hostnames': hostnames, |
| 65 | + } |
| 66 | + return self.drop_audit_blob(serial, audit_info) |
| 67 | + |
| 68 | + def make_audit_log(self, serial, valid_for, username, |
| 69 | + ca_key_filename, reason, principals): |
| 70 | + audit_info = { |
| 71 | + 'username': username, |
| 72 | + 'valid_for': valid_for, |
| 73 | + 'ca_key_filename': ca_key_filename, |
| 74 | + 'reason': reason, |
| 75 | + 'principals': principals, |
| 76 | + } |
| 77 | + return self.drop_audit_blob(serial, audit_info) |
| 78 | + |
| 79 | + def drop_audit_blob(self, serial, blob): |
| 80 | + timestamp = datetime.datetime.strftime( |
| 81 | + datetime.datetime.utcnow(), '%Y-%m-%d-%H:%M:%S.%f') |
| 82 | + blob['serial'] = serial |
| 83 | + blob['timestamp'] = timestamp |
| 84 | + |
| 85 | + arglist = (None, json.dumps(blob)) |
| 86 | + with self.conn: |
| 87 | + self.conn.execute('insert into audit_log values (?, ?)', arglist) |
0 commit comments