Skip to content

Commit fc55dfa

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7cb5ae5 commit fc55dfa

File tree

6 files changed

+55
-29
lines changed

6 files changed

+55
-29
lines changed

tweetopic/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from tweetopic.dmm import DMM # noqa: F401
21
from tweetopic.btm import BTM # noqa: F401
2+
from tweetopic.dmm import DMM # noqa: F401
33
from tweetopic.pipeline import TopicPipeline # noqa: F401

tweetopic/_btm.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
"""Module for utility functions for fitting BTMs"""
1+
"""Module for utility functions for fitting BTMs."""
22

33
import random
44
from typing import Dict, Tuple, TypeVar
55

66
import numba
77
import numpy as np
88
from numba import njit
9+
910
from tweetopic._prob import norm_prob, sample_categorical
1011

1112

1213
@njit
1314
def doc_unique_biterms(
14-
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
15+
doc_unique_words: np.ndarray,
16+
doc_unique_word_counts: np.ndarray,
1517
) -> Dict[Tuple[int, int], int]:
1618
(n_max_unique_words,) = doc_unique_words.shape
1719
biterm_counts = dict()
@@ -42,7 +44,7 @@ def doc_unique_biterms(
4244

4345
@njit
4446
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
45-
"""Adds one counter dict to another in place with Numba"""
47+
"""Adds one counter dict to another in place with Numba."""
4648
for key in source:
4749
if key in dest:
4850
dest[key] += source[key]
@@ -52,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
5254

5355
@njit
5456
def corpus_unique_biterms(
55-
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
57+
doc_unique_words: np.ndarray,
58+
doc_unique_word_counts: np.ndarray,
5659
) -> Dict[Tuple[int, int], int]:
5760
n_documents, _ = doc_unique_words.shape
5861
biterm_counts = doc_unique_biterms(
59-
doc_unique_words[0], doc_unique_word_counts[0]
62+
doc_unique_words[0],
63+
doc_unique_word_counts[0],
6064
)
6165
for i_doc in range(1, n_documents):
6266
doc_unique_words_i = doc_unique_words[i_doc]
6367
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
6468
doc_biterms = doc_unique_biterms(
65-
doc_unique_words_i, doc_unique_word_counts_i
69+
doc_unique_words_i,
70+
doc_unique_word_counts_i,
6671
)
6772
nb_add_counter(biterm_counts, doc_biterms)
6873
return biterm_counts
6974

7075

7176
@njit
7277
def compute_biterm_set(
73-
biterm_counts: Dict[Tuple[int, int], int]
78+
biterm_counts: Dict[Tuple[int, int], int],
7479
) -> np.ndarray:
7580
return np.array(list(biterm_counts.keys()))
7681

@@ -115,7 +120,12 @@ def add_biterm(
115120
topic_biterm_count: np.ndarray,
116121
) -> None:
117122
add_remove_biterm(
118-
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
123+
True,
124+
i_biterm,
125+
i_topic,
126+
biterms,
127+
topic_word_count,
128+
topic_biterm_count,
119129
)
120130

121131

@@ -128,7 +138,12 @@ def remove_biterm(
128138
topic_biterm_count: np.ndarray,
129139
) -> None:
130140
add_remove_biterm(
131-
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
141+
False,
142+
i_biterm,
143+
i_topic,
144+
biterms,
145+
topic_word_count,
146+
topic_biterm_count,
132147
)
133148

134149

@@ -146,7 +161,11 @@ def init_components(
146161
i_topic = random.randint(0, n_components - 1)
147162
biterm_topic_assignments[i_biterm] = i_topic
148163
add_biterm(
149-
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
164+
i_biterm,
165+
i_topic,
166+
biterms,
167+
topic_word_count,
168+
topic_biterm_count,
150169
)
151170
return biterm_topic_assignments, topic_word_count, topic_biterm_count
152171

@@ -360,7 +379,10 @@ def predict_docs(
360379
)
361380
biterms = doc_unique_biterms(words, word_counts)
362381
prob_topic_given_document(
363-
pred, biterms, topic_distribution, topic_word_distribution
382+
pred,
383+
biterms,
384+
topic_distribution,
385+
topic_word_distribution,
364386
)
365387
predictions[i_doc, :] = pred
366388
return predictions

tweetopic/_dmm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
1+
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
2+
Model."""
23
from __future__ import annotations
34

45
from math import exp, log
56

67
import numpy as np
78
from numba import njit
89

9-
from tweetopic._prob import sample_categorical, norm_prob
10+
from tweetopic._prob import norm_prob, sample_categorical
1011

1112

1213
@njit
@@ -197,8 +198,7 @@ def _cond_prob(
197198
# I use logs instead of computing the products directly,
198199
# as it would quickly result in numerical overflow.
199200
log_norm_term = log(
200-
(cluster_doc_count[i_cluster] + alpha)
201-
/ (n_docs - 1 + n_clusters * alpha),
201+
(cluster_doc_count[i_cluster] + alpha) / (n_docs - 1 + n_clusters * alpha),
202202
)
203203
log_numerator = 0
204204
for i_unique in range(max_unique_words):

tweetopic/_doc.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ def init_doc_words(
1111
n_docs, _ = doc_term_matrix.shape
1212
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
1313
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
14-
np.uint32
14+
np.uint32,
1515
)
1616
for i_doc in range(n_docs):
1717
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore
1818
unique_word_counts = doc_term_matrix[i_doc].data[0] # type: ignore
1919
for i_unique in range(len(unique_words)):
2020
doc_unique_words[i_doc, i_unique] = unique_words[i_unique]
21-
doc_unique_word_counts[i_doc, i_unique] = unique_word_counts[
22-
i_unique
23-
]
21+
doc_unique_word_counts[i_doc, i_unique] = unique_word_counts[i_unique]
2422
return doc_unique_words, doc_unique_word_counts

tweetopic/btm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77
import scipy.sparse as spr
88
import sklearn
99
from numpy.typing import ArrayLike
10-
from tweetopic._btm import (compute_biterm_set, corpus_unique_biterms,
11-
fit_model, predict_docs)
10+
11+
from tweetopic._btm import (
12+
compute_biterm_set,
13+
corpus_unique_biterms,
14+
fit_model,
15+
predict_docs,
16+
)
1217
from tweetopic._doc import init_doc_words
1318
from tweetopic.exceptions import NotFittedException
1419

1520

1621
class BTM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator):
17-
"""Implementation of the Biterm Topic Model with Gibbs Sampling
18-
solver.
22+
"""Implementation of the Biterm Topic Model with Gibbs Sampling solver.
1923
2024
Parameters
2125
----------
@@ -136,7 +140,8 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
136140
max_unique_words=max_unique_words,
137141
)
138142
biterms = corpus_unique_biterms(
139-
doc_unique_words, doc_unique_word_counts
143+
doc_unique_words,
144+
doc_unique_word_counts,
140145
)
141146
biterm_set = compute_biterm_set(biterms)
142147
self.topic_distribution, self.components_ = fit_model(
@@ -152,8 +157,7 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None):
152157
# TODO: Something goes terribly wrong here, fix this
153158

154159
def transform(self, X: Union[spr.spmatrix, ArrayLike]) -> np.ndarray:
155-
"""Predicts probabilities for each document belonging to each
156-
topic.
160+
"""Predicts probabilities for each document belonging to each topic.
157161
158162
Parameters
159163
----------

tweetopic/pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def fit(self, texts: Iterable[str]) -> TopicPipeline:
4747
return self
4848

4949
def fit_transform(
50-
self, texts: Iterable[str]
50+
self,
51+
texts: Iterable[str],
5152
) -> Union[ArrayLike, spr.spmatrix]:
5253
"""Fits vectorizer and topic model and transforms the given text.
5354
@@ -65,7 +66,8 @@ def fit_transform(
6566
return self.topic_model.fit_transform(doc_term_matrix)
6667

6768
def transform(
68-
self, texts: Iterable[str]
69+
self,
70+
texts: Iterable[str],
6971
) -> Union[ArrayLike, spr.spmatrix]:
7072
"""Transforms given texts with the fitted pipeline.
7173

0 commit comments

Comments
 (0)