1
- """Module for utility functions for fitting BTMs"""
1
+ """Module for utility functions for fitting BTMs. """
2
2
3
3
import random
4
4
from typing import Dict , Tuple , TypeVar
5
5
6
6
import numba
7
7
import numpy as np
8
8
from numba import njit
9
+
9
10
from tweetopic ._prob import norm_prob , sample_categorical
10
11
11
12
12
13
@njit
13
14
def doc_unique_biterms (
14
- doc_unique_words : np .ndarray , doc_unique_word_counts : np .ndarray
15
+ doc_unique_words : np .ndarray ,
16
+ doc_unique_word_counts : np .ndarray ,
15
17
) -> Dict [Tuple [int , int ], int ]:
16
18
(n_max_unique_words ,) = doc_unique_words .shape
17
19
biterm_counts = dict ()
@@ -42,7 +44,7 @@ def doc_unique_biterms(
42
44
43
45
@njit
44
46
def nb_add_counter (dest : Dict [T , int ], source : Dict [T , int ]):
45
- """Adds one counter dict to another in place with Numba"""
47
+ """Adds one counter dict to another in place with Numba. """
46
48
for key in source :
47
49
if key in dest :
48
50
dest [key ] += source [key ]
@@ -52,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
52
54
53
55
@njit
54
56
def corpus_unique_biterms (
55
- doc_unique_words : np .ndarray , doc_unique_word_counts : np .ndarray
57
+ doc_unique_words : np .ndarray ,
58
+ doc_unique_word_counts : np .ndarray ,
56
59
) -> Dict [Tuple [int , int ], int ]:
57
60
n_documents , _ = doc_unique_words .shape
58
61
biterm_counts = doc_unique_biterms (
59
- doc_unique_words [0 ], doc_unique_word_counts [0 ]
62
+ doc_unique_words [0 ],
63
+ doc_unique_word_counts [0 ],
60
64
)
61
65
for i_doc in range (1 , n_documents ):
62
66
doc_unique_words_i = doc_unique_words [i_doc ]
63
67
doc_unique_word_counts_i = doc_unique_word_counts [i_doc ]
64
68
doc_biterms = doc_unique_biterms (
65
- doc_unique_words_i , doc_unique_word_counts_i
69
+ doc_unique_words_i ,
70
+ doc_unique_word_counts_i ,
66
71
)
67
72
nb_add_counter (biterm_counts , doc_biterms )
68
73
return biterm_counts
69
74
70
75
71
76
@njit
72
77
def compute_biterm_set (
73
- biterm_counts : Dict [Tuple [int , int ], int ]
78
+ biterm_counts : Dict [Tuple [int , int ], int ],
74
79
) -> np .ndarray :
75
80
return np .array (list (biterm_counts .keys ()))
76
81
@@ -115,7 +120,12 @@ def add_biterm(
115
120
topic_biterm_count : np .ndarray ,
116
121
) -> None :
117
122
add_remove_biterm (
118
- True , i_biterm , i_topic , biterms , topic_word_count , topic_biterm_count
123
+ True ,
124
+ i_biterm ,
125
+ i_topic ,
126
+ biterms ,
127
+ topic_word_count ,
128
+ topic_biterm_count ,
119
129
)
120
130
121
131
@@ -128,7 +138,12 @@ def remove_biterm(
128
138
topic_biterm_count : np .ndarray ,
129
139
) -> None :
130
140
add_remove_biterm (
131
- False , i_biterm , i_topic , biterms , topic_word_count , topic_biterm_count
141
+ False ,
142
+ i_biterm ,
143
+ i_topic ,
144
+ biterms ,
145
+ topic_word_count ,
146
+ topic_biterm_count ,
132
147
)
133
148
134
149
@@ -146,7 +161,11 @@ def init_components(
146
161
i_topic = random .randint (0 , n_components - 1 )
147
162
biterm_topic_assignments [i_biterm ] = i_topic
148
163
add_biterm (
149
- i_biterm , i_topic , biterms , topic_word_count , topic_biterm_count
164
+ i_biterm ,
165
+ i_topic ,
166
+ biterms ,
167
+ topic_word_count ,
168
+ topic_biterm_count ,
150
169
)
151
170
return biterm_topic_assignments , topic_word_count , topic_biterm_count
152
171
@@ -360,7 +379,10 @@ def predict_docs(
360
379
)
361
380
biterms = doc_unique_biterms (words , word_counts )
362
381
prob_topic_given_document (
363
- pred , biterms , topic_distribution , topic_word_distribution
382
+ pred ,
383
+ biterms ,
384
+ topic_distribution ,
385
+ topic_word_distribution ,
364
386
)
365
387
predictions [i_doc , :] = pred
366
388
return predictions
0 commit comments