diff --git a/.gitignore b/.gitignore index b63da45..4c36a00 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ build/ .idea/jarRepositories.xml .idea/compiler.xml .idea/libraries/ +.idea *.iws *.iml *.ipr @@ -39,4 +40,4 @@ bin/ .vscode/ ### Mac OS ### -.DS_Store \ No newline at end of file +.DS_Store diff --git a/.idea/gradle.xml b/.idea/gradle.xml deleted file mode 100644 index 54ea192..0000000 --- a/.idea/gradle.xml +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index b0137f1..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/workspace.xml b/.idea/workspace.xml deleted file mode 100644 index ebe296d..0000000 --- a/.idea/workspace.xml +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 1683078027782 - - - - \ No newline at end of file diff --git a/README.md b/README.md index 6c7e6e8..ee0e3a7 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,18 @@ -# sam-java +# segment-anything.java Meta's [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) model, ported to Java SE 17. -Eventual goal is to move this into a [JOSM](https://josm.openstreetmap.de/) [plugin](https://github.com/JOSM/josm-plugins) for segmentation of aerial imagery for OpenStreetMap. Further versions from that will be [finetuned](https://github.com/ctrlaltf2/segment-any-landuse) on aerial imagery for better results. +Eventual goal is to use this in a [JOSM](https://josm.openstreetmap.de/) [plugin](https://github.com/JOSM/josm-plugins) for segmentation of aerial imagery for OpenStreetMap. Further versions from that will be [finetuned](https://github.com/ctrlaltf2/segment-any-landuse) on aerial imagery for better results. + +## Roadmap + - [x] Reproduce ONNX export of encoder and decoder (ref: https://github.com/visheratin/segment-anything) + - [x] Image loading and preprocessing + - [x] Encoder forward pass + - [x] OnnxTensor to primitive matrix type conversion + - [x] Decoder forward pass, basic coordinate-based prompt + - [ ] Decoder mask post-processing (mapping back to the input image) + - [ ] Multi-mask support (currently T/F value is baked into the ONNX model at export-time) + - [ ] Error handling TODOs + - [ ] Improve upon ONNX export method + - [ ] Figure out a place to host the ONNX models (~5 GB total) + - [ ] FastSAM? It's such a different model though architecturally, might be out of the scope of this repository. + - [ ] GPU runtimes? diff --git a/build.gradle.kts b/build.gradle.kts index d4b384c..5269645 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -12,7 +12,7 @@ repositories { dependencies { testImplementation(platform("org.junit:junit-bom:5.9.1")) testImplementation("org.junit.jupiter:junit-jupiter") - implementation("com.microsoft.onnxruntime:onnxruntime:1.14.0") + implementation("com.microsoft.onnxruntime:onnxruntime:1.16.3") implementation("org.nd4j:nd4j:1.0.0-M2.1") implementation("org.nd4j:nd4j-native-platform:1.0.0-M2.1") } diff --git a/src/main/java/dev/troyer/sam/ConstraintType.java b/src/main/java/dev/troyer/sam/ConstraintType.java new file mode 100644 index 0000000..8bd584a --- /dev/null +++ b/src/main/java/dev/troyer/sam/ConstraintType.java @@ -0,0 +1,10 @@ +package dev.troyer.sam; + +/** + * Enum for representing what type of constraint is applied to a + * decoder model query point + */ +public enum ConstraintType { + BACKGROUND, + FOREGROUND +} diff --git a/src/main/java/dev/troyer/sam/Main.java b/src/main/java/dev/troyer/sam/Main.java index d30116a..f3e740b 100644 --- a/src/main/java/dev/troyer/sam/Main.java +++ b/src/main/java/dev/troyer/sam/Main.java @@ -1,37 +1,23 @@ package dev.troyer.sam; -import java.awt.image.BufferedImage; -import java.io.File; +import ai.onnxruntime.OrtException; + import java.io.IOException; import java.nio.file.Path; -import javax.imageio.ImageIO; - -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; public class Main { - public static void main(String[] args) throws OrtException { - String imagePath = "/home/caleb/Pictures/example.jpg"; - BufferedImage image = null; - - try { - image = ImageIO.read(new File(imagePath)); - } catch (IOException e) { - System.err.println("Error while loading image: " + e.getMessage()); - return; - } + public static void main(String[] args) throws IOException, OrtException { + SamPredictor predictor = new SamPredictor( + Path.of("./src/main/resources/data/vit_b/encoder-vit_b.quant.onnx"), + Path.of("./src/main/resources/data/vit_b/decoder-vit_b.quant.onnx") + ); - OrtEnvironment env = OrtEnvironment.getEnvironment(); - SamImage samImage = new SamImage(image); + predictor.setImageFromPath( + Path.of("./test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg") + ); - try { - Path model_path = Path.of("./src/main/resources/data/vit_b/encoder-vit_b.quant.onnx"); - SamEncoder encoder = new SamEncoder(env, model_path); - var out = encoder.forward(env, samImage); - System.out.println(out); - } catch (OrtException e) { - System.err.println("Error while loading model: " + e.getMessage()); - return; - } + final SamResult result = predictor.predict(new SamConstraint[]{ + new SamConstraint(1980, 1200, ConstraintType.FOREGROUND) + }); } } diff --git a/src/main/java/dev/troyer/sam/SamConstraint.java b/src/main/java/dev/troyer/sam/SamConstraint.java new file mode 100644 index 0000000..4d9e9cc --- /dev/null +++ b/src/main/java/dev/troyer/sam/SamConstraint.java @@ -0,0 +1,29 @@ +package dev.troyer.sam; + +/** + * Represents a constraint to the SAM decoder model. + * A constraint is the foreground/background parts you select in + * all those nice Segment Anything frontends. + */ +public class SamConstraint { + /** + * Column in image (pixel space). Origin is top left. + */ + final public int x; + + /** + * Row in image (pixel space). Origin is top left. + */ + final public int y; + + /** + * Type of constraint (foreground/background) + */ + final public ConstraintType type; + + public SamConstraint(int x, int y, ConstraintType type) { + this.x = x; + this.y = y; + this.type = type; + } +} diff --git a/src/main/java/dev/troyer/sam/SamDecoder.java b/src/main/java/dev/troyer/sam/SamDecoder.java index be12d6a..82d9a1c 100644 --- a/src/main/java/dev/troyer/sam/SamDecoder.java +++ b/src/main/java/dev/troyer/sam/SamDecoder.java @@ -1,4 +1,60 @@ package dev.troyer.sam; +import ai.onnxruntime.*; + +import java.nio.file.Path; +import java.util.LinkedHashMap; +import java.util.Map; + public class SamDecoder { + /** + * Loaded ONNX model + */ + private final OrtSession model; + + /** + * @param env ONNX environment context + * @param modelPath Path to the model file + */ + public SamDecoder(OrtEnvironment env, Path modelPath) throws OrtException { + this.model = env.createSession(modelPath.toString()); + // TODO: Verify input/output names. These could vary based on how/who exported the model. + } + + /** + * @return shape of the mask inputs parameter + */ + public long[] getMaskInputsShape() throws OrtException { + // TODO: better error handling, don't assume output with this name exists + final Map inputsInfo = model.getInputInfo(); + final NodeInfo maskInputNodeInfo = inputsInfo.get("mask_input"); + final TensorInfo maskInputTensorInfo = (TensorInfo) maskInputNodeInfo.getInfo(); + return maskInputTensorInfo.getShape(); + } + + public SamResult forward(OrtEnvironment env, OnnxTensor imageEmbedding, + float[][][] pointCoords, float[][] pointLabels, + OnnxTensor maskInput, boolean hasMaskInput, + int[] originalImageSize) throws OrtException { + LinkedHashMap inputs = new LinkedHashMap<>(); + inputs.put("image_embeddings", imageEmbedding); + inputs.put("point_coords", OnnxTensor.createTensor(env, pointCoords)); + inputs.put("point_labels", OnnxTensor.createTensor(env, pointLabels)); + inputs.put("mask_input", maskInput); + inputs.put("has_mask_input", OnnxTensor.createTensor(env, new float[]{hasMaskInput ? 1.0f : 0.0f})); + inputs.put("orig_im_size", OnnxTensor.createTensor(env, new float[]{originalImageSize[0], originalImageSize[1]})); + + OrtSession.Result result = model.run(inputs); + + assert result.get("masks").isPresent(); + assert result.get("iou_predictions").isPresent(); + assert result.get("low_res_masks").isPresent(); + + // TODO: use cursed onnx api to verify OnnxValue rank instead of assuming + return new SamResult( + (float[][][][]) result.get("masks").get().getValue(), + (float[][]) result.get("iou_predictions").get().getValue(), + (float[][][][]) result.get("low_res_masks").get().getValue() + ); + } } diff --git a/src/main/java/dev/troyer/sam/SamEncoder.java b/src/main/java/dev/troyer/sam/SamEncoder.java index 74ac767..da506bf 100644 --- a/src/main/java/dev/troyer/sam/SamEncoder.java +++ b/src/main/java/dev/troyer/sam/SamEncoder.java @@ -1,32 +1,42 @@ package dev.troyer.sam; import java.nio.file.Path; +import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; -import ai.onnxruntime.OnnxTensor; -import ai.onnxruntime.OrtEnvironment; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; +import ai.onnxruntime.*; class SamEncoder { + /** + * Loaded ONNX model + */ private final OrtSession model; - private final int image_size; + private final int imageSize; /** - * @param model_path Path to the model file - * @param image_size Size of the image the encoder model was trained on (usually 1024) + * @param env ONNX environment context + * @param modelPath Path to the model file */ - public SamEncoder(OrtEnvironment env, Path model_path, int image_size) throws OrtException { - this.image_size = image_size; - this.model = env.createSession(model_path.toString()); - } + public SamEncoder(OrtEnvironment env, Path modelPath) throws OrtException { + this.model = env.createSession(modelPath.toString()); - /** - * @param model_path Path to the model file - */ - public SamEncoder(OrtEnvironment env, Path model_path) throws OrtException { - this(env, model_path, 1024); + // Verify inputs/outputs + // TODO: Exception instead of this + assert this.model.getInputNames().equals( + new HashSet<>(List.of("x")) + ); + + assert this.model.getInputNames().equals( + new HashSet<>(List.of("image_embeddings")) + ); + + Map inputsInfo = model.getInputInfo(); + final NodeInfo imageNodeInfo = inputsInfo.get("x"); + final TensorInfo imageTensorInfo = (TensorInfo) imageNodeInfo.getInfo(); + this.imageSize = (int) imageTensorInfo.getShape()[2]; } /** diff --git a/src/main/java/dev/troyer/sam/SamImage.java b/src/main/java/dev/troyer/sam/SamImage.java index 49783b3..e279bbb 100644 --- a/src/main/java/dev/troyer/sam/SamImage.java +++ b/src/main/java/dev/troyer/sam/SamImage.java @@ -1,19 +1,16 @@ package dev.troyer.sam; import java.awt.image.BufferedImage; -import java.io.File; -import java.io.IOException; -import javax.imageio.ImageIO; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; +import org.nd4j.enums.ImageResizeMethod; +import org.nd4j.enums.Mode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.enums.Mode; -import org.nd4j.enums.ImageResizeMethod; class SamImage { @@ -26,17 +23,21 @@ class SamImage { /** * Original height, px */ - public final int original_height; + public final int originalHeight; /** * Original width, px */ - public final int original_width; + public final int originalWidth; /** * Image from BufferedImage */ public SamImage(BufferedImage img) { + this.originalHeight = img.getHeight(); + this.originalWidth = img.getWidth(); + + // TODO: Don't hardcode 1024 image size, use detected size from loaded models // TODO: support 16-bit images? int[] pixels = img.getRGB(0, 0, img.getWidth(), img.getHeight(), null, 0, img.getWidth()); @@ -58,19 +59,15 @@ public SamImage(BufferedImage img) { } // Standardize the image (Z-score transform) - INDArray mean = img_tensor.mean(0, 1, 2) - .reshape(1, 1, 1, 3) - .broadcast(1, img.getHeight(), img.getWidth(), 3); + INDArray mean = img_tensor.mean(0, 1, 2).reshape(1, 1, 1, 3).broadcast(1, img.getHeight(), img.getWidth(), 3); - INDArray std = img_tensor.std(0, 1, 2) - .reshape(1, 1, 1, 3) - .broadcast(1, img.getHeight(), img.getWidth(), 3); + INDArray std = img_tensor.std(0, 1, 2).reshape(1, 1, 1, 3).broadcast(1, img.getHeight(), img.getWidth(), 3); img_tensor.subi(mean).divi(std); // Resize proportionally such that longest side is 1024 pixels final double scale = 1024.0 / Math.max(img.getWidth(), img.getHeight()); - final int new_width = Math.min((int) ( img.getWidth() * scale + 0.5), 1024); + final int new_width = Math.min((int) (img.getWidth() * scale + 0.5), 1024); final int new_height = Math.min((int) (img.getHeight() * scale + 0.5), 1024); // why are there literally 5 different ways to resize an image with this library and why are 4 of them wrong @@ -82,11 +79,8 @@ public SamImage(BufferedImage img) { final int pad_height = Math.max(1024 - new_height, 0); // docs were not clear on this __at all__, but it's the start/end padding for each axis, in order of dimension. - INDArray padding = Nd4j.createFromArray(new int[][]{ - {0, 0}, // batch padding - {0, pad_height}, - {0, pad_width}, - {0, 0} // pad channel + INDArray padding = Nd4j.createFromArray(new int[][]{{0, 0}, // batch padding + {0, pad_height}, {0, pad_width}, {0, 0} // pad channel }); INDArray padded = Nd4j.image.pad(resized, padding, Mode.CONSTANT, 0.0); @@ -99,17 +93,15 @@ public SamImage(BufferedImage img) { // and into a double array you go (cursed way because NDArray cannot export 4D matrices apparently) data = new float[1][3][1024][1024]; + // TODO: figure out how INDArray stores its elements, access in a CPU cache friendly manner (if that even matters for JVM) for (int i = 0; i < 1024; i++) for (int j = 0; j < 1024; j++) for (int k = 0; k < 3; k++) this.data[0][k][i][j] = (float) padded.getDouble(0, i, j, k); - - this.original_height = img.getHeight(); - this.original_width = img.getWidth(); } /** - * Get the data for the image, post transfom and ready for encoding + * Get the data for the image, post transform and ready for encoding */ public OnnxTensor asTensor(OrtEnvironment env) throws OrtException { return OnnxTensor.createTensor(env, data); diff --git a/src/main/java/dev/troyer/sam/SamPredictor.java b/src/main/java/dev/troyer/sam/SamPredictor.java index 42138a7..1984ec0 100644 --- a/src/main/java/dev/troyer/sam/SamPredictor.java +++ b/src/main/java/dev/troyer/sam/SamPredictor.java @@ -1,19 +1,137 @@ package dev.troyer.sam; -import dev.troyer.sam.SamDecoder; -import dev.troyer.sam.SamEncoder; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.LinkedHashMap; +import javax.imageio.ImageIO; -import ai.onnxruntime.OrtSession; +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; public class SamPredictor { - // encoder model - private SamEncoder encoder; + /** + * Image that is currently selected + */ + private SamImage image; - // decoder model - private SamDecoder decoder; + /** + * Image embedding (1x256x64x64) + */ + // TODO: Couple SamImage, this embedding, and SamEncoder somehow + private OnnxTensor imageEmbedding; - // Constructor - public SamPredictor() { + /** + * Encoder model handler + */ + final private SamEncoder encoder; + /** + * Decoder model handler + */ + final private SamDecoder decoder; + + /** + * ONNX environment context + */ + final private OrtEnvironment env; + + /** + * Constructor + */ + public SamPredictor(Path encoderModelPath, Path decoderModelPath) throws OrtException { + env = OrtEnvironment.getEnvironment(); + + encoder = new SamEncoder(env, encoderModelPath); + decoder = new SamDecoder(env, decoderModelPath); + + // TODO: assert encoder.outputs.image_embeddings.shape == decoder.inputs.embeddings.shape + } + + /** + * Main interface, probably the most common use case + * + * @param constraints Queries/constraints to the model, given current image + * @return Masks, TBD, TBD, in the form of a SamResult + */ + public SamResult predict(SamConstraint[] constraints) throws OrtException { + assert image != null; + + // Unpack constraints into form usable by the model + float[][][] pointCoords = new float[1][constraints.length][2]; + float[][] pointLabels = new float[1][constraints.length]; + + for (int iConstraint = 0; iConstraint < constraints.length; ++iConstraint) { + SamConstraint thisConstraint = constraints[iConstraint]; + + // TODO: confirm ordering + pointCoords[0][iConstraint][0] = (float) thisConstraint.x; + pointCoords[0][iConstraint][1] = (float) thisConstraint.y; + + pointLabels[0][iConstraint] = (float) thisConstraint.type.ordinal(); + } + + // TODO: confirm ordering of h/w + final int[] originalImageSize = new int[]{image.originalWidth, image.originalHeight}; + + // ONNX model at the moment always expects a mask input. + // I think it's because ONNX models' parameters might be fixed, meaning optional + // items must still be populated (and tagged separately with a bool). + // items don't need explicitly set https://docs.oracle.com/javase/specs/jls/se17/html/jls-4.html#jls-4.12.5 + + final long[] maskInputShape = this.decoder.getMaskInputsShape(); + final OnnxTensor maskInput = OnnxTensor.createTensor( + env, + new float[(int) maskInputShape[0]][(int) maskInputShape[1]][(int) maskInputShape[2]][(int) maskInputShape[3]] + ); + + // TODO: return + return predict(pointCoords, pointLabels, maskInput, false, originalImageSize); + } + + /** + * Set current image by file path + * + * @param imagePath Path to the image. Isn't validated or anything, assumes it's an image + */ + public void setImageFromPath(Path imagePath) throws IOException, OrtException { + // TODO: Error handling + image = new SamImage(ImageIO.read(new File(imagePath.toString()))); + onImageUpdate(); + } + + /** + * Predict masks for given input prompts. + * This is the full set of inputs to the decoder, rarely to be used directly + * + * @param pointCoords Nx2 array of point prompts to the model. Each point is in pixel space. + * @param pointLabels Nx1 array of labels for each point in pointCoords. 1 indicates foreground, + * 0 indicates background. + * @param maskInput 1x1x256x256 Low res mask input to the model, usually from a previous prediciton iteration. + * @param hasMaskInput true if maskInput is to be used + * @param originalImageSize original image dimensions + * @return Masks, TBD, TBD, in the form of a SamResult + */ + private SamResult predict(float[][][] pointCoords, float[][] pointLabels, OnnxTensor maskInput, boolean hasMaskInput, int[] originalImageSize) throws OrtException { + // Prep native parameters to be parameters to the model + return this.decoder.forward( + env, + imageEmbedding, + pointCoords, + pointLabels, + maskInput, + hasMaskInput, + originalImageSize + ); + } + + /** + * Should run when this->image updates. Re-encodes the image for efficient prompting later. + */ + private void onImageUpdate() throws OrtException { + var forwardResult = encoder.forward(env, image); + assert forwardResult.get("image_embeddings").isPresent(); + imageEmbedding = (OnnxTensor) forwardResult.get("image_embeddings").get(); } } diff --git a/src/main/java/dev/troyer/sam/SamResult.java b/src/main/java/dev/troyer/sam/SamResult.java new file mode 100644 index 0000000..a998c9f --- /dev/null +++ b/src/main/java/dev/troyer/sam/SamResult.java @@ -0,0 +1,33 @@ +package dev.troyer.sam; + +/** + * Represents a result of a query to the segment-anything model as a whole + */ +public class SamResult { + /** + * Mask(s) returned by the model + */ + final public float[][][][] masks; + + /** + * iou_predictions + */ + final public float[][] iouPredictions; + + /** + * low_res_masks + */ + final public float[][][][] lowResMasks; + + /** + * Constructor, initializes members + * @param masks Mask(s) returned by the model + * @param iouPredictions iou_predictions + * @param lowResMasks low_res_masks + */ + public SamResult(float[][][][] masks, float[][] iouPredictions, float[][][][] lowResMasks) { + this.masks = masks; + this.iouPredictions = iouPredictions; + this.lowResMasks = lowResMasks; + } +} diff --git a/test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg b/test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg new file mode 100644 index 0000000..28ee7b8 Binary files /dev/null and b/test/assets/mingjun-liu-mVWqCdTHfxs-unsplash.jpg differ