@@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
88from libc.stdint cimport uint32_t
99from libc.string cimport memcpy
1010from murmurhash.mrmr cimport hash32, hash64
11+ from preshed.maps cimport map_clear
1112
1213import srsly
1314
@@ -125,10 +126,9 @@ cdef class StringStore:
125126 self .mem = Pool()
126127 self ._non_temp_mem = self .mem
127128 self ._map = PreshMap()
128- self ._transient_map = None
129129 if strings is not None :
130130 for string in strings:
131- self .add(string)
131+ self .add(string, allow_transient = False )
132132
133133 def __getitem__ (self , object string_or_id ):
134134 """ Retrieve a string from a given hash, or vice versa.
@@ -158,17 +158,17 @@ cdef class StringStore:
158158 return SYMBOLS_BY_INT[str_hash]
159159 else :
160160 utf8str = < Utf8Str* > self ._map.get(str_hash)
161- if utf8str is NULL and self ._transient_map is not None :
162- utf8str = < Utf8Str* > self ._transient_map.get(str_hash)
161+ if utf8str is NULL :
162+ raise KeyError (Errors.E018.format(hash_value = string_or_id))
163+ else :
164+ return decode_Utf8Str(utf8str)
163165 else :
164166 # TODO: Raise an error instead
165167 utf8str = < Utf8Str* > self ._map.get(string_or_id)
166- if utf8str is NULL and self ._transient_map is not None :
167- utf8str = < Utf8Str* > self ._transient_map.get(str_hash)
168- if utf8str is NULL :
169- raise KeyError (Errors.E018.format(hash_value = string_or_id))
170- else :
171- return decode_Utf8Str(utf8str)
168+ if utf8str is NULL :
169+ raise KeyError (Errors.E018.format(hash_value = string_or_id))
170+ else :
171+ return decode_Utf8Str(utf8str)
172172
173173 def as_int (self , key ):
174174 """ If key is an int, return it; otherwise, get the int value."""
@@ -184,16 +184,12 @@ cdef class StringStore:
184184 else :
185185 return self [key]
186186
187- def __reduce__ (self ):
188- strings = list (self .non_transient_keys())
189- return (StringStore, (strings,), None , None , None )
190-
191187 def __len__ (self ) -> int:
192188 """The number of strings in the store.
193189
194190 RETURNS (int ): The number of strings in the store.
195191 """
196- return self._keys .size() + self._transient_keys.size()
192+ return self.keys .size() + self._transient_keys.size()
197193
198194 @contextmanager
199195 def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
@@ -209,13 +205,13 @@ cdef class StringStore:
209205 if mem is None:
210206 mem = Pool()
211207 self.mem = mem
212- self._transient_map = PreshMap()
213208 yield mem
214- self.mem = self._non_temp_mem
215- self._transient_map = None
209+ for key in self._transient_keys:
210+ map_clear( self._map.c_map, key)
216211 self._transient_keys.clear()
212+ self.mem = self._non_temp_mem
217213
218- def add(self, string: str, allow_transient: bool = False ) -> int:
214+ def add(self, string: str, allow_transient: Optional[ bool] = None ) -> int:
219215 """ Add a string to the StringStore.
220216
221217 string (str ): The string to add.
@@ -226,6 +222,8 @@ cdef class StringStore:
226222 internally should not .
227223 RETURNS (uint64): The string' s hash value.
228224 """
225+ if allow_transient is None:
226+ allow_transient = self.mem is not self._non_temp_mem
229227 cdef hash_t str_hash
230228 if isinstance(string, str):
231229 if string in SYMBOLS_BY_STR:
@@ -273,17 +271,13 @@ cdef class StringStore:
273271 # TODO: Raise an error instead
274272 if self._map.get(string_or_id) is not NULL:
275273 return True
276- elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
277- return True
278274 else:
279275 return False
280276 if str_hash < len(SYMBOLS_BY_INT):
281277 return True
282278 else:
283279 if self._map.get(str_hash) is not NULL:
284280 return True
285- elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
286- return True
287281 else:
288282 return False
289283
@@ -292,32 +286,21 @@ cdef class StringStore:
292286
293287 YIELDS (str ): A string in the store.
294288 """
295- yield from self.non_transient_keys()
296- yield from self.transient_keys()
297-
298- def non_transient_keys(self) -> Iterator[str]:
299- """ Iterate over the stored strings in insertion order.
300-
301- RETURNS: A list of strings.
302- """
303289 cdef int i
304290 cdef hash_t key
305291 for i in range(self.keys.size()):
306292 key = self.keys[i]
307293 utf8str = <Utf8Str*>self._map.get(key)
308294 yield decode_Utf8Str(utf8str)
295+ for i in range(self._transient_keys.size()):
296+ key = self._transient_keys[i]
297+ utf8str = <Utf8Str*>self._map.get(key)
298+ yield decode_Utf8Str(utf8str)
309299
310300 def __reduce__(self):
311301 strings = list(self)
312302 return (StringStore, (strings,), None, None, None)
313303
314- def transient_keys(self) -> Iterator[str]:
315- if self._transient_map is None:
316- return []
317- for i in range(self._transient_keys.size()):
318- utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
319- yield decode_Utf8Str(utf8str)
320-
321304 def values(self) -> List[int]:
322305 """ Iterate over the stored strings hashes in insertion order.
323306
@@ -327,12 +310,9 @@ cdef class StringStore:
327310 hashes = [None] * self._keys.size()
328311 for i in range(self._keys.size()):
329312 hashes[i] = self._keys[i]
330- if self._transient_map is not None:
331- transient_hashes = [None] * self._transient_keys.size()
332- for i in range(self._transient_keys.size()):
333- transient_hashes[i] = self._transient_keys[i]
334- else:
335- transient_hashes = []
313+ transient_hashes = [None] * self._transient_keys.size()
314+ for i in range(self._transient_keys.size()):
315+ transient_hashes[i] = self._transient_keys[i]
336316 return hashes + transient_hashes
337317
338318 def to_disk(self, path):
@@ -383,8 +363,10 @@ cdef class StringStore:
383363
384364 def _reset_and_load(self, strings):
385365 self.mem = Pool()
366+ self._non_temp_mem = self.mem
386367 self._map = PreshMap()
387368 self.keys.clear()
369+ self._transient_keys.clear()
388370 for string in strings:
389371 self.add(string, allow_transient=False)
390372
@@ -401,19 +383,10 @@ cdef class StringStore:
401383 cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
402384 if value is not NULL:
403385 return value
404- if allow_transient and self._transient_map is not None:
405- # If we've already allocated a transient string, and now we
406- # want to intern it permanently, we'll end up with the string
407- # in both places. That seems fine -- I don't see why we need
408- # to remove it from the transient map.
409- value = <Utf8Str*>self._transient_map.get(key)
410- if value is not NULL:
411- return value
412386 value = _allocate(self.mem, <unsigned char*>utf8_string, length)
413- if allow_transient and self._transient_map is not None:
414- self._transient_map.set(key, value)
387+ self._map.set(key, value)
388+ if allow_transient and self.mem is not self._non_temp_mem:
415389 self._transient_keys.push_back(key)
416390 else:
417- self._map.set(key, value)
418391 self.keys.push_back(key)
419392 return value
0 commit comments