Add mini web demo

This commit is contained in:
Nikhila Ravi
2023-04-11 10:37:41 -07:00
parent 7fa17d78c4
commit 426edb2ed3
23 changed files with 766 additions and 0 deletions

View File

@@ -0,0 +1,43 @@
import React, { useContext } from "react";
import * as _ from "underscore";
import Tool from "./Tool";
import { modelInputProps } from "./helpers/Interfaces";
import AppContext from "./hooks/createContext";
const Stage = () => {
const {
clicks: [, setClicks],
image: [image],
} = useContext(AppContext)!;
const getClick = (x: number, y: number): modelInputProps => {
const clickType = 1;
return { x, y, clickType };
};
// Get mouse position and scale the (x, y) coordinates back to the natural
// scale of the image. Update the state of clicks with setClicks to trigger
// the ONNX model to run and generate a new mask via a useEffect in App.tsx
const handleMouseMove = _.throttle((e: any) => {
let el = e.nativeEvent.target;
const rect = el.getBoundingClientRect();
let x = e.clientX - rect.left;
let y = e.clientY - rect.top;
const imageScale = image ? image.width / el.offsetWidth : 1;
x *= imageScale;
y *= imageScale;
const click = getClick(x, y);
if (click) setClicks([click]);
}, 15);
const flexCenterClasses = "flex items-center justify-center";
return (
<div className={`${flexCenterClasses} w-full h-full`}>
<div className={`${flexCenterClasses} relative w-[90%] h-[90%]`}>
<Tool handleMouseMove={handleMouseMove} />
</div>
</div>
);
};
export default Stage;

View File

@@ -0,0 +1,67 @@
import React, { useContext, useEffect, useState } from "react";
import AppContext from "./hooks/createContext";
import { ToolProps } from "./helpers/Interfaces";
import * as _ from "underscore";
const Tool = ({ handleMouseMove }: ToolProps) => {
const {
image: [image],
maskImg: [maskImg, setMaskImg],
} = useContext(AppContext)!;
// Determine if we should shrink or grow the images to match the
// width or the height of the page and setup a ResizeObserver to
// monitor changes in the size of the page
const [shouldFitToWidth, setShouldFitToWidth] = useState(true);
const bodyEl = document.body;
const fitToPage = () => {
if (!image) return;
const imageAspectRatio = image.width / image.height;
const screenAspectRatio = window.innerWidth / window.innerHeight;
setShouldFitToWidth(imageAspectRatio > screenAspectRatio);
};
const resizeObserver = new ResizeObserver((entries) => {
for (const entry of entries) {
if (entry.target === bodyEl) {
fitToPage();
}
}
});
useEffect(() => {
fitToPage();
resizeObserver.observe(bodyEl);
return () => {
resizeObserver.unobserve(bodyEl);
};
}, [image]);
const imageClasses = "";
const maskImageClasses = `absolute opacity-40 pointer-events-none`;
// Render the image and the predicted mask image on top
return (
<>
{image && (
<img
onMouseMove={handleMouseMove}
onMouseOut={() => _.defer(() => setMaskImg(null))}
onTouchStart={handleMouseMove}
src={image.src}
className={`${
shouldFitToWidth ? "w-full" : "h-full"
} ${imageClasses}`}
></img>
)}
{maskImg && (
<img
src={maskImg.src}
className={`${
shouldFitToWidth ? "w-full" : "h-full"
} ${maskImageClasses}`}
></img>
)}
</>
);
};
export default Tool;

View File

@@ -0,0 +1,23 @@
import { Tensor } from "onnxruntime-web";
export interface modelScaleProps {
samScale: number;
height: number;
width: number;
}
export interface modelInputProps {
x: number;
y: number;
clickType: number;
}
export interface modeDataProps {
clicks?: Array<modelInputProps>;
tensor: Tensor;
modelScale: modelScaleProps;
}
export interface ToolProps {
handleMouseMove: (e: any) => void;
}

View File

@@ -0,0 +1,43 @@
// Functions for handling mask output from the ONNX model
// Convert the onnx model mask prediction to ImageData
function arrayToImageData(input: any, width: number, height: number) {
const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color
const arr = new Uint8ClampedArray(4 * width * height).fill(0);
for (let i = 0; i < input.length; i++) {
// Threshold the onnx model mask prediction at 0.0
// This is equivalent to thresholding the mask using predictor.model.mask_threshold
// in python
if (input[i] > 0.0) {
arr[4 * i + 0] = r;
arr[4 * i + 1] = g;
arr[4 * i + 2] = b;
arr[4 * i + 3] = a;
}
}
return new ImageData(arr, height, width);
}
// Use a Canvas element to produce an image from ImageData
function imageDataToImage(imageData: ImageData) {
const canvas = imageDataToCanvas(imageData);
const image = new Image();
image.src = canvas.toDataURL();
return image;
}
// Canvas elements can be created from ImageData
function imageDataToCanvas(imageData: ImageData) {
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
canvas.width = imageData.width;
canvas.height = imageData.height;
ctx?.putImageData(imageData, 0, 0);
return canvas;
}
// Convert the onnx model mask output to an HTMLImageElement
export function onnxMaskToImage(input: any, width: number, height: number) {
return imageDataToImage(arrayToImageData(input, width, height));
}

View File

@@ -0,0 +1,65 @@
import { Tensor } from "onnxruntime-web";
import { modeDataProps } from "./Interfaces";
const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => {
const imageEmbedding = tensor;
let pointCoords;
let pointLabels;
let pointCoordsTensor;
let pointLabelsTensor;
// Check there are input click prompts
if (clicks) {
let n = clicks.length;
// If there is no box input, a single padding point with
// label -1 and coordinates (0.0, 0.0) should be concatenated
// so initialize the array to support (n + 1) points.
pointCoords = new Float32Array(2 * (n + 1));
pointLabels = new Float32Array(n + 1);
// Add clicks and scale to what SAM expects
for (let i = 0; i < n; i++) {
pointCoords[2 * i] = clicks[i].x * modelScale.samScale;
pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale;
pointLabels[i] = clicks[i].clickType;
}
// Add in the extra point/label when only clicks and no box
// The extra point is at (0, 0) with label -1
pointCoords[2 * n] = 0.0;
pointCoords[2 * n + 1] = 0.0;
pointLabels[n] = -1.0;
// Create the tensor
pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]);
pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]);
}
const imageSizeTensor = new Tensor("float32", [
modelScale.height,
modelScale.width,
]);
if (pointCoordsTensor === undefined || pointLabelsTensor === undefined)
return;
// There is no previous mask, so default to an empty tensor
const maskInput = new Tensor(
"float32",
new Float32Array(256 * 256),
[1, 1, 256, 256]
);
// There is no previous mask, so default to 0
const hasMaskInput = new Tensor("float32", [0]);
return {
image_embeddings: imageEmbedding,
point_coords: pointCoordsTensor,
point_labels: pointLabelsTensor,
orig_im_size: imageSizeTensor,
mask_input: maskInput,
has_mask_input: hasMaskInput,
};
};
export { modelData };

View File

@@ -0,0 +1,12 @@
// Helper function for handling image scaling needed for SAM
const handleImageScale = (image: HTMLImageElement) => {
// Input images to SAM must be resized so the longest side is 1024
const LONG_SIDE_LENGTH = 1024;
let w = image.naturalWidth;
let h = image.naturalHeight;
const samScale = LONG_SIDE_LENGTH / Math.max(h, w);
return { height: h, width: w, samScale };
};
export { handleImageScale };

View File

@@ -0,0 +1,25 @@
import React, { useState } from "react";
import { modelInputProps } from "../helpers/Interfaces";
import AppContext from "./createContext";
const AppContextProvider = (props: {
children: React.ReactElement<any, string | React.JSXElementConstructor<any>>;
}) => {
const [clicks, setClicks] = useState<Array<modelInputProps> | null>(null);
const [image, setImage] = useState<HTMLImageElement | null>(null);
const [maskImg, setMaskImg] = useState<HTMLImageElement | null>(null);
return (
<AppContext.Provider
value={{
clicks: [clicks, setClicks],
image: [image, setImage],
maskImg: [maskImg, setMaskImg],
}}
>
{props.children}
</AppContext.Provider>
);
};
export default AppContextProvider;

View File

@@ -0,0 +1,21 @@
import { createContext } from "react";
import { modelInputProps } from "../helpers/Interfaces";
interface contextProps {
clicks: [
clicks: modelInputProps[] | null,
setClicks: (e: modelInputProps[] | null) => void
];
image: [
image: HTMLImageElement | null,
setImage: (e: HTMLImageElement | null) => void
];
maskImg: [
maskImg: HTMLImageElement | null,
setMaskImg: (e: HTMLImageElement | null) => void
];
}
const AppContext = createContext<contextProps | null>(null);
export default AppContext;