Skip to content

Commit 1b85030

Browse files
authored
Add result types: numpy, pandas, polars, arrow (#25)
* Add numpy, pandas, polars, and arrow result types
1 parent 420096d commit 1b85030

File tree

14 files changed

+2613
-252
lines changed

14 files changed

+2613
-252
lines changed

accel.c

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
#define ACCEL_OUT_STRUCTSEQUENCES 1
1717
#define ACCEL_OUT_DICTS 2
1818
#define ACCEL_OUT_NAMEDTUPLES 3
19+
#define ACCEL_OUT_NUMPY 4
20+
#define ACCEL_OUT_PANDAS 5
21+
#define ACCEL_OUT_POLARS 6
22+
#define ACCEL_OUT_ARROW 7
1923

2024
#define NUMPY_BOOL 1
2125
#define NUMPY_INT8 2
@@ -389,6 +393,9 @@ typedef struct {
389393
PyObject *Series;
390394
PyObject *array;
391395
PyObject *vectorize;
396+
PyObject *DataFrame;
397+
PyObject *Table;
398+
PyObject *from_pylist;
392399
} PyStrings;
393400

394401
static PyStrings PyStr = {0};
@@ -406,6 +413,10 @@ typedef struct {
406413
PyObject *collections_namedtuple;
407414
PyObject *numpy_array;
408415
PyObject *numpy_vectorize;
416+
PyObject *pandas_DataFrame;
417+
PyObject *polars_DataFrame;
418+
PyObject *pyarrow_Table;
419+
PyObject *pyarrow_Table_from_pylist;
409420
} PyFunctions;
410421

411422
static PyFunctions PyFunc = {0};
@@ -466,10 +477,88 @@ typedef struct {
466477
char *encoding_errors;
467478
} StateObject;
468479

469-
static void read_options(MySQLAccelOptions *options, PyObject *dict);
480+
static int read_options(MySQLAccelOptions *options, PyObject *dict);
470481

471482
#define DESTROY(x) do { if (x) { free((void*)x); (x) = NULL; } } while (0)
472483

484+
int ensure_numpy() {
485+
if (PyFunc.numpy_array && PyFunc.numpy_vectorize) goto exit;
486+
487+
// Import numpy if it exists
488+
PyObject *numpy_mod = PyImport_ImportModule("numpy");
489+
if (!numpy_mod) goto error;
490+
491+
PyFunc.numpy_array = PyObject_GetAttr(numpy_mod, PyStr.array);
492+
if (!PyFunc.numpy_array) goto error;
493+
494+
PyFunc.numpy_vectorize = PyObject_GetAttr(numpy_mod, PyStr.vectorize);
495+
if (!PyFunc.numpy_vectorize) goto error;
496+
497+
exit:
498+
return 0;
499+
500+
error:
501+
return -1;
502+
}
503+
504+
505+
int ensure_pandas() {
506+
if (PyFunc.pandas_DataFrame) goto exit;
507+
508+
// Import pandas if it exists
509+
PyObject *pandas_mod = PyImport_ImportModule("pandas");
510+
if (!pandas_mod) goto error;
511+
512+
PyFunc.pandas_DataFrame = PyObject_GetAttr(pandas_mod, PyStr.DataFrame);
513+
if (!PyFunc.pandas_DataFrame) goto error;
514+
515+
exit:
516+
return 0;
517+
518+
error:
519+
return -1;
520+
}
521+
522+
523+
int ensure_polars() {
524+
if (PyFunc.polars_DataFrame) goto exit;
525+
526+
// Import polars if it exists
527+
PyObject *polars_mod = PyImport_ImportModule("polars");
528+
if (!polars_mod) goto error;
529+
530+
PyFunc.polars_DataFrame = PyObject_GetAttr(polars_mod, PyStr.DataFrame);
531+
if (!PyFunc.polars_DataFrame) goto error;
532+
533+
exit:
534+
return 0;
535+
536+
error:
537+
return -1;
538+
}
539+
540+
541+
int ensure_pyarrow() {
542+
if (PyFunc.pyarrow_Table_from_pylist) goto exit;
543+
544+
// Import pyarrow if it exists
545+
PyObject *pyarrow_mod = PyImport_ImportModule("pyarrow");
546+
if (!pyarrow_mod) goto error;
547+
548+
PyFunc.pyarrow_Table = PyObject_GetAttr(pyarrow_mod, PyStr.Table);
549+
if (!PyFunc.pyarrow_Table) goto error;
550+
551+
PyFunc.pyarrow_Table_from_pylist = PyObject_GetAttr(PyFunc.pyarrow_Table, PyStr.from_pylist);
552+
if (!PyFunc.pyarrow_Table_from_pylist) goto error;
553+
554+
exit:
555+
return 0;
556+
557+
error:
558+
return -1;
559+
}
560+
561+
473562
static void State_clear_fields(StateObject *self) {
474563
if (!self) return;
475564
DESTROY(self->offsets);
@@ -680,7 +769,7 @@ static int State_init(StateObject *self, PyObject *args, PyObject *kwds) {
680769
Py_XINCREF(self->py_invalid_values[i]);
681770

682771
self->py_converters[i] = (!py_converter
683-
|| py_converter == Py_None
772+
// || py_converter == Py_None
684773
|| py_converter == py_default_converter) ?
685774
NULL : py_converter;
686775
Py_XINCREF(self->py_converters[i]);
@@ -709,7 +798,8 @@ static int State_init(StateObject *self, PyObject *args, PyObject *kwds) {
709798
Py_XDECREF(py_next_seq_id);
710799

711800
if (py_options && PyDict_Check(py_options)) {
712-
read_options(&self->options, py_options);
801+
rc = read_options(&self->options, py_options);
802+
if (rc) goto error;
713803
}
714804

715805
switch (self->options.results_type) {
@@ -825,12 +915,13 @@ static PyType_Spec StateType_spec = {
825915
// End State
826916
//
827917

828-
static void read_options(MySQLAccelOptions *options, PyObject *dict) {
829-
if (!options || !dict) return;
918+
static int read_options(MySQLAccelOptions *options, PyObject *dict) {
919+
if (!options || !dict) return 0;
830920

831921
PyObject *key = NULL;
832922
PyObject *value = NULL;
833923
Py_ssize_t pos = 0;
924+
int rc = 0;
834925

835926
while (PyDict_Next(dict, &pos, &key, &value)) {
836927
if (PyUnicode_CompareWithASCIIString(key, "results_type") == 0) {
@@ -846,6 +937,23 @@ static void read_options(MySQLAccelOptions *options, PyObject *dict) {
846937
PyUnicode_CompareWithASCIIString(value, "structsequences") == 0) {
847938
options->results_type = ACCEL_OUT_STRUCTSEQUENCES;
848939
}
940+
else if (PyUnicode_CompareWithASCIIString(value, "numpy") == 0) {
941+
options->results_type = ACCEL_OUT_NUMPY;
942+
rc = ensure_numpy();
943+
}
944+
else if (PyUnicode_CompareWithASCIIString(value, "pandas") == 0) {
945+
options->results_type = ACCEL_OUT_PANDAS;
946+
rc = ensure_pandas();
947+
}
948+
else if (PyUnicode_CompareWithASCIIString(value, "polars") == 0) {
949+
options->results_type = ACCEL_OUT_POLARS;
950+
rc = ensure_polars();
951+
}
952+
else if (PyUnicode_CompareWithASCIIString(value, "arrow") == 0 ||
953+
PyUnicode_CompareWithASCIIString(value, "pyarrow") == 0) {
954+
options->results_type = ACCEL_OUT_ARROW;
955+
rc = ensure_pyarrow();
956+
}
849957
else {
850958
options->results_type = ACCEL_OUT_TUPLES;
851959
}
@@ -857,6 +965,8 @@ static void read_options(MySQLAccelOptions *options, PyObject *dict) {
857965
}
858966
}
859967
}
968+
969+
return rc;
860970
}
861971

862972
static void raise_exception(
@@ -1323,6 +1433,7 @@ static PyObject *read_row_from_packet(
13231433

13241434
switch (py_state->options.results_type) {
13251435
case ACCEL_OUT_DICTS:
1436+
case ACCEL_OUT_ARROW:
13261437
py_result = PyDict_New();
13271438
break;
13281439
case ACCEL_OUT_STRUCTSEQUENCES: {
@@ -1362,8 +1473,12 @@ static PyObject *read_row_from_packet(
13621473
py_str = PyUnicode_Decode(out, out_l, py_state->encodings[i], py_state->encoding_errors);
13631474
if (!py_str) goto error;
13641475
}
1365-
py_item = PyObject_CallFunctionObjArgs(py_state->py_converters[i], py_str, NULL);
1366-
Py_CLEAR(py_str);
1476+
if (py_state->py_converters[i] == Py_None) {
1477+
py_item = py_str;
1478+
} else {
1479+
py_item = PyObject_CallFunctionObjArgs(py_state->py_converters[i], py_str, NULL);
1480+
Py_CLEAR(py_str);
1481+
}
13671482
if (!py_item) goto error;
13681483
}
13691484

@@ -1586,6 +1701,7 @@ static PyObject *read_row_from_packet(
15861701
PyStructSequence_SetItem(py_result, i, py_item);
15871702
break;
15881703
case ACCEL_OUT_DICTS:
1704+
case ACCEL_OUT_ARROW:
15891705
PyDict_SetItem(py_result, py_state->py_names[i], py_item);
15901706
Py_INCREF(py_state->py_names[i]);
15911707
Py_DECREF(py_item);
@@ -1847,27 +1963,6 @@ static PyObject *create_numpy_array(PyObject *py_memview, char *data_format, int
18471963
}
18481964

18491965

1850-
int ensure_numpy() {
1851-
if (PyFunc.numpy_array && PyFunc.numpy_vectorize) goto exit;
1852-
1853-
// Import numpy if it exists
1854-
PyObject *numpy_mod = PyImport_ImportModule("numpy");
1855-
if (!numpy_mod) goto error;
1856-
1857-
PyFunc.numpy_array = PyObject_GetAttr(numpy_mod, PyStr.array);
1858-
if (!PyFunc.numpy_array) goto error;
1859-
1860-
PyFunc.numpy_vectorize = PyObject_GetAttr(numpy_mod, PyStr.vectorize);
1861-
if (!PyFunc.numpy_vectorize) goto error;
1862-
1863-
exit:
1864-
return 0;
1865-
1866-
error:
1867-
return -1;
1868-
}
1869-
1870-
18711966
static PyObject *load_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *kwargs) {
18721967
PyObject *py_data = NULL;
18731968
PyObject *py_out = NULL;
@@ -4372,6 +4467,9 @@ PyMODINIT_FUNC PyInit__singlestoredb_accel(void) {
43724467
PyStr.Series = PyUnicode_FromString("Series");
43734468
PyStr.array = PyUnicode_FromString("array");
43744469
PyStr.vectorize = PyUnicode_FromString("vectorize");
4470+
PyStr.DataFrame = PyUnicode_FromString("DataFrame");
4471+
PyStr.Table = PyUnicode_FromString("Table");
4472+
PyStr.from_pylist = PyUnicode_FromString("from_pylist");
43754473

43764474
PyObject *decimal_mod = PyImport_ImportModule("decimal");
43774475
if (!decimal_mod) goto error;

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ exclude =
7878
docs/*
7979
resources/*
8080
licenses/*
81-
max-complexity = 30
81+
max-complexity = 35
8282
max-line-length = 90
8383
per-file-ignores =
8484
singlestoredb/__init__.py:F401

singlestoredb/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@
231231
valid_values=[
232232
'tuple', 'tuples', 'namedtuple', 'namedtuples',
233233
'dict', 'dicts', 'structsequence', 'structsequences',
234+
'numpy', 'pandas', 'polars', 'arrow', 'pyarrow',
234235
],
235236
),
236237
'tuples',

0 commit comments

Comments
 (0)