From 14c1b9cb27436ea21242a5c5240499896b1b5606 Mon Sep 17 00:00:00 2001 From: Patrick Orlando Date: Thu, 23 Feb 2023 23:08:52 +1100 Subject: [PATCH] fix: remove extra candidate_ids for factorized_metrics - If only candidate_embeddings are sliced a shape mismatch occurs --- tensorflow_recommenders/tasks/retrieval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_recommenders/tasks/retrieval.py b/tensorflow_recommenders/tasks/retrieval.py index bd44e05d..bc219c55 100644 --- a/tensorflow_recommenders/tasks/retrieval.py +++ b/tensorflow_recommenders/tasks/retrieval.py @@ -199,8 +199,8 @@ def call(self, query_embeddings, # Slice to the size of query embeddings # if `candidate_embeddings` contains extra negatives. - candidate_embeddings[:tf.shape(query_embeddings)[0]], - true_candidate_ids=candidate_ids) + candidate_embeddings[:num_queries], + true_candidate_ids=candidate_ids[:num_queries]) ) if compute_batch_metrics: