@@ -150,29 +150,46 @@ def __init__(self, sequence: str):
150
150
self .sequence = sequence
151
151
self .mutations = self .initialize (sequence )
152
152
153
- def initialize (self , sequence : str ) -> dict [int , list [str ]]:
153
+ def initialize (self , sequence : str ) -> dict [int , set [str ]]:
154
154
"""Initialize with no changes allowed to the sequence."""
155
- return {i : [ aa ] for i , aa in enumerate (sequence , start = 1 )}
155
+ return {i : { aa } for i , aa in enumerate (sequence , start = 1 )}
156
156
157
- def allow (self , positions : int | list [int ], amino_acids : list [str ] | str ) -> None :
157
+ def allow (
158
+ self ,
159
+ amino_acids : list [str ] | str | None = None ,
160
+ positions : int | list [int ] | None = None ,
161
+ ) -> None :
158
162
"""Allow specific amino acids at given positions."""
159
163
if isinstance (positions , int ):
160
164
positions = [positions ]
165
+ elif positions is None :
166
+ positions = [i + 1 for i in range (len (self .sequence ))]
161
167
if isinstance (amino_acids , str ):
162
168
amino_acids = list (amino_acids )
169
+ elif amino_acids is None :
170
+ amino_acids = list (self .sequence )
163
171
164
172
for position in positions :
165
173
if position in self .mutations :
166
- self .mutations [position ].extend (amino_acids )
174
+ for aa in amino_acids :
175
+ self .mutations [position ].add (aa )
167
176
else :
168
- self .mutations [position ] = amino_acids
177
+ self .mutations [position ] = set ( amino_acids )
169
178
170
- def remove (self , positions : int | list [int ], amino_acids : list [str ] | str ) -> None :
179
+ def remove (
180
+ self ,
181
+ amino_acids : list [str ] | str | None = None ,
182
+ positions : int | list [int ] | None = None ,
183
+ ) -> None :
171
184
"""Remove specific amino acids from being allowed at given positions."""
172
185
if isinstance (positions , int ):
173
186
positions = [positions ]
187
+ elif positions is None :
188
+ positions = [i + 1 for i in range (len (self .sequence ))]
174
189
if isinstance (amino_acids , str ):
175
190
amino_acids = list (amino_acids )
191
+ elif amino_acids is None :
192
+ amino_acids = list (self .sequence )
176
193
177
194
for position in positions :
178
195
if position in self .mutations :
@@ -182,4 +199,4 @@ def remove(self, positions: int | list[int], amino_acids: list[str] | str) -> No
182
199
183
200
def as_dict (self ) -> dict [int , list [str ]]:
184
201
"""Convert the internal mutations representation into a dictionary."""
185
- return self .mutations
202
+ return { i : list ( aa ) for i , aa in self .mutations . items ()}
0 commit comments