Skip to content

Commit 59e3a40

Browse files
authored
Merge pull request #165 from henk717/united
1.19.1
2 parents 6af0e84 + 64715b1 commit 59e3a40

File tree

3 files changed

+203
-16
lines changed

3 files changed

+203
-16
lines changed

aiserver.py

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/python3
22
#==================================================================#
33
# KoboldAI
4-
# Version: 1.19.0
4+
# Version: 1.19.1
55
# By: The KoboldAI Community
66
#==================================================================#
77

@@ -377,6 +377,7 @@ class vars:
377377
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
378378
comregex_ui = re.compile(r'(&lt;\|(?:.|\n)*?\|&gt;)') # Pattern for matching comments in the editor
379379
sampler_order = utils.default_sampler_order.copy()
380+
rng_states = {} # Used by the POST /generate endpoint to store sampler RNG states
380381
chatmode = False
381382
chatname = "You"
382383
adventure = False
@@ -630,7 +631,7 @@ def delete(self, rule: str, **kwargs):
630631
api_version = None # This gets set automatically so don't change this value
631632

632633
api_v1 = KoboldAPISpec(
633-
version="1.1.4",
634+
version="1.2.0",
634635
prefixes=["/api/v1", "/api/latest"],
635636
tags=tags,
636637
)
@@ -2963,7 +2964,7 @@ def load_lua_scripts():
29632964
if(vars.serverstarted):
29642965
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
29652966
sendUSStatItems()
2966-
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
2967+
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
29672968
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
29682969
if(vars.serverstarted):
29692970
set_aibusy(0)
@@ -7450,6 +7451,13 @@ def story_load_validator(name: str):
74507451
raise ValidationError("Must be a valid story name.")
74517452
return True
74527453

7454+
def permutation_validator(lst: list):
7455+
if any(not isinstance(e, int) for e in lst):
7456+
return
7457+
if min(lst) != 0 or max(lst) != len(lst) - 1 or len(set(lst)) != len(lst):
7458+
raise ValidationError("Must be a permutation of the first N non-negative integers, where N is the length of this array")
7459+
return True
7460+
74537461
class GenerationInputSchema(SamplerSettingsSchema):
74547462
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
74557463
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
@@ -7469,6 +7477,9 @@ class GenerationInputSchema(SamplerSettingsSchema):
74697477
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"})
74707478
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
74717479
quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."})
7480+
sampler_order: Optional[List[int]] = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], metadata={"description": "Sampler order to be used. If N is the length of this array, then N must be greater than or equal to 6 and the array must be a permutation of the first N non-negative integers."})
7481+
sampler_seed: Optional[int] = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), metadata={"description": "RNG seed to use for sampling. If not specified, the global RNG will be used."})
7482+
sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."})
74727483

74737484
class GenerationResultSchema(KoboldSchema):
74747485
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
@@ -7559,6 +7570,29 @@ def _generate_text(body: GenerationInputSchema):
75597570
"msg": "Server is busy; please try again later.",
75607571
"type": "service_unavailable",
75617572
}}), mimetype="application/json", status=503))
7573+
if vars.use_colab_tpu:
7574+
import tpu_mtj_backend
7575+
if hasattr(body, "sampler_seed"):
7576+
# If a seed was specified, we need to save the global RNG state so we
7577+
# can restore it later
7578+
old_seed = vars.seed
7579+
old_rng_state = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
7580+
vars.seed = body.sampler_seed
7581+
# We should try to use a previously saved RNG state with the same seed
7582+
if body.sampler_seed in vars.rng_states:
7583+
if vars.use_colab_tpu:
7584+
tpu_mtj_backend.set_rng_state(vars.rng_states[body.sampler_seed])
7585+
else:
7586+
torch.set_rng_state(vars.rng_states[body.sampler_seed])
7587+
else:
7588+
if vars.use_colab_tpu:
7589+
tpu_mtj_backend.set_rng_state(tpu_mtj_backend.new_rng_state(body.sampler_seed))
7590+
else:
7591+
torch.manual_seed(body.sampler_seed)
7592+
vars.rng_states[body.sampler_seed] = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
7593+
if hasattr(body, "sampler_order"):
7594+
if len(body.sampler_order) < 7:
7595+
body.sampler_order = [6] + body.sampler_order
75627596
# This maps each property of the setting to use when sending the generate idempotently
75637597
# To the object which typically contains it's value
75647598
# This allows to set the property only for the API generation, and then revert the setting
@@ -7584,6 +7618,8 @@ def _generate_text(body: GenerationInputSchema):
75847618
"max_context_length": ("vars", "max_length", None),
75857619
"n": ("vars", "numseqs", None),
75867620
"quiet": ("vars", "quiet", None),
7621+
"sampler_order": ("vars", "sampler_order", None),
7622+
"sampler_full_determinism": ("vars", "full_determinism", None),
75877623
}
75887624
saved_settings = {}
75897625
set_aibusy(1)
@@ -7633,6 +7669,12 @@ def _generate_text(body: GenerationInputSchema):
76337669
vars.output_streaming = output_streaming
76347670
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
76357671
spRequest(old_spfilename)
7672+
if hasattr(body, "sampler_seed"):
7673+
vars.seed = old_seed
7674+
if vars.use_colab_tpu:
7675+
tpu_mtj_backend.set_rng_state(old_rng_state)
7676+
else:
7677+
torch.set_rng_state(old_rng_state)
76367678
set_aibusy(0)
76377679
return output
76387680

@@ -9838,6 +9880,60 @@ def put_config_soft_prompt(body: SoftPromptSettingSchema):
98389880
settingschanged()
98399881
return {}
98409882

9883+
class SamplerSeedSettingSchema(KoboldSchema):
9884+
value: int = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), required=True)
9885+
9886+
@api_v1.get("/config/sampler_seed")
9887+
@api_schema_wrap
9888+
def get_config_sampler_seed():
9889+
"""---
9890+
get:
9891+
summary: Retrieve the current global sampler seed value
9892+
tags:
9893+
- config
9894+
responses:
9895+
200:
9896+
description: Successful request
9897+
content:
9898+
application/json:
9899+
schema: SamplerSeedSettingSchema
9900+
example:
9901+
value: 3475097509890965500
9902+
"""
9903+
return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()}
9904+
9905+
@api_v1.put("/config/sampler_seed")
9906+
@api_schema_wrap
9907+
def put_config_sampler_seed(body: SamplerSeedSettingSchema):
9908+
"""---
9909+
put:
9910+
summary: Set the global sampler seed value
9911+
tags:
9912+
- config
9913+
requestBody:
9914+
required: true
9915+
content:
9916+
application/json:
9917+
schema: SamplerSeedSettingSchema
9918+
example:
9919+
value: 3475097509890965500
9920+
responses:
9921+
200:
9922+
description: Successful request
9923+
content:
9924+
application/json:
9925+
schema: EmptySchema
9926+
{api_validation_error_response}
9927+
"""
9928+
if vars.use_colab_tpu:
9929+
import tpu_mtj_backend
9930+
tpu_mtj_backend.set_rng_seed(body.value)
9931+
else:
9932+
import torch
9933+
torch.manual_seed(body.value)
9934+
vars.seed = body.value
9935+
return {}
9936+
98419937
config_endpoint_schemas: List[Type[KoboldSchema]] = []
98429938

98439939
def config_endpoint_schema(c: Type[KoboldSchema]):
@@ -10035,6 +10131,25 @@ class KoboldMeta:
1003510131
name = "add sentence spacing (input formatting)"
1003610132
example_yaml_value = "false"
1003710133

10134+
@config_endpoint_schema
10135+
class SamplerOrderSettingSchema(KoboldSchema):
10136+
value = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], required=True)
10137+
class KoboldMeta:
10138+
route_name = "sampler_order"
10139+
obj = "vars"
10140+
var_name = "sampler_order"
10141+
name = "sampler order"
10142+
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
10143+
10144+
@config_endpoint_schema
10145+
class SamplerFullDeterminismSettingSchema(KoboldSchema):
10146+
value = fields.Boolean(required=True)
10147+
class KoboldMeta:
10148+
route_name = "sampler_full_determinism"
10149+
obj = "vars"
10150+
var_name = "full_determinism"
10151+
name = "sampler full determinism"
10152+
example_yaml_value = "false"
1003810153

1003910154

1004010155
for schema in config_endpoint_schemas:

torch_lazy_loader.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@
5050
import zipfile
5151
import pickle
5252
import torch
53+
import numpy as np
54+
import collections
55+
import _codecs
5356
import utils
5457
from torch.nn import Module
55-
from typing import Any, Callable, Dict, Optional, Tuple, Union
58+
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
5659

5760

5861
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
@@ -111,8 +114,50 @@ def materialize(self, checkpoint: Union[zipfile.ZipFile, zipfile.ZipExtFile], ma
111114
tensor._backward_hooks = self.backward_hooks
112115
return tensor
113116

117+
class RestrictedUnpickler(pickle.Unpickler):
118+
def original_persistent_load(self, saved_id):
119+
return super().persistent_load(saved_id)
114120

115-
class _LazyUnpickler(pickle.Unpickler):
121+
def forced_persistent_load(self, saved_id):
122+
if saved_id[0] != "storage":
123+
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
124+
return self.original_persistent_load(saved_id)
125+
126+
def find_class(self, module, name):
127+
if module == "collections" and name == "OrderedDict":
128+
return collections.OrderedDict
129+
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
130+
return torch._utils._rebuild_tensor_v2
131+
elif module == "torch" and name in (
132+
"DoubleStorage",
133+
"FloatStorage",
134+
"HalfStorage",
135+
"LongStorage",
136+
"IntStorage",
137+
"ShortStorage",
138+
"CharStorage",
139+
"ByteStorage",
140+
"BoolStorage",
141+
"BFloat16Storage",
142+
):
143+
return getattr(torch, name)
144+
elif module == "numpy.core.multiarray" and name == "scalar":
145+
return np.core.multiarray.scalar
146+
elif module == "numpy" and name == "dtype":
147+
return np.dtype
148+
elif module == "_codecs" and name == "encode":
149+
return _codecs.encode
150+
else:
151+
# Forbid everything else.
152+
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
153+
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
154+
155+
def load(self, *args, **kwargs):
156+
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
157+
self.persistent_load = self.forced_persistent_load
158+
return super().load(*args, **kwargs)
159+
160+
class _LazyUnpickler(RestrictedUnpickler):
116161
lazy_loaded_storages: Dict[str, LazyTensor]
117162

118163
def __init__(self, *args, **kwargs):
@@ -127,7 +172,6 @@ def forced_persistent_load(self, saved_id):
127172
return LazyTensor(storage_type, key, location)
128173

129174
def load(self, *args, **kwargs):
130-
self.persistent_load = self.forced_persistent_load
131175
retval = super().load(*args, **kwargs)
132176
self.lazy_loaded_storages = {}
133177
return retval
@@ -213,16 +257,33 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
213257
unexpected_keys.append(key)
214258

215259

260+
@contextlib.contextmanager
261+
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
262+
try:
263+
old_unpickler = pickle.Unpickler
264+
pickle.Unpickler = unpickler
265+
266+
old_pickle_load = pickle.load
267+
268+
def new_pickle_load(*args, **kwargs):
269+
return pickle.Unpickler(*args, **kwargs).load()
270+
271+
pickle.load = new_pickle_load
272+
273+
yield
274+
275+
finally:
276+
pickle.Unpickler = old_unpickler
277+
pickle.load = old_pickle_load
278+
216279
@contextlib.contextmanager
217280
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
218281
if not enable:
219-
yield False
282+
with use_custom_unpickler(RestrictedUnpickler):
283+
yield False
220284
return
221285

222286
try:
223-
old_unpickler = pickle.Unpickler
224-
pickle.Unpickler = _LazyUnpickler
225-
226287
old_rebuild_tensor = torch._utils._rebuild_tensor
227288
torch._utils._rebuild_tensor = _rebuild_tensor
228289

@@ -261,10 +322,10 @@ def layernorm_init(self, *args, device=None, **kwargs):
261322
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
262323
torch.nn.Module._load_from_state_dict = _load_from_state_dict
263324

264-
yield True
325+
with use_custom_unpickler(_LazyUnpickler):
326+
yield True
265327

266328
finally:
267-
pickle.Unpickler = old_unpickler
268329
torch._utils._rebuild_tensor = old_rebuild_tensor
269330
torch.load = old_torch_load
270331
if dematerialized_modules:

tpu_mtj_backend.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
params: Dict[str, Any] = {}
5757

58-
__seed = random.randrange(sys.maxsize)
58+
__seed = random.randrange(2**64)
5959
rng = random.Random(__seed)
6060

6161

@@ -69,8 +69,17 @@ def set_rng_seed(seed: int):
6969
return seed
7070

7171
def randomize_rng_seed():
72-
return set_rng_seed(random.randrange(sys.maxsize))
72+
return set_rng_seed(random.randrange(2**64))
7373

74+
def get_rng_state():
75+
return rng
76+
77+
def set_rng_state(state):
78+
global rng
79+
rng = state
80+
81+
def new_rng_state(seed: int):
82+
return random.Random(seed)
7483

7584
def warper_callback(logits) -> np.array:
7685
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
@@ -946,6 +955,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
946955

947956
import torch
948957
import torch.utils.dlpack
958+
import torch_lazy_loader
949959
from tqdm.auto import tqdm
950960

951961
move_xmap = jax.experimental.maps.xmap(
@@ -987,8 +997,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
987997
continue
988998
layer = checkpoint_layer - 2
989999
shards = []
990-
for checkpoint_shard in range(checkpoint_shards):
991-
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
1000+
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
1001+
for checkpoint_shard in range(checkpoint_shards):
1002+
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
9921003
for key in shards[0]:
9931004
if key == "attention.rotary_emb.inv_freq":
9941005
continue

0 commit comments

Comments
 (0)