diff --git a/.gitignore b/.gitignore
index b63da45..4c36a00 100644
--- a/.gitignore
+++ b/.gitignore
@@ -9,6 +9,7 @@ build/
.idea/jarRepositories.xml
.idea/compiler.xml
.idea/libraries/
+.idea
*.iws
*.iml
*.ipr
@@ -39,4 +40,4 @@ bin/
.vscode/
### Mac OS ###
-.DS_Store
\ No newline at end of file
+.DS_Store
diff --git a/.idea/gradle.xml b/.idea/gradle.xml
deleted file mode 100644
index 54ea192..0000000
--- a/.idea/gradle.xml
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index b0137f1..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,7 +0,0 @@
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 35eb1dd..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
deleted file mode 100644
index ebe296d..0000000
--- a/.idea/workspace.xml
+++ /dev/null
@@ -1,61 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 1683078027782
-
-
- 1683078027782
-
-
-
-
\ No newline at end of file
diff --git a/README.md b/README.md
index 6c7e6e8..ee0e3a7 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,18 @@
-# sam-java
+# segment-anything.java
Meta's [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) model, ported to Java SE 17.
-Eventual goal is to move this into a [JOSM](https://josm.openstreetmap.de/) [plugin](https://github.com/JOSM/josm-plugins) for segmentation of aerial imagery for OpenStreetMap. Further versions from that will be [finetuned](https://github.com/ctrlaltf2/segment-any-landuse) on aerial imagery for better results.
+Eventual goal is to use this in a [JOSM](https://josm.openstreetmap.de/) [plugin](https://github.com/JOSM/josm-plugins) for segmentation of aerial imagery for OpenStreetMap. Further versions from that will be [finetuned](https://github.com/ctrlaltf2/segment-any-landuse) on aerial imagery for better results.
+
+## Roadmap
+ - [x] Reproduce ONNX export of encoder and decoder (ref: https://github.com/visheratin/segment-anything)
+ - [x] Image loading and preprocessing
+ - [x] Encoder forward pass
+ - [x] OnnxTensor to primitive matrix type conversion
+ - [x] Decoder forward pass, basic coordinate-based prompt
+ - [ ] Decoder mask post-processing (mapping back to the input image)
+ - [ ] Multi-mask support (currently T/F value is baked into the ONNX model at export-time)
+ - [ ] Error handling TODOs
+ - [ ] Improve upon ONNX export method
+ - [ ] Figure out a place to host the ONNX models (~5 GB total)
+ - [ ] FastSAM? It's such a different model though architecturally, might be out of the scope of this repository.
+ - [ ] GPU runtimes?
diff --git a/build.gradle.kts b/build.gradle.kts
index d4b384c..5269645 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -12,7 +12,7 @@ repositories {
dependencies {
testImplementation(platform("org.junit:junit-bom:5.9.1"))
testImplementation("org.junit.jupiter:junit-jupiter")
- implementation("com.microsoft.onnxruntime:onnxruntime:1.14.0")
+ implementation("com.microsoft.onnxruntime:onnxruntime:1.16.3")
implementation("org.nd4j:nd4j:1.0.0-M2.1")
implementation("org.nd4j:nd4j-native-platform:1.0.0-M2.1")
}
diff --git a/src/main/java/dev/troyer/sam/ConstraintType.java b/src/main/java/dev/troyer/sam/ConstraintType.java
new file mode 100644
index 0000000..8bd584a
--- /dev/null
+++ b/src/main/java/dev/troyer/sam/ConstraintType.java
@@ -0,0 +1,10 @@
+package dev.troyer.sam;
+
+/**
+ * Enum for representing what type of constraint is applied to a
+ * decoder model query point
+ */
+public enum ConstraintType {
+ BACKGROUND,
+ FOREGROUND
+}
diff --git a/src/main/java/dev/troyer/sam/Main.java b/src/main/java/dev/troyer/sam/Main.java
index d30116a..f3e740b 100644
--- a/src/main/java/dev/troyer/sam/Main.java
+++ b/src/main/java/dev/troyer/sam/Main.java
@@ -1,37 +1,23 @@
package dev.troyer.sam;
-import java.awt.image.BufferedImage;
-import java.io.File;
+import ai.onnxruntime.OrtException;
+
import java.io.IOException;
import java.nio.file.Path;
-import javax.imageio.ImageIO;
-
-import ai.onnxruntime.OrtEnvironment;
-import ai.onnxruntime.OrtException;
public class Main {
- public static void main(String[] args) throws OrtException {
- String imagePath = "/home/caleb/Pictures/example.jpg";
- BufferedImage image = null;
-
- try {
- image = ImageIO.read(new File(imagePath));
- } catch (IOException e) {
- System.err.println("Error while loading image: " + e.getMessage());
- return;
- }
+ public static void main(String[] args) throws IOException, OrtException {
+ SamPredictor predictor = new SamPredictor(
+ Path.of("./src/main/resources/data/vit_b/encoder-vit_b.quant.onnx"),
+ Path.of("./src/main/resources/data/vit_b/decoder-vit_b.quant.onnx")
+ );
- OrtEnvironment env = OrtEnvironment.getEnvironment();
- SamImage samImage = new SamImage(image);
+ predictor.setImageFromPath(
+ Path.of("./test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg")
+ );
- try {
- Path model_path = Path.of("./src/main/resources/data/vit_b/encoder-vit_b.quant.onnx");
- SamEncoder encoder = new SamEncoder(env, model_path);
- var out = encoder.forward(env, samImage);
- System.out.println(out);
- } catch (OrtException e) {
- System.err.println("Error while loading model: " + e.getMessage());
- return;
- }
+ final SamResult result = predictor.predict(new SamConstraint[]{
+ new SamConstraint(1980, 1200, ConstraintType.FOREGROUND)
+ });
}
}
diff --git a/src/main/java/dev/troyer/sam/SamConstraint.java b/src/main/java/dev/troyer/sam/SamConstraint.java
new file mode 100644
index 0000000..4d9e9cc
--- /dev/null
+++ b/src/main/java/dev/troyer/sam/SamConstraint.java
@@ -0,0 +1,29 @@
+package dev.troyer.sam;
+
+/**
+ * Represents a constraint to the SAM decoder model.
+ * A constraint is the foreground/background parts you select in
+ * all those nice Segment Anything frontends.
+ */
+public class SamConstraint {
+ /**
+ * Column in image (pixel space). Origin is top left.
+ */
+ final public int x;
+
+ /**
+ * Row in image (pixel space). Origin is top left.
+ */
+ final public int y;
+
+ /**
+ * Type of constraint (foreground/background)
+ */
+ final public ConstraintType type;
+
+ public SamConstraint(int x, int y, ConstraintType type) {
+ this.x = x;
+ this.y = y;
+ this.type = type;
+ }
+}
diff --git a/src/main/java/dev/troyer/sam/SamDecoder.java b/src/main/java/dev/troyer/sam/SamDecoder.java
index be12d6a..82d9a1c 100644
--- a/src/main/java/dev/troyer/sam/SamDecoder.java
+++ b/src/main/java/dev/troyer/sam/SamDecoder.java
@@ -1,4 +1,60 @@
package dev.troyer.sam;
+import ai.onnxruntime.*;
+
+import java.nio.file.Path;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
public class SamDecoder {
+ /**
+ * Loaded ONNX model
+ */
+ private final OrtSession model;
+
+ /**
+ * @param env ONNX environment context
+ * @param modelPath Path to the model file
+ */
+ public SamDecoder(OrtEnvironment env, Path modelPath) throws OrtException {
+ this.model = env.createSession(modelPath.toString());
+ // TODO: Verify input/output names. These could vary based on how/who exported the model.
+ }
+
+ /**
+ * @return shape of the mask inputs parameter
+ */
+ public long[] getMaskInputsShape() throws OrtException {
+ // TODO: better error handling, don't assume output with this name exists
+ final Map inputsInfo = model.getInputInfo();
+ final NodeInfo maskInputNodeInfo = inputsInfo.get("mask_input");
+ final TensorInfo maskInputTensorInfo = (TensorInfo) maskInputNodeInfo.getInfo();
+ return maskInputTensorInfo.getShape();
+ }
+
+ public SamResult forward(OrtEnvironment env, OnnxTensor imageEmbedding,
+ float[][][] pointCoords, float[][] pointLabels,
+ OnnxTensor maskInput, boolean hasMaskInput,
+ int[] originalImageSize) throws OrtException {
+ LinkedHashMap inputs = new LinkedHashMap<>();
+ inputs.put("image_embeddings", imageEmbedding);
+ inputs.put("point_coords", OnnxTensor.createTensor(env, pointCoords));
+ inputs.put("point_labels", OnnxTensor.createTensor(env, pointLabels));
+ inputs.put("mask_input", maskInput);
+ inputs.put("has_mask_input", OnnxTensor.createTensor(env, new float[]{hasMaskInput ? 1.0f : 0.0f}));
+ inputs.put("orig_im_size", OnnxTensor.createTensor(env, new float[]{originalImageSize[0], originalImageSize[1]}));
+
+ OrtSession.Result result = model.run(inputs);
+
+ assert result.get("masks").isPresent();
+ assert result.get("iou_predictions").isPresent();
+ assert result.get("low_res_masks").isPresent();
+
+ // TODO: use cursed onnx api to verify OnnxValue rank instead of assuming
+ return new SamResult(
+ (float[][][][]) result.get("masks").get().getValue(),
+ (float[][]) result.get("iou_predictions").get().getValue(),
+ (float[][][][]) result.get("low_res_masks").get().getValue()
+ );
+ }
}
diff --git a/src/main/java/dev/troyer/sam/SamEncoder.java b/src/main/java/dev/troyer/sam/SamEncoder.java
index 74ac767..da506bf 100644
--- a/src/main/java/dev/troyer/sam/SamEncoder.java
+++ b/src/main/java/dev/troyer/sam/SamEncoder.java
@@ -1,32 +1,42 @@
package dev.troyer.sam;
import java.nio.file.Path;
+import java.util.HashSet;
import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
-import ai.onnxruntime.OnnxTensor;
-import ai.onnxruntime.OrtEnvironment;
-import ai.onnxruntime.OrtException;
-import ai.onnxruntime.OrtSession;
+import ai.onnxruntime.*;
class SamEncoder {
+ /**
+ * Loaded ONNX model
+ */
private final OrtSession model;
- private final int image_size;
+ private final int imageSize;
/**
- * @param model_path Path to the model file
- * @param image_size Size of the image the encoder model was trained on (usually 1024)
+ * @param env ONNX environment context
+ * @param modelPath Path to the model file
*/
- public SamEncoder(OrtEnvironment env, Path model_path, int image_size) throws OrtException {
- this.image_size = image_size;
- this.model = env.createSession(model_path.toString());
- }
+ public SamEncoder(OrtEnvironment env, Path modelPath) throws OrtException {
+ this.model = env.createSession(modelPath.toString());
- /**
- * @param model_path Path to the model file
- */
- public SamEncoder(OrtEnvironment env, Path model_path) throws OrtException {
- this(env, model_path, 1024);
+ // Verify inputs/outputs
+ // TODO: Exception instead of this
+ assert this.model.getInputNames().equals(
+ new HashSet<>(List.of("x"))
+ );
+
+ assert this.model.getInputNames().equals(
+ new HashSet<>(List.of("image_embeddings"))
+ );
+
+ Map inputsInfo = model.getInputInfo();
+ final NodeInfo imageNodeInfo = inputsInfo.get("x");
+ final TensorInfo imageTensorInfo = (TensorInfo) imageNodeInfo.getInfo();
+ this.imageSize = (int) imageTensorInfo.getShape()[2];
}
/**
diff --git a/src/main/java/dev/troyer/sam/SamImage.java b/src/main/java/dev/troyer/sam/SamImage.java
index 49783b3..e279bbb 100644
--- a/src/main/java/dev/troyer/sam/SamImage.java
+++ b/src/main/java/dev/troyer/sam/SamImage.java
@@ -1,19 +1,16 @@
package dev.troyer.sam;
import java.awt.image.BufferedImage;
-import java.io.File;
-import java.io.IOException;
-import javax.imageio.ImageIO;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
+import org.nd4j.enums.ImageResizeMethod;
+import org.nd4j.enums.Mode;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.enums.Mode;
-import org.nd4j.enums.ImageResizeMethod;
class SamImage {
@@ -26,17 +23,21 @@ class SamImage {
/**
* Original height, px
*/
- public final int original_height;
+ public final int originalHeight;
/**
* Original width, px
*/
- public final int original_width;
+ public final int originalWidth;
/**
* Image from BufferedImage
*/
public SamImage(BufferedImage img) {
+ this.originalHeight = img.getHeight();
+ this.originalWidth = img.getWidth();
+
+ // TODO: Don't hardcode 1024 image size, use detected size from loaded models
// TODO: support 16-bit images?
int[] pixels = img.getRGB(0, 0, img.getWidth(), img.getHeight(), null, 0, img.getWidth());
@@ -58,19 +59,15 @@ public SamImage(BufferedImage img) {
}
// Standardize the image (Z-score transform)
- INDArray mean = img_tensor.mean(0, 1, 2)
- .reshape(1, 1, 1, 3)
- .broadcast(1, img.getHeight(), img.getWidth(), 3);
+ INDArray mean = img_tensor.mean(0, 1, 2).reshape(1, 1, 1, 3).broadcast(1, img.getHeight(), img.getWidth(), 3);
- INDArray std = img_tensor.std(0, 1, 2)
- .reshape(1, 1, 1, 3)
- .broadcast(1, img.getHeight(), img.getWidth(), 3);
+ INDArray std = img_tensor.std(0, 1, 2).reshape(1, 1, 1, 3).broadcast(1, img.getHeight(), img.getWidth(), 3);
img_tensor.subi(mean).divi(std);
// Resize proportionally such that longest side is 1024 pixels
final double scale = 1024.0 / Math.max(img.getWidth(), img.getHeight());
- final int new_width = Math.min((int) ( img.getWidth() * scale + 0.5), 1024);
+ final int new_width = Math.min((int) (img.getWidth() * scale + 0.5), 1024);
final int new_height = Math.min((int) (img.getHeight() * scale + 0.5), 1024);
// why are there literally 5 different ways to resize an image with this library and why are 4 of them wrong
@@ -82,11 +79,8 @@ public SamImage(BufferedImage img) {
final int pad_height = Math.max(1024 - new_height, 0);
// docs were not clear on this __at all__, but it's the start/end padding for each axis, in order of dimension.
- INDArray padding = Nd4j.createFromArray(new int[][]{
- {0, 0}, // batch padding
- {0, pad_height},
- {0, pad_width},
- {0, 0} // pad channel
+ INDArray padding = Nd4j.createFromArray(new int[][]{{0, 0}, // batch padding
+ {0, pad_height}, {0, pad_width}, {0, 0} // pad channel
});
INDArray padded = Nd4j.image.pad(resized, padding, Mode.CONSTANT, 0.0);
@@ -99,17 +93,15 @@ public SamImage(BufferedImage img) {
// and into a double array you go (cursed way because NDArray cannot export 4D matrices apparently)
data = new float[1][3][1024][1024];
+ // TODO: figure out how INDArray stores its elements, access in a CPU cache friendly manner (if that even matters for JVM)
for (int i = 0; i < 1024; i++)
for (int j = 0; j < 1024; j++)
for (int k = 0; k < 3; k++)
this.data[0][k][i][j] = (float) padded.getDouble(0, i, j, k);
-
- this.original_height = img.getHeight();
- this.original_width = img.getWidth();
}
/**
- * Get the data for the image, post transfom and ready for encoding
+ * Get the data for the image, post transform and ready for encoding
*/
public OnnxTensor asTensor(OrtEnvironment env) throws OrtException {
return OnnxTensor.createTensor(env, data);
diff --git a/src/main/java/dev/troyer/sam/SamPredictor.java b/src/main/java/dev/troyer/sam/SamPredictor.java
index 42138a7..1984ec0 100644
--- a/src/main/java/dev/troyer/sam/SamPredictor.java
+++ b/src/main/java/dev/troyer/sam/SamPredictor.java
@@ -1,19 +1,137 @@
package dev.troyer.sam;
-import dev.troyer.sam.SamDecoder;
-import dev.troyer.sam.SamEncoder;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.LinkedHashMap;
+import javax.imageio.ImageIO;
-import ai.onnxruntime.OrtSession;
+import ai.onnxruntime.OnnxTensor;
+import ai.onnxruntime.OrtEnvironment;
+import ai.onnxruntime.OrtException;
public class SamPredictor {
- // encoder model
- private SamEncoder encoder;
+ /**
+ * Image that is currently selected
+ */
+ private SamImage image;
- // decoder model
- private SamDecoder decoder;
+ /**
+ * Image embedding (1x256x64x64)
+ */
+ // TODO: Couple SamImage, this embedding, and SamEncoder somehow
+ private OnnxTensor imageEmbedding;
- // Constructor
- public SamPredictor() {
+ /**
+ * Encoder model handler
+ */
+ final private SamEncoder encoder;
+ /**
+ * Decoder model handler
+ */
+ final private SamDecoder decoder;
+
+ /**
+ * ONNX environment context
+ */
+ final private OrtEnvironment env;
+
+ /**
+ * Constructor
+ */
+ public SamPredictor(Path encoderModelPath, Path decoderModelPath) throws OrtException {
+ env = OrtEnvironment.getEnvironment();
+
+ encoder = new SamEncoder(env, encoderModelPath);
+ decoder = new SamDecoder(env, decoderModelPath);
+
+ // TODO: assert encoder.outputs.image_embeddings.shape == decoder.inputs.embeddings.shape
+ }
+
+ /**
+ * Main interface, probably the most common use case
+ *
+ * @param constraints Queries/constraints to the model, given current image
+ * @return Masks, TBD, TBD, in the form of a SamResult
+ */
+ public SamResult predict(SamConstraint[] constraints) throws OrtException {
+ assert image != null;
+
+ // Unpack constraints into form usable by the model
+ float[][][] pointCoords = new float[1][constraints.length][2];
+ float[][] pointLabels = new float[1][constraints.length];
+
+ for (int iConstraint = 0; iConstraint < constraints.length; ++iConstraint) {
+ SamConstraint thisConstraint = constraints[iConstraint];
+
+ // TODO: confirm ordering
+ pointCoords[0][iConstraint][0] = (float) thisConstraint.x;
+ pointCoords[0][iConstraint][1] = (float) thisConstraint.y;
+
+ pointLabels[0][iConstraint] = (float) thisConstraint.type.ordinal();
+ }
+
+ // TODO: confirm ordering of h/w
+ final int[] originalImageSize = new int[]{image.originalWidth, image.originalHeight};
+
+ // ONNX model at the moment always expects a mask input.
+ // I think it's because ONNX models' parameters might be fixed, meaning optional
+ // items must still be populated (and tagged separately with a bool).
+ // items don't need explicitly set https://docs.oracle.com/javase/specs/jls/se17/html/jls-4.html#jls-4.12.5
+
+ final long[] maskInputShape = this.decoder.getMaskInputsShape();
+ final OnnxTensor maskInput = OnnxTensor.createTensor(
+ env,
+ new float[(int) maskInputShape[0]][(int) maskInputShape[1]][(int) maskInputShape[2]][(int) maskInputShape[3]]
+ );
+
+ // TODO: return
+ return predict(pointCoords, pointLabels, maskInput, false, originalImageSize);
+ }
+
+ /**
+ * Set current image by file path
+ *
+ * @param imagePath Path to the image. Isn't validated or anything, assumes it's an image
+ */
+ public void setImageFromPath(Path imagePath) throws IOException, OrtException {
+ // TODO: Error handling
+ image = new SamImage(ImageIO.read(new File(imagePath.toString())));
+ onImageUpdate();
+ }
+
+ /**
+ * Predict masks for given input prompts.
+ * This is the full set of inputs to the decoder, rarely to be used directly
+ *
+ * @param pointCoords Nx2 array of point prompts to the model. Each point is in pixel space.
+ * @param pointLabels Nx1 array of labels for each point in pointCoords. 1 indicates foreground,
+ * 0 indicates background.
+ * @param maskInput 1x1x256x256 Low res mask input to the model, usually from a previous prediciton iteration.
+ * @param hasMaskInput true if maskInput is to be used
+ * @param originalImageSize original image dimensions
+ * @return Masks, TBD, TBD, in the form of a SamResult
+ */
+ private SamResult predict(float[][][] pointCoords, float[][] pointLabels, OnnxTensor maskInput, boolean hasMaskInput, int[] originalImageSize) throws OrtException {
+ // Prep native parameters to be parameters to the model
+ return this.decoder.forward(
+ env,
+ imageEmbedding,
+ pointCoords,
+ pointLabels,
+ maskInput,
+ hasMaskInput,
+ originalImageSize
+ );
+ }
+
+ /**
+ * Should run when this->image updates. Re-encodes the image for efficient prompting later.
+ */
+ private void onImageUpdate() throws OrtException {
+ var forwardResult = encoder.forward(env, image);
+ assert forwardResult.get("image_embeddings").isPresent();
+ imageEmbedding = (OnnxTensor) forwardResult.get("image_embeddings").get();
}
}
diff --git a/src/main/java/dev/troyer/sam/SamResult.java b/src/main/java/dev/troyer/sam/SamResult.java
new file mode 100644
index 0000000..a998c9f
--- /dev/null
+++ b/src/main/java/dev/troyer/sam/SamResult.java
@@ -0,0 +1,33 @@
+package dev.troyer.sam;
+
+/**
+ * Represents a result of a query to the segment-anything model as a whole
+ */
+public class SamResult {
+ /**
+ * Mask(s) returned by the model
+ */
+ final public float[][][][] masks;
+
+ /**
+ * iou_predictions
+ */
+ final public float[][] iouPredictions;
+
+ /**
+ * low_res_masks
+ */
+ final public float[][][][] lowResMasks;
+
+ /**
+ * Constructor, initializes members
+ * @param masks Mask(s) returned by the model
+ * @param iouPredictions iou_predictions
+ * @param lowResMasks low_res_masks
+ */
+ public SamResult(float[][][][] masks, float[][] iouPredictions, float[][][][] lowResMasks) {
+ this.masks = masks;
+ this.iouPredictions = iouPredictions;
+ this.lowResMasks = lowResMasks;
+ }
+}
diff --git a/test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg b/test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg
new file mode 100644
index 0000000..28ee7b8
Binary files /dev/null and b/test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg differ