diff --git a/sqllex/core/entities/abc/sql_database.py b/sqllex/core/entities/abc/sql_database.py index 04065a1..ed82303 100644 --- a/sqllex/core/entities/abc/sql_database.py +++ b/sqllex/core/entities/abc/sql_database.py @@ -688,7 +688,8 @@ def content_gen(parameters, column=None) -> str: 'col': ["INTEGER", "NOT NULL"], } """ - parameters = sorted(parameters, key=lambda par: sort.column_types(par)) + parameters = sorted(parameters, + key=lambda par: sort.column_types(par)) return script_gen.column(name=column, params=tuple(parameters)) elif isinstance(parameters, tuple): @@ -697,7 +698,8 @@ def content_gen(parameters, column=None) -> str: 'col': ("INTEGER", "NOT NULL"), } """ - parameters = sorted(list(parameters), key=lambda par: sort.column_types(par)) + parameters = sorted(list(parameters), + key=lambda par: sort.column_types(par)) return script_gen.column(name=column, params=tuple(parameters)) elif isinstance(parameters, dict): @@ -709,20 +711,46 @@ def content_gen(parameters, column=None) -> str: } """ if column != FOREIGN_KEY: - raise TypeError(f'Incorrect column "{column}" initialisation: {parameters}') + raise TypeError(f'Incorrect column "{column}" ' + f'initialisation: {parameters}') res = "" for (key, refs) in parameters.items(): if isinstance(refs, (list, tuple)): - res += script_gen.column_with_foreign_key(key=key, table=refs[0], column=refs[1]) + res += script_gen.column_with_foreign_key( + key=key, table=refs[0], column=refs[1]) if isinstance(refs, AbstractColumn): - res += script_gen.column_with_foreign_key(key=key, table=refs.table, column=refs.name) + res += script_gen.column_with_foreign_key( + key=key, table=refs.table, column=refs.name) return res[:-1] else: - raise TypeError(f'Incorrect column "{column}" initialisation, parameters type {type(parameters)}, ' - f'expected tuple, list or str') + raise TypeError(f'Incorrect column "{column}" initialisation, ' + f'parameters type {type(parameters)}, ' + 'expected tuple, list or str') + + def translate_param(param: Union[type, str]) -> str: + dictionary = {int: "INTEGER", + str: "TEXT", + float: "REAL", + None: "NULL"} + + translation = dictionary.get(param) + return translation if translation else param + + def translate_params(parameters: Union[type, str]) -> str: + if isinstance(parameters, (str, dict)): + return parameters + + if isinstance(parameters, list): + parameters = [translate_param(param) for param in parameters] + elif isinstance(parameters, list): + parameters = (translate_param(param) for param in parameters) + else: + parameters = translate_param(parameters) + + return parameters if not columns: raise ValueError("Zero-column tables aren't supported in SQLite") @@ -731,7 +759,7 @@ def content_gen(parameters, column=None) -> str: values = () for (col, params) in columns.items(): - content += content_gen(params, column=col) + content += content_gen(translate_params(params), column=col) script = script_gen.create( temp=temp, diff --git a/sqllex/types/types.py b/sqllex/types/types.py index 92eba6e..7e4ccd7 100644 --- a/sqllex/types/types.py +++ b/sqllex/types/types.py @@ -44,7 +44,7 @@ ConstantType, str ] -ColumnsType = Mapping[str, ColumnType] +ColumnsType = Mapping[Union[str, type], ColumnType] # Type for databases template DBTemplateType = Mapping[ diff --git a/tests/temp_tests.py b/tests/temp_tests.py index c07d163..3f5e4be 100644 --- a/tests/temp_tests.py +++ b/tests/temp_tests.py @@ -8,14 +8,13 @@ db.create_table( 'suggestions', { - # 'sid': [sx.TEXT, sx.PRIMARY_KEY, sx.AUTOINCREMENT], - 'sid': [sx.INTEGER, sx.PRIMARY_KEY, sx.AUTOINCREMENT], - 'uid': [sx.INTEGER, sx.NOT_NULL], - 'category': [sx.TEXT, sx.NOT_NULL], # header|article|idea - 'status': [sx.TEXT, sx.NOT_NULL, sx.DEFAULT, "sent"], + 'sid': [int, sx.PRIMARY_KEY, sx.AUTOINCREMENT], + 'uid': [int, sx.NOT_NULL], + 'category': [str, sx.NOT_NULL], # header|article|idea + 'status': [str, sx.NOT_NULL, sx.DEFAULT, "sent"], # sent|rejected|in_work|posted - 'comment': [sx.TEXT], - 'date': [sx.TEXT, sx.NOT_NULL] + 'comment': [str], + 'date': [str, sx.NOT_NULL] }, IF_NOT_EXIST=True ) @@ -23,7 +22,7 @@ 'suggestions', { "uid": 1, - "category": 'category', + "category": 'header', "date": datetime.now().strftime("%Y-%m-%d %H:%M") } ) diff --git a/tests/test_postgresqlx.py b/tests/test_postgresqlx.py index b092ffb..c1ef097 100644 --- a/tests/test_postgresqlx.py +++ b/tests/test_postgresqlx.py @@ -219,7 +219,7 @@ def test_create_table_basic(self): """ Testing table creating """ - columns = {'id': INTEGER} + columns = {'id': int} self.db.create_table( 'test_table_1', @@ -241,10 +241,10 @@ def test_create_table_all_columns(self): self.db.create_table( name='test_table', columns={ - 'id': [INTEGER, PRIMARY_KEY], - 'user': [TEXT, UNIQUE, NOT_NULL], - 'about': [TEXT, DEFAULT, NULL], - 'status': [TEXT, DEFAULT, "'offline'"] + 'id': [int, PRIMARY_KEY], + 'user': [str, UNIQUE, NOT_NULL], + 'about': [str, DEFAULT, NULL], + 'status': [str, DEFAULT, "'offline'"] } ) @@ -254,9 +254,9 @@ def test_create_table_all_columns(self): name='test_table_1', columns={ 'id': [SERIAL, PRIMARY_KEY], - 'user': [TEXT, UNIQUE, NOT_NULL], - 'about': [TEXT, DEFAULT, NULL], - 'status': [TEXT, DEFAULT, "'offline'"] + 'user': [str, UNIQUE, NOT_NULL], + 'about': [str, DEFAULT, NULL], + 'status': [str, DEFAULT, "'offline'"] } ) self.assertEqual(self.raw_sql_get_tables_names(), ('test_table', 'test_table_1')) @@ -265,7 +265,7 @@ def test_create_table_inx(self): """ Testing if not exist kwarg """ - columns = {'id': INTEGER} + columns = {'id': int} self.db.create_table('test_table_1', columns, IF_NOT_EXIST=True) self.db.create_table('test_table_2', columns, IF_NOT_EXIST=True)