Skip to content

Commit

Permalink
Merge pull request #557 from roboflow/sam2
Browse files Browse the repository at this point in the history
SAM2
  • Loading branch information
probicheaux committed Aug 2, 2024
2 parents 13295b6 + c718be0 commit ecdfe11
Show file tree
Hide file tree
Showing 26 changed files with 1,192 additions and 2 deletions.
18 changes: 17 additions & 1 deletion .github/workflows/test.nvidia_t4.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ jobs:
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements/requirements.test.integration.txt
- name: 🔨 Build and Push Test Docker - SAM2
run: |
docker build -t roboflow/roboflow-inference-server-sam2:test -f docker/dockerfiles/Dockerfile.sam2 .
- name: 🔋 Start Test Docker - SAM2
run: |
PORT=9101 INFERENCE_SERVER_REPO=roboflow-inference-server-sam2 make start_test_docker_gpu
- name: 🧪 Regression Tests - SAM2
id: sam2_tests
run: |
PORT=9101 API_KEY=${{ secrets.API_KEY }} SKIP_SAM2_TESTS=False python3 -m pytest tests/inference/integration_tests/test_sam2.py
- name: 🚨 Show server logs on error
run: docker logs inference-test
if: ${{ steps.sam2_tests.outcome != 'success' }}
- name: 🧹 Cleanup Test Docker - SAM2
run: make stop_test_docker
if: success() || failure()
- name: 🔨 Build and Push Test Docker - GPU
run: |
docker build -t roboflow/roboflow-inference-server-gpu:test -f docker/dockerfiles/Dockerfile.onnx.gpu .
Expand Down Expand Up @@ -77,4 +93,4 @@ jobs:
if: ${{ steps.florence_tests.outcome != 'success' }}
- name: 🧹 Cleanup Test Docker - GPU
run: make stop_test_docker
if: success() || failure()
if: success() || failure()
1 change: 1 addition & 0 deletions .release/pypi/inference.core.setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def read_requirements(path):
"hosted": read_requirements("requirements/requirements.hosted.txt"),
"http": read_requirements("requirements/requirements.http.txt"),
"sam": read_requirements("requirements/requirements.sam.txt"),
"sam2": read_requirements("requirements/requirements.sam2.txt"),
"waf": read_requirements("requirements/requirements.waf.txt"),
"yolo-world": read_requirements("requirements/requirements.yolo_world.txt"),
},
Expand Down
1 change: 1 addition & 0 deletions .release/pypi/inference.cpu.setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def read_requirements(path):
"hosted": read_requirements("requirements/requirements.hosted.txt"),
"http": read_requirements("requirements/requirements.http.txt"),
"sam": read_requirements("requirements/requirements.sam.txt"),
"sam2": read_requirements("requirements/requirements.sam2.txt"),
"waf": read_requirements("requirements/requirements.waf.txt"),
"yolo-world": read_requirements("requirements/requirements.yolo_world.txt"),
},
Expand Down
1 change: 1 addition & 0 deletions .release/pypi/inference.gpu.setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def read_requirements(path):
"hosted": read_requirements("requirements/requirements.hosted.txt"),
"http": read_requirements("requirements/requirements.http.txt"),
"sam": read_requirements("requirements/requirements.sam.txt"),
"sam2": read_requirements("requirements/requirements.sam2.txt"),
"waf": read_requirements("requirements/requirements.waf.txt"),
"yolo-world": read_requirements("requirements/requirements.yolo_world.txt"),
},
Expand Down
1 change: 1 addition & 0 deletions .release/pypi/inference.setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def read_requirements(path):
"hosted": read_requirements("requirements/requirements.hosted.txt"),
"http": read_requirements("requirements/requirements.http.txt"),
"sam": read_requirements("requirements/requirements.sam.txt"),
"sam2": read_requirements("requirements/requirements.sam2.txt"),
"waf": read_requirements("requirements/requirements.waf.txt"),
"yolo-world": read_requirements("requirements/requirements.yolo_world.txt"),
},
Expand Down
56 changes: 56 additions & 0 deletions docker/dockerfiles/Dockerfile.sam2
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
FROM nvcr.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04

WORKDIR /app

RUN rm -rf /var/lib/apt/lists/* && apt-get clean && apt-get update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y \
ffmpeg \
libxext6 \
libopencv-dev \
uvicorn \
python3-pip \
git \
libgdal-dev \
wget \
gcc \
&& rm -rf /var/lib/apt/lists/*

COPY requirements/requirements.sam2.txt \
requirements/requirements.http.txt \
requirements/_requirements.txt \
requirements/requirements.gpu.txt \
requirements/requirements.sdk.http.txt \
./

RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --extra-index-url https://download.pytorch.org/whl/cu141 \
-r _requirements.txt \
-r requirements.sam2.txt \
-r requirements.http.txt \
-r requirements.gpu.txt \
-r requirements.sdk.http.txt \
--upgrade \
&& rm -rf ~/.cache/pip

WORKDIR /sam
RUN git clone https://github.com/facebookresearch/segment-anything-2
COPY inference/models/sam2/sam.patch ./
RUN cd segment-anything-2 && git checkout 0e78a118995e66bb27d78518c4bd9a3e95b4e266 && git apply ../sam.patch && TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" python3 -m pip install -e . && rm -rf ~/.cache/pip

WORKDIR /app/
COPY inference inference
COPY inference_sdk inference_sdk
COPY docker/config/gpu_http.py gpu_http.py

ENV VERSION_CHECK_MODE=continuous
ENV PROJECT=roboflow-platform
ENV NUM_WORKERS=1
ENV HOST=0.0.0.0
ENV PORT=9001
ENV WORKFLOWS_STEP_EXECUTION_MODE=local
ENV WORKFLOWS_MAX_CONCURRENT_STEPS=1
ENV API_LOGGING_ENABLED=True
ENV LMM_ENABLED=True
ENV PYTHONPATH=/app/
ENV CORE_MODEL_SAM2_ENABLED=True

ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT
216 changes: 216 additions & 0 deletions docs/foundation/sam2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
<a href="https://github.com/facebookresearch/segment-anything-2" target="_blank">Segment Anything 2</a> is an open source image segmentation model.

You can use Segment Anything 2 to identify the precise location of objects in an image. This process can generate masks for objects in an image iteratively, by specifying points to be included or discluded from the segmentation mask.

## How to Use Segment Anything

To use Segment Anything 2 with Inference, you will need a Roboflow API key. If you don't already have a Roboflow account, <a href="https://app.roboflow.com" target="_blank">sign up for a free Roboflow account</a>. Then, retrieve your API key from the Roboflow dashboard.

## How To Use SAM2 Locally With Inference

We will follow along with the example located at `examples/sam2/sam2.py`.

We start with the following image,

![Input image](https://media.roboflow.com/inference/sam2/hand.png)

compute the most prominent mask,

![Most prominent mask](https://media.roboflow.com/inference/sam2/sam.png)

and negative prompt the wrist to obtain only the fist.

![Negative prompt](https://media.roboflow.com/inference/sam2/sam_negative_prompted.png)

### Running within docker
Build the dockerfile (make sure your cwd is at the root of inference) with
```
docker build -f docker/dockerfiles/Dockerfile.sam2 -t sam2 .
```

Start up an interactive terminal with
```
docker run -it --rm --entrypoint bash -v $(pwd)/scratch/:/app/scratch/ -v /tmp/cache/:/tmp/cache/ -v $(pwd)/inference/:/app/inference/ --gpus=all --net=host sam2
```
You can save files to `/app/scratch/` to use them on the host device.

Or, start a sam2 server with
```
docker run -it --rm -v /tmp/cache/:/tmp/cache/ -v $(pwd)/inference/:/app/inference/ --gpus=all --net=host sam2
```

and interact over http.

### Imports
Set up your api key, and install <a href="https://github.com/facebookresearch/segment-anything-2" target="_blank">Segment Anything 2</a>

!!! note

There's <a href="https://github.com/facebookresearch/segment-anything-2/issues/48" target="_blank">currently a problem</a> with sam2 + flash attention on certain gpus, like the L4 or A100. Use the fix in the posted thread, or use the docker image we provide for sam2.

```
import os
os.environ["API_KEY"] = "<YOUR-API-KEY>"
from inference.models.sam2 import SegmentAnything2
from inference.core.utils.postprocess import masks2poly
import supervision as sv
from PIL import Image
import numpy as np
image_path = "./examples/sam2/hand.png"
```
### Model Loading
Load the model with
```
m = SegmentAnything2(model_id="sam2/hiera_large")
```

Other values for `model_id` are `"hiera_small", "hiera_large", "hiera_tiny", "hiera_b_plus"`.

### Compute the Most Prominent Mask

```
# call embed_image before segment_image to precompute embeddings
embedding, img_shape, id_ = m.embed_image(image_path)
# segments image using cached embedding if it exists, else computes it on the fly
raw_masks, raw_low_res_masks = m.segment_image(image_path)
# convert binary masks to polygons
raw_masks = raw_masks >= m.predictor.mask_threshold
poly_masks = masks2poly(raw_masks)
```
Note that you can embed the image as soon as you know you want to process it, and the embeddings are cached automatically for faster downstream processing.

The resulting mask will look like this:

![Most prominent mask](https://media.roboflow.com/inference/sam2/sam.png)

### Negative Prompt the Model
```
point = [250, 800]
# give a negative point (point_label 0) or a positive example (point_label 1)
# uses cached masks from prior call
raw_masks2, raw_low_res_masks2 = m.segment_image(
image_path,
point_coords=[point],
point_labels=[0],
)
raw_masks2 = raw_masks2 >= m.predictor.mask_threshold
```
Here we tell the model that the cached mask should not include the wrist.

The resulting mask will look like this:

![Negative prompt](https://media.roboflow.com/inference/sam2/sam_negative_prompted.png)

### Annotate
Use <a href="https://github.com/roboflow/supervision" target="_blank">Supervision</a> to draw the results of the model.

```
image = np.array(Image.open(image_path).convert("RGB"))
mask_annotator = sv.MaskAnnotator()
dot_annotator = sv.DotAnnotator()
detections = sv.Detections(
xyxy=np.array([[0, 0, 100, 100]]), mask=np.array([raw_masks])
)
detections.class_id = [i for i in range(len(detections))]
annotated_image = mask_annotator.annotate(image.copy(), detections)
im = Image.fromarray(annotated_image)
im.save("sam.png")
detections = sv.Detections(
xyxy=np.array([[0, 0, 100, 100]]), mask=np.array([raw_masks2])
)
detections.class_id = [i for i in range(len(detections))]
annotated_image = mask_annotator.annotate(image.copy(), detections)
dot_detections = sv.Detections(
xyxy=np.array([[point[0] - 1, point[1] - 1, point[0] + 1, point[1] + 1]]),
class_id=np.array([1]),
)
annotated_image = dot_annotator.annotate(annotated_image, dot_detections)
im = Image.fromarray(annotated_image)
im.save("sam_negative_prompted.png")
```
## How To Use SAM2 With a Local Docker Container HTTP Server

### Build and Start The Server

Build the dockerfile (make sure your cwd is at the root of inference) with
```
docker build -f docker/dockerfiles/Dockerfile.sam2 -t sam2 .
```
and start a sam2 server with
```
docker run -it --rm -v /tmp/cache/:/tmp/cache/ -v $(pwd)/inference/:/app/inference/ --gpus=all --net=host sam2
```

### Embed an Image

An embedding is a numeric representation of an image. SAM uses embeddings as input to calcualte the location of objects in an image.

Create a new Python file and add the following code:

```python
import requests

infer_payload = {
"image": {
"type": "base64",
"value": "https://i.imgur.com/Q6lDy8B.jpg",
},
"image_id": "example_image_id",
}

base_url = "http://localhost:9001"

# Define your Roboflow API Key
api_key = "YOUR ROBOFLOW API KEY"

res = requests.post(
f"{base_url}/sam2/embed_image?api_key={api_key}",
json=infer_payload,
)

```

This code makes a request to Inference to embed an image using SAM.

The `example_image_id` is used to cache the embeddings for later use so you don't have to send them back in future segmentation requests.

### Segment an Object

To segment an object, you need to know at least one point in the image that represents the object that you want to use.

!!! tip "For testing with a single image, you can upload an image to the <a href="https://roboflow.github.io/polygonzone/" target="_blank">Polygon Zone web interface</a> and hover over a point in the image to see the coordinates of that point."

You may also opt to use an object detection model to identify an object, then use the center point of the bounding box as a prompt for segmentation.

Create a new Python file and add the following code:

```python
#Define request payload
infer_payload = {
"image": {
"type": "base64",
"value": "https://i.imgur.com/Q6lDy8B.jpg",
},
"point_coords": [[380, 350]],
"point_labels": [1],
"image_id": "example_image_id",
}

res = requests.post(
f"{base_url}/sam2/embed_image?api_key={api_key}",
json=infer_payload,
)

masks = request.json()['masks']
```

This request returns segmentation masks that represent the object of interest.
Binary file added examples/sam2/hand.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 59 additions & 0 deletions examples/sam2/sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os

os.environ["API_KEY"] = "<YOUR-API-KEY>"
from inference.models.sam2 import SegmentAnything2
from inference.core.utils.postprocess import masks2poly
import supervision as sv
from PIL import Image
import numpy as np

image_path = "./examples/sam2/hand.png"
m = SegmentAnything2(model_id="sam2/hiera_large")

# call embed_image before segment_image to precompute embeddings
embedding, img_shape, id_ = m.embed_image(image_path)

# segments image using cached embedding if it exists, else computes it on the fly
raw_masks, raw_low_res_masks = m.segment_image(image_path)

# convert binary masks to polygons
raw_masks = raw_masks >= m.predictor.mask_threshold
poly_masks = masks2poly(raw_masks)

point = [250, 800]
# give a negative point (point_label 0) or a positive example (point_label 1)
# uses cached masks from prior call
raw_masks2, raw_low_res_masks2 = m.segment_image(
image_path,
point_coords=[point],
point_labels=[0],
)

raw_masks2 = raw_masks2 >= m.predictor.mask_threshold

image = np.array(Image.open(image_path).convert("RGB"))

mask_annotator = sv.MaskAnnotator()
dot_annotator = sv.DotAnnotator()

detections = sv.Detections(
xyxy=np.array([[0, 0, 100, 100]]), mask=np.array([raw_masks])
)
detections.class_id = [i for i in range(len(detections))]
annotated_image = mask_annotator.annotate(image.copy(), detections)
im = Image.fromarray(annotated_image)
im.save("sam.png")

detections = sv.Detections(
xyxy=np.array([[0, 0, 100, 100]]), mask=np.array([raw_masks2])
)
detections.class_id = [i for i in range(len(detections))]
annotated_image = mask_annotator.annotate(image.copy(), detections)

dot_detections = sv.Detections(
xyxy=np.array([[point[0] - 1, point[1] - 1, point[0] + 1, point[1] + 1]]),
class_id=np.array([1]),
)
annotated_image = dot_annotator.annotate(annotated_image, dot_detections)
im = Image.fromarray(annotated_image)
im.save("sam_negative_prompted.png")
Loading

0 comments on commit ecdfe11

Please sign in to comment.