Skip to content

Commit

Permalink
Merge pull request #637 from roboflow/fix/vlm_as_classifier
Browse files Browse the repository at this point in the history
Fix problem with VLM as classifier block
  • Loading branch information
PawelPeczek-Roboflow committed Sep 6, 2024
2 parents 97445aa + c9d2d44 commit 3f4a264
Show file tree
Hide file tree
Showing 9 changed files with 633 additions and 8 deletions.
2 changes: 1 addition & 1 deletion inference/core/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.18.0"
__version__ = "0.18.1"


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ def parse_multi_class_classification_results(
if top_class not in class2id_mapping:
predictions.append(
{
"class_name": top_class,
"class": top_class,
"class_id": -1,
"confidence": confidences.get(top_class, 0.0),
}
)
for class_name, class_id in class2id_mapping.items():
predictions.append(
{
"class_name": class_name,
"class": class_name,
"class_id": class_id,
"confidence": confidences.get(class_name, 0.0),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,101 @@ def test_object_detection_workflow(
"dog",
"dog",
], "Expected 2 dogs to be detected"


VLM_AS_SECONDARY_CLASSIFIER_WORKFLOW = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{"type": "WorkflowParameter", "name": "api_key"},
{"type": "WorkflowParameter", "name": "model_id"},
{
"type": "WorkflowParameter",
"name": "classes",
"default_value": [
"russell-terrier",
"wirehaired-pointing-griffon",
"beagle",
],
},
],
"steps": [
{
"type": "ObjectDetectionModel",
"name": "general_detection",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"class_filter": ["dog"],
},
{
"type": "Crop",
"name": "cropping",
"image": "$inputs.image",
"predictions": "$steps.general_detection.predictions",
},
{
"type": "roboflow_core/anthropic_claude@v1",
"name": "claude",
"images": "$steps.cropping.crops",
"task_type": "classification",
"classes": "$inputs.classes",
"api_key": "$inputs.api_key",
},
{
"type": "roboflow_core/vlm_as_classifier@v1",
"name": "parser",
"image": "$steps.cropping.crops",
"vlm_output": "$steps.claude.output",
"classes": "$steps.claude.classes",
},
{
"type": "roboflow_core/detections_classes_replacement@v1",
"name": "classes_replacement",
"object_detection_predictions": "$steps.general_detection.predictions",
"classification_predictions": "$steps.parser.predictions",
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.classes_replacement.predictions",
},
],
}


@pytest.mark.skipif(ANTHROPIC_API_KEY is None, reason="No Antropic API key provided")
@pytest.mark.flaky(retries=4, delay=1)
def test_workflow_with_secondary_classifier(
object_detection_service_url: str,
dogs_image: np.ndarray,
yolov8n_640_model_id: str,
) -> None:
# given
client = InferenceHTTPClient(
api_url=object_detection_service_url,
api_key=ROBOFLOW_API_KEY,
)

# when
result = client.run_workflow(
specification=VLM_AS_SECONDARY_CLASSIFIER_WORKFLOW,
images={
"image": dogs_image,
},
parameters={
"api_key": ANTHROPIC_API_KEY,
"classes": ["russell-terrier", "wirehaired-pointing-griffon", "beagle"],
"model_id": yolov8n_640_model_id,
},
)

# then
assert len(result) == 1, "Single image given, expected single output"
assert set(result[0].keys()) == {
"predictions",
}, "Expected all outputs to be delivered"
assert "dog" not in set(
[e["class"] for e in result[0]["predictions"]["predictions"]]
), "Expected classes to be substituted"
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,100 @@ def test_object_detection_workflow(
"dog",
"dog",
], "Expected 2 dogs to be detected"


VLM_AS_SECONDARY_CLASSIFIER_WORKFLOW = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{"type": "WorkflowParameter", "name": "api_key"},
{"type": "WorkflowParameter", "name": "model_id"},
{
"type": "WorkflowParameter",
"name": "classes",
"default_value": [
"russell-terrier",
"wirehaired-pointing-griffon",
"beagle",
],
},
],
"steps": [
{
"type": "ObjectDetectionModel",
"name": "general_detection",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"class_filter": ["dog"],
},
{
"type": "Crop",
"name": "cropping",
"image": "$inputs.image",
"predictions": "$steps.general_detection.predictions",
},
{
"type": "roboflow_core/google_gemini@v1",
"name": "gemini",
"images": "$steps.cropping.crops",
"task_type": "classification",
"classes": "$inputs.classes",
"api_key": "$inputs.api_key",
},
{
"type": "roboflow_core/vlm_as_classifier@v1",
"name": "parser",
"image": "$steps.cropping.crops",
"vlm_output": "$steps.gemini.output",
"classes": "$steps.gemini.classes",
},
{
"type": "roboflow_core/detections_classes_replacement@v1",
"name": "classes_replacement",
"object_detection_predictions": "$steps.general_detection.predictions",
"classification_predictions": "$steps.parser.predictions",
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.classes_replacement.predictions",
},
],
}


@pytest.mark.skipif(GOOGLE_API_KEY is None, reason="No Google API key provided")
@pytest.mark.flaky(retries=4, delay=1)
def test_workflow_with_secondary_classifier(
object_detection_service_url: str,
dogs_image: np.ndarray,
yolov8n_640_model_id: str,
) -> None:
client = InferenceHTTPClient(
api_url=object_detection_service_url,
api_key=ROBOFLOW_API_KEY,
)

# when
result = client.run_workflow(
specification=VLM_AS_SECONDARY_CLASSIFIER_WORKFLOW,
images={
"image": dogs_image,
},
parameters={
"api_key": GOOGLE_API_KEY,
"classes": ["russell-terrier", "wirehaired-pointing-griffon", "beagle"],
"model_id": yolov8n_640_model_id,
},
)

# then
assert len(result) == 1, "Single image given, expected single output"
assert set(result[0].keys()) == {
"predictions",
}, "Expected all outputs to be delivered"
assert "dog" not in set(
[e["class"] for e in result[0]["predictions"]["predictions"]]
), "Expected classes to be substituted"
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,100 @@ def test_structured_prompting_workflow(
assert len(result) == 1, "Single image given, expected single output"
assert set(result[0].keys()) == {"result"}, "Expected all outputs to be delivered"
assert result[0]["result"] == "2"


VLM_AS_SECONDARY_CLASSIFIER_WORKFLOW = {
"version": "1.0",
"inputs": [
{"type": "WorkflowImage", "name": "image"},
{"type": "WorkflowParameter", "name": "api_key"},
{"type": "WorkflowParameter", "name": "model_id"},
{
"type": "WorkflowParameter",
"name": "classes",
"default_value": [
"russell-terrier",
"wirehaired-pointing-griffon",
"beagle",
],
},
],
"steps": [
{
"type": "ObjectDetectionModel",
"name": "general_detection",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"class_filter": ["dog"],
},
{
"type": "Crop",
"name": "cropping",
"image": "$inputs.image",
"predictions": "$steps.general_detection.predictions",
},
{
"type": "roboflow_core/open_ai@v2",
"name": "gpt",
"images": "$steps.cropping.crops",
"task_type": "classification",
"classes": "$inputs.classes",
"api_key": "$inputs.api_key",
},
{
"type": "roboflow_core/vlm_as_classifier@v1",
"name": "parser",
"image": "$steps.cropping.crops",
"vlm_output": "$steps.gpt.output",
"classes": "$steps.gpt.classes",
},
{
"type": "roboflow_core/detections_classes_replacement@v1",
"name": "classes_replacement",
"object_detection_predictions": "$steps.general_detection.predictions",
"classification_predictions": "$steps.parser.predictions",
},
],
"outputs": [
{
"type": "JsonField",
"name": "predictions",
"selector": "$steps.classes_replacement.predictions",
},
],
}


@pytest.mark.skipif(OPENAI_KEY is None, reason="No OpenAI API key provided")
@pytest.mark.flaky(retries=4, delay=1)
def test_structured_prompting_workflow(
object_detection_service_url: str,
dogs_image: np.ndarray,
yolov8n_640_model_id: str,
) -> None:
client = InferenceHTTPClient(
api_url=object_detection_service_url,
api_key=ROBOFLOW_API_KEY,
)

# when
result = client.run_workflow(
specification=VLM_AS_SECONDARY_CLASSIFIER_WORKFLOW,
images={
"image": dogs_image,
},
parameters={
"api_key": OPENAI_KEY,
"classes": ["russell-terrier", "wirehaired-pointing-griffon", "beagle"],
"model_id": yolov8n_640_model_id,
},
)

# then
assert len(result) == 1, "Single image given, expected single output"
assert set(result[0].keys()) == {
"predictions",
}, "Expected all outputs to be delivered"
assert "dog" not in set(
[e["class"] for e in result[0]["predictions"]["predictions"]]
), "Expected classes to be substituted"
Loading

0 comments on commit 3f4a264

Please sign in to comment.