Skip to content

Commit f69b59c

Browse files
authored
Update coreml conversion script and add reference to swift repo for testing (#9)
* segment any text model * Update README to mention PyTorch version * Improve formatting in README.md for commands Updated README formatting for clarity. * add hugging face link to download model * changed directory location * updated conversion script retaining accuracy * reference swift repo for testing
1 parent 33fd6ea commit f69b59c

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

models/segment-text/coreml/README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ Segment Any Text is state-of-the-art sentence segmentation with 3 Transfomer lay
33

44
If you wish to skip the CoreML conversion, you can download a precompiled `SaT.mlmodelc` from [Hugging Face](https://huggingface.co/smdesai/SaT).
55

6+
## Swift Usage
7+
8+
Swift sample code for testing and integrating the Core ML model is available at [smdesai/SegmentText](https://github.com/smdesai/SegmentText).
9+
610

711
# CoreML Conversion
812

@@ -51,7 +55,7 @@ Usage: convert_sat.py [OPTIONS]
5155

5256
Run the following to compile the model.
5357
```bash
54-
python compile_mlmodelc.py --coreml-dir sat_coreml
58+
python compile_mlmodelc.py --coreml-dir sat_coreml --output-dir compiled
5559
```
5660

5761
This produces `SaT.mlmodelc` in the `compiled` directory.
@@ -61,6 +65,8 @@ Here is the complete usage:
6165
Usage: compile_mlmodelc.py [OPTIONS]
6266

6367
Options
64-
--coreml-dir PATH Directory where mlpackages and metadata are written
65-
[default: sat_coreml]
68+
--coreml-dir PATH Directory where the mlpackage is
69+
[default: sat_coreml]
70+
--output-dir PATH Directory where the compiled model is written
71+
[default: compiled]
6672
```

models/segment-text/coreml/compile_mlmodelc.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,12 @@ def gather_packages(dir: str) -> list[Path]:
4141
return packages
4242

4343

44-
def compile_package(package: Path) -> None:
44+
def compile_package(package: Path, output_dir: Path) -> None:
4545
"""Compile a single ``.mlpackage`` bundle using ``xcrun coremlcompiler``."""
4646
relative_pkg = package.relative_to(BASE_DIR)
47-
#output_dir = OUTPUT_ROOT / relative_pkg.parent
48-
output_dir = OUTPUT_ROOT
49-
output_dir.mkdir(parents=True, exist_ok=True)
50-
output_path = output_dir / f"{package.stem}.mlmodelc"
47+
resolved_output_dir = output_dir if output_dir.is_absolute() else BASE_DIR / output_dir
48+
resolved_output_dir.mkdir(parents=True, exist_ok=True)
49+
output_path = resolved_output_dir / f"{package.stem}.mlmodelc"
5150

5251
if output_path.exists():
5352
shutil.rmtree(output_path)
@@ -57,18 +56,27 @@ def compile_package(package: Path) -> None:
5756
"coremlcompiler",
5857
"compile",
5958
str(package),
60-
str(output_dir),
59+
str(resolved_output_dir),
6160
]
6261

63-
print(f"Compiling {relative_pkg} -> {output_path.relative_to(BASE_DIR)}")
62+
try:
63+
relative_output = output_path.relative_to(BASE_DIR)
64+
except ValueError:
65+
relative_output = output_path
66+
67+
print(f"Compiling {relative_pkg} -> {relative_output}")
6468
subprocess.run(cmd, check=True)
6569

6670

6771
@app.command()
6872
def compile(
6973
coreml_dir: Path = typer.Option(
7074
Path("sat_coreml"),
71-
help="Directory where mlpackages and metadata are written",
75+
help="Directory where the mlpackage is",
76+
),
77+
output_dir: Path = typer.Option(
78+
Path("compiled"),
79+
help="Directory where the compiled model is written",
7280
),
7381
):
7482
ensure_coremlcompiler()
@@ -80,7 +88,7 @@ def compile(
8088

8189
for package in packages:
8290
try:
83-
compile_package(package)
91+
compile_package(package, output_dir)
8492
except subprocess.CalledProcessError as exc:
8593
print(f"Failed to compile {package}: {exc}", file=sys.stderr)
8694
sys.exit(exc.returncode)

models/segment-text/coreml/convert_sat.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,24 @@ def convert(
8383
return_dict=False,
8484
torchscript=True,
8585
trust_remote_code=True,
86-
).eval()
86+
).eval().to("cpu")
87+
88+
class WrappedModel(torch.nn.Module):
89+
def __init__(self, base_model: torch.nn.Module):
90+
super().__init__()
91+
self.base_model = base_model
92+
93+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
94+
input_ids = input_ids.to(dtype=torch.long)
95+
attention_mask = attention_mask.to(dtype=torch.long)
96+
outputs = self.base_model(input_ids, attention_mask)
97+
if isinstance(outputs, (tuple, list)):
98+
logits = outputs[0]
99+
else:
100+
logits = outputs
101+
return logits.to(dtype=torch.float32)
102+
103+
wrapped_model = WrappedModel(model).eval()
87104

88105
tokenizer = AutoTokenizer.from_pretrained("facebookAI/xlm-roberta-base")
89106
tokenized = tokenizer(
@@ -93,12 +110,17 @@ def convert(
93110
padding="max_length",
94111
)
95112

96-
traced_model = torch.jit.trace(
97-
model,
98-
(tokenized["input_ids"], tokenized["attention_mask"])
113+
example_inputs = (
114+
tokenized["input_ids"].to(torch.int32),
115+
tokenized["attention_mask"].to(torch.int32),
99116
)
117+
traced_model = torch.jit.trace(wrapped_model, example_inputs, strict=False)
118+
traced_model.eval()
119+
120+
with torch.no_grad():
121+
sample_output = wrapped_model(*example_inputs)
100122

101-
outputs = [ct.TensorType(name="output")]
123+
output_spec = ct.TensorType(name="logits", dtype=np.float32)
102124

103125
mlpackage = ct.convert(
104126
traced_model,
@@ -111,8 +133,9 @@ def convert(
111133
)
112134
for name, tensor in tokenized.items()
113135
],
114-
outputs=outputs,
136+
outputs=[output_spec],
115137
compute_units=ct.ComputeUnit.ALL,
138+
compute_precision=ct.precision.FLOAT32,
116139
minimum_deployment_target=ct.target.iOS18,
117140
)
118141

0 commit comments

Comments
 (0)