diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ab8c4cf1e..db76edebc 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -48,6 +48,7 @@ jobs: pip install torch-scatter -f https://data.pyg.org/whl/torch-`python -c "import torch;print(torch.__version__)"`.html pip install setuptools==59.5.0 pip install plotly + pip install kmeans-pytorch # Use "python -m pytest" instead of "pytest" to fix imports - name: Test Overall run: | @@ -90,4 +91,4 @@ jobs: - name: Apply code-format changes uses: stefanzweifel/git-auto-commit-action@v4 with: - commit_message: Format Python code according to PEP8 \ No newline at end of file + commit_message: Format Python code according to PEP8 diff --git a/asset/dataset_list.json b/asset/dataset_list.json index 4e2aaedd5..7f3fe07e5 100644 --- a/asset/dataset_list.json +++ b/asset/dataset_list.json @@ -8,7 +8,7 @@ "inter_num": "-", "sparsity": "-", "type": "Rating", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/MovieLens.md" }, { @@ -19,7 +19,7 @@ "inter_num": "7,813,737", "sparsity": "99.05%", "type": "Rating [-1, 1-10]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Anime.md" }, { @@ -30,7 +30,7 @@ "inter_num": "188,478", "sparsity": "99.99%", "type": "Rating [1-5]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Epinions.md" }, { @@ -41,7 +41,7 @@ "inter_num": "-", "sparsity": "-", "type": "Rating", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Yelp.md" }, { @@ -52,7 +52,7 @@ "inter_num": "100,480,507", "sparsity": "98.82%", "type": "Rating [1-5]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Netflix.md" }, { @@ -63,7 +63,7 @@ "inter_num": "1,149,780", "sparsity": "99.99%", "type": "Rating [0-10]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Book-Crossing.md" }, { @@ -74,7 +74,7 @@ "inter_num": "4,136,360", "sparsity": "44.22%", "type": "Rating [-10, 10]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Jester.md" }, { @@ -85,7 +85,7 @@ "inter_num": "2,125,056", "sparsity": "89.73%", "type": "Rating [0, 5]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Douban.md" }, { @@ -93,10 +93,10 @@ "dataset_link": "", "user_num": "1,948,882", "item_num": "98,211", - "inter_num": "11,557,943", + "inter_num": "111,557,943", "sparsity": "99.99%", "type": "Rating [0, 100]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/YahooMusic.md" }, { @@ -107,7 +107,7 @@ "inter_num": "-", "sparsity": "-", "type": "Rating", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/KDD2010.md" }, { @@ -118,7 +118,7 @@ "inter_num": "-", "sparsity": "-", "type": "Rating [0, 5]", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Amazon.md" }, { @@ -129,7 +129,7 @@ "inter_num": "1,445,622", "sparsity": "99.74%", "type": "-", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Pinterest.md" }, { @@ -140,7 +140,7 @@ "inter_num": "6,442,892", "sparsity": "99.99%", "type": "Check-in", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Gowalla.md" }, { @@ -151,7 +151,7 @@ "inter_num": "92,834", "sparsity": "99.72%", "type": "Click", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/LastFM.md" }, { @@ -162,7 +162,7 @@ "inter_num": "993,483", "sparsity": "99.99%", "type": "Click", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/DIGINETICA.md" }, { @@ -173,7 +173,7 @@ "inter_num": "7,793,069", "sparsity": "99.99%", "type": "Buy", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Steam.md" }, { @@ -184,7 +184,7 @@ "inter_num": "817,741", "sparsity": "99.89%", "type": "Click", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/TaFeng.md" }, { @@ -195,7 +195,7 @@ "inter_num": "-", "sparsity": "-", "type": "Check-in", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Foursquare.md" }, { @@ -206,7 +206,7 @@ "inter_num": "44,528,127", "sparsity": "99.99%", "type": "Click/Buy", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Tmall.md" }, { @@ -217,7 +217,7 @@ "inter_num": "34,154,697", "sparsity": "99.99%", "type": "Click/Buy", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/YOOCHOOSE.md" }, { @@ -228,7 +228,7 @@ "inter_num": "2,756,101", "sparsity": "99.99%", "type": "View/Addtocart/Transaction", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/Retailrocket.md" }, { @@ -239,7 +239,7 @@ "inter_num": "1,088,161,692", "sparsity": "99.71%", "type": "Click", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/LFM-1b.md" }, { @@ -250,7 +250,7 @@ "inter_num": "-", "sparsity": "-", "type": "Click", - "link_name": "scipt", + "link_name": "script", "link_url": "https://github.com/RUCAIBox/RecDatasets/blob/master/conversion_tools/usage/MIND.md" }, { @@ -452,4 +452,4 @@ "link_url": "https://tianchi.aliyun.com/dataset/dataDetail?dataId=56#1" } ] -} \ No newline at end of file +} diff --git a/asset/questionnaire.xlsx b/asset/questionnaire.xlsx new file mode 100644 index 000000000..9333bd13b Binary files /dev/null and b/asset/questionnaire.xlsx differ diff --git a/docs/source/asset/diffrec.png b/docs/source/asset/diffrec.png new file mode 100644 index 000000000..0251077a0 Binary files /dev/null and b/docs/source/asset/diffrec.png differ diff --git a/docs/source/asset/ldiffrec.png b/docs/source/asset/ldiffrec.png new file mode 100644 index 000000000..876f659e2 Binary files /dev/null and b/docs/source/asset/ldiffrec.png differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 290b360c6..c2dc7ab57 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -11,7 +11,7 @@ Introduction RecBole is a unified, comprehensive and efficient framework developed based on PyTorch. It aims to help the researchers to reproduce and develop recommendation models. -In the lastest release, our library includes 87 recommendation algorithms `[Model List]`_, covering four major categories: +In the lastest release, our library includes 89 recommendation algorithms `[Model List]`_, covering four major categories: - General Recommendation - Sequential Recommendation diff --git a/docs/source/user_guide/model/general/diffrec.rst b/docs/source/user_guide/model/general/diffrec.rst new file mode 100644 index 000000000..dab976473 --- /dev/null +++ b/docs/source/user_guide/model/general/diffrec.rst @@ -0,0 +1,94 @@ +DiffRec +=========== + +Introduction +--------------------- + +`[paper] `_ + +**Title:** Diffusion Recommender Model + +**Authors:** Wenjie Wang, Yiyan Xu, Fuli Feng, Xinyu Lin, Xiangnan He, Tat-Seng Chua + +**Abstract:** Generative models such as Generative Adversarial Networks (GANs) and Variational Auto-Encoders (VAEs) are widely utilized to model the generative process of user interactions. However, they suffer from intrinsic limitations such as the instability of GANs and the restricted representation ability of VAEs. Such limitations hinder the accurate modeling of the complex user interaction generation procedure, such as noisy interactions caused by various interference factors. In light of the impressive advantages of Diffusion Models (DMs) over traditional generative models in image synthesis, we propose a novel Diffusion Recommender Model (named DiffRec) to learn the generative process in a denoising manner. To retain personalized information in user interactions, DiffRec reduces the added noises and avoids corrupting users’ interactions into pure noises like in image synthesis. In addition, we extend traditional DMs to tackle the unique challenges in recommendation: high resource costs for large-scale item prediction and temporal shifts of user preference. To this end, we propose two extensions of DiffRec: L-DiffRec clusters items for dimension compression and conducts the diffusion processes in the latent space; and T-DiffRec reweights user interactions based on the interaction timestamps to encode temporal information. We conduct extensive experiments on three datasets under multiple settings (e.g., clean training, noisy training, and temporal training). The empirical results validate the superiority of DiffRec with two extensions over competitive baselines. + +.. image:: ../../../asset/diffrec.png + :width: 500 + :align: center + +Running with RecBole +------------------------- + +**Model Hyper-Parameters:** + +- ``noise_schedule (str)`` : The schedule for noise generating: ['linear', 'linear-var', 'cosine', 'binomial']. Defaults to ``'linear'``. +- ``noise_scale (int)`` : The scale for noise generating. Defaults to ``0.001``. +- ``noise_min (int)`` : Noise lower bound for noise generating. Defaults to ``0.0005``. +- ``noise_max (int)`` : 0.005 Noise upper bound for noise generating. Defaults to ``0.005``. +- ``sampling_noise (bool)`` : Whether to use sampling noise. Defaults to ``False``. +- ``sampling_steps (int)`` : Steps of the forward process during inference. Defaults to ``0``. +- ``reweight (bool)`` : Assign different weight to different timestep or not. Defaults to ``True``. +- ``mean_type (str)`` : MeanType for diffusion: ['x0', 'eps']. Defaults to ``'x0'``. +- ``steps (int)`` : Diffusion steps. Defaults to ``5``. +- ``history_num_per_term (int)`` : The number of history items needed to calculate loss weight. Defaults to ``10``. +- ``beta_fixed (bool)`` : Whether to fix the variance of the first step to prevent overfitting. Defaults to ``True``. +- ``dims_dnn (list of int)`` : The dims for the DNN. Defaults to ``[300]``. +- ``embedding_size (int)`` : Timestep embedding size. Defaults to ``10``. +- ``mlp_act_func (str)`` : Activation function for MLP. Defaults to ``'tanh'``. +- ``time-aware (bool)`` : T-DiffRec or not. Defaults to ``False``. +- ``w_max (int)`` : The upper bound of the time-aware interaction weight. Defaults to ``1``. +- ``w_min (int)`` : The lower bound of the time-aware interaction weight. Defaults to ``0.1``. + + +**A Running Example:** + +Write the following code to a python file, such as `run.py` + +.. code:: python + + from recbole.quick_start import run_recbole + + run_recbole(model='DiffRec', dataset='ml-100k') + +And then: + +.. code:: bash + + python run.py + +**Notes:** + +- ``w_max`` and ``w_min`` are unused when ``time-aware`` is False. + +Tuning Hyper Parameters +------------------------- + +If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``. + +.. code:: bash + + learning_rate choice [1e-3,1e-4,1e-5] + dims_dnn choice ['[300]','[200,600]','[1000]'] + steps choice [2,5,10,50] + noice_scale choice [0,1e-5,1e-4,1e-3,1e-2,1e-1] + noice_min choice [5e-4,1e-3,5e-3] + noice_max choice [5e-3,1e-2] + w_min choice [0.1,0.2,0.3] + +Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model. + +Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning: + +.. code:: bash + + python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test + +For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`. + + +If you want to change parameters, dataset or evaluation settings, take a look at + +- :doc:`../../../user_guide/config_settings` +- :doc:`../../../user_guide/data_intro` +- :doc:`../../../user_guide/train_eval_intro` +- :doc:`../../../user_guide/usage` \ No newline at end of file diff --git a/docs/source/user_guide/model/general/ldiffrec.rst b/docs/source/user_guide/model/general/ldiffrec.rst new file mode 100644 index 000000000..28f3aa5de --- /dev/null +++ b/docs/source/user_guide/model/general/ldiffrec.rst @@ -0,0 +1,106 @@ +LDiffRec +=========== + +Introduction +--------------------- + +`[paper] `_ + +**Title:** Diffusion Recommender Model + +**Authors:** Wenjie Wang, Yiyan Xu, Fuli Feng, Xinyu Lin, Xiangnan He, Tat-Seng Chua + +**Abstract:** Generative models such as Generative Adversarial Networks (GANs) and Variational Auto-Encoders (VAEs) are widely utilized to model the generative process of user interactions. However, they suffer from intrinsic limitations such as the instability of GANs and the restricted representation ability of VAEs. Such limitations hinder the accurate modeling of the complex user interaction generation procedure, such as noisy interactions caused by various interference factors. In light of the impressive advantages of Diffusion Models (DMs) over traditional generative models in image synthesis, we propose a novel Diffusion Recommender Model (named DiffRec) to learn the generative process in a denoising manner. To retain personalized information in user interactions, DiffRec reduces the added noises and avoids corrupting users’ interactions into pure noises like in image synthesis. In addition, we extend traditional DMs to tackle the unique challenges in recommendation: high resource costs for large-scale item prediction and temporal shifts of user preference. To this end, we propose two extensions of DiffRec: L-DiffRec clusters items for dimension compression and conducts the diffusion processes in the latent space; and T-DiffRec reweights user interactions based on the interaction timestamps to encode temporal information. We conduct extensive experiments on three datasets under multiple settings (e.g., clean training, noisy training, and temporal training). The empirical results validate the superiority of DiffRec with two extensions over competitive baselines. + +.. image:: ../../../asset/ldiffrec.png + :width: 500 + :align: center + +Running with RecBole +------------------------- + +**Model Hyper-Parameters:** + +- ``noise_schedule (str)`` : The schedule for noise generating: [linear, linear-var, cosine, binomial]. Defaults to ``'linear'``. +- ``noise_scale (int)`` : The scale for noise generating. Defaults to ``0.1``. +- ``noise_min (int)`` : Noise lower bound for noise generating. Defaults to ``0.001``. +- ``noise_max (int)`` : 0.005 Noise upper bound for noise generating. Defaults to ``0.005``. +- ``sampling_noise (bool)`` : Whether to use sampling noise. Defaults to ``False``. +- ``sampling_steps (int)`` : Steps of the forward process during inference. Defaults to ``0``. +- ``reweight (bool)`` : Assign different weight to different timestep or not. Defaults to ``True``. +- ``mean_type (str)`` : MeanType for diffusion: ['x0', 'eps']. Defaults to ``'x0'``. +- ``steps (int)`` : Diffusion steps. Defaults to ``5``. +- ``history_num_per_term (int)`` : The number of history items needed to calculate loss weight. Defaults to ``10``. +- ``beta_fixed (bool)`` : Whether to fix the variance of the first step to prevent overfitting. Defaults to ``True``. +- ``dims_dnn (list of int)`` : The dims for the DNN. Defaults to ``[300]``. +- ``embedding_size (int)`` : Timestep embedding size. Defaults to ``10``. +- ``mlp_act_func (str)`` : Activation function for MLP. Defaults to ``'tanh'``. +- ``time-aware (bool)`` : LT-DiffRec or not. Defaults to ``False``. +- ``w_max (int)`` : The upper bound of the time-aware interaction weight. Defaults to ``1``. +- ``w_min (int)`` : The lower bound of the time-aware interaction weight. Defaults to ``0.1``. +- ``n_cate (int)`` : Category num of items. Defaults to ``1``. +- ``reparam (bool) `` : Autoencoder with variational inference or not. Defaults to ``True``. +- ``in_dims (list of int)`` : The dims for the encoder. Defaults to ``[300]``. +- ``out_dims (list of int)`` : The hidden dims for the decoder. Defaults to ``[]``. +- ``ae_act_func (str)`` : Activation function for AutoEncoder. Defaults to ``'tanh'``. +- ``lamda (float)`` : Hyper-parameter of multinomial log-likelihood for AE. Defaults to ``0.03``. +- ``anneal_cap (float)`` : The upper bound of the annealing weight. Defaults to ``0.005``. +- ``anneal_steps (int)`` : The steps of annealing. Defaults to ``1000``. +- ``vae_anneal_cap (float)`` : The upper bound of the VAE annealing weight. Defaults to ``0.3``. +- ``vae_anneal_steps (int)`` : The steps of VAE annealing. Defaults to ``200``. + + +**A Running Example:** + +Write the following code to a python file, such as `run.py` + +.. code:: python + + from recbole.quick_start import run_recbole + + run_recbole(model='LDiffRec', dataset='ml-100k') + +And then: + +.. code:: bash + + python run.py + +**Notes:** + +- ``w_max`` and ``w_min`` are unused when ``time-aware`` is False. + +- The item embedding file is needed if ``n_cate`` is greater than 1. + +Tuning Hyper Parameters +------------------------- + +If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``. + +.. code:: bash + + learning_rate choice [1e-3,1e-4,1e-5] + dims_dnn choice ['[300]','[200,600]','[1000]'] + steps choice [2,5,10,50] + noice_scale choice [0,1e-5,1e-4,1e-3,1e-2,1e-1] + noice_min choice [5e-4,1e-3,5e-3] + noice_max choice [5e-3,1e-2] + w_min choice [0.1,0.2,0.3] + +Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model. + +Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning: + +.. code:: bash + + python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test + +For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`. + + +If you want to change parameters, dataset or evaluation settings, take a look at + +- :doc:`../../../user_guide/config_settings` +- :doc:`../../../user_guide/data_intro` +- :doc:`../../../user_guide/train_eval_intro` +- :doc:`../../../user_guide/usage` \ No newline at end of file diff --git a/docs/source/user_guide/model_intro.rst b/docs/source/user_guide/model_intro.rst index 4d6c2e303..c4b31315e 100644 --- a/docs/source/user_guide/model_intro.rst +++ b/docs/source/user_guide/model_intro.rst @@ -1,6 +1,6 @@ Model Introduction ===================== -We implement 86 recommendation models covering general recommendation, sequential recommendation, +We implement 88 recommendation models covering general recommendation, sequential recommendation, context-aware recommendation and knowledge-based recommendation. A brief introduction to these models are as follows: @@ -43,6 +43,9 @@ task of top-n recommendation. All the collaborative filter(CF) based models are model/general/simplex model/general/ncl model/general/random + model/general/diffrec + model/general/ldiffrec + Context-aware Recommendation ------------------------------- diff --git a/docs/source/user_guide/usage/parameter_tuning.rst b/docs/source/user_guide/usage/parameter_tuning.rst index 642da9b62..2ec723842 100644 --- a/docs/source/user_guide/usage/parameter_tuning.rst +++ b/docs/source/user_guide/usage/parameter_tuning.rst @@ -34,7 +34,7 @@ The user can also use an encapsulated :attr:`objective_function`, that is: dataset = create_dataset(config) train_data, valid_data, test_data = data_preparation(config, dataset) model_name = config['model'] - model = get_model(model_name)(config, train_data).to(config['device']) + model = get_model(model_name)(config, train_data._dataset).to(config['device']) trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) best_valid_score, best_valid_result = trainer.fit(train_data, valid_data, verbose=False) test_result = trainer.evaluate(test_data) diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index 5bc630b94..21a8ce840 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -136,8 +136,7 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): self.neg_sample_args = neg_sample_args self.times = 1 if ( - self.neg_sample_args["distribution"] == "uniform" - or "popularity" + self.neg_sample_args["distribution"] in ["uniform", "popularity"] and self.neg_sample_args["sample_num"] != "none" ): self.neg_sample_num = self.neg_sample_args["sample_num"] diff --git a/recbole/data/dataloader/general_dataloader.py b/recbole/data/dataloader/general_dataloader.py index b04b5736d..cfd4d290f 100644 --- a/recbole/data/dataloader/general_dataloader.py +++ b/recbole/data/dataloader/general_dataloader.py @@ -138,8 +138,12 @@ def _init_batch_size_and_step(self): self.set_batch_size(batch_size) def update_config(self, config): + phase = self._sampler.phase if self._sampler.phase is not None else "test" self._set_neg_sample_args( - config, self._dataset, InputType.POINTWISE, config["eval_neg_sample_args"] + config, + self._dataset, + InputType.POINTWISE, + config[f"{phase}_neg_sample_args"], ) super().update_config(config) @@ -248,6 +252,9 @@ def _init_batch_size_and_step(self): self.step = batch_size self.set_batch_size(batch_size) + def update_config(self, config): + super().update_config(config) + def collate_fn(self, index): index = np.array(index) if not self.is_sequential: diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index d0d00673c..c151abaa1 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1810,7 +1810,9 @@ def save(self): """Saving this :class:`Dataset` object to :attr:`config['checkpoint_dir']`.""" save_dir = self.config["checkpoint_dir"] ensure_dir(save_dir) - file = os.path.join(save_dir, f'{self.config["dataset"]}-dataset.pth') + file = os.path.join( + save_dir, f'{self.config["dataset"]}-{self.__class__.__name__}.pth' + ) self.logger.info( set_color("Saving filtered dataset into ", "pink") + f"[{file}]" ) diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index 77b06315f..5713b776d 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -390,7 +390,7 @@ def _add_auxiliary_relation(self): reverse_kg_data = { self.head_entity_field: original_tids, self.relation_field: reverse_rels, - self.head_entity_field: original_hids, + self.tail_entity_field: original_hids, } reverse_kg_feat = pd.DataFrame(reverse_kg_data) self.kg_feat = pd.concat([self.kg_feat, reverse_kg_feat]) diff --git a/recbole/data/utils.py b/recbole/data/utils.py index fce37c93d..9a81431b6 100644 --- a/recbole/data/utils.py +++ b/recbole/data/utils.py @@ -160,6 +160,7 @@ def data_preparation(config, dataset): dataloaders = load_split_dataloaders(config) if dataloaders is not None: train_data, valid_data, test_data = dataloaders + dataset._change_feat_format() else: model_type = config["MODEL_TYPE"] built_datasets = dataset.build() @@ -245,6 +246,8 @@ def get_dataloader(config, phase: Literal["train", "valid", "test", "evaluation" "ENMF": _get_AE_dataloader, "RaCT": _get_AE_dataloader, "RecVAE": _get_AE_dataloader, + "DiffRec": _get_AE_dataloader, + "LDiffRec": _get_AE_dataloader, } if config["model"] in register_table: diff --git a/recbole/model/abstract_recommender.py b/recbole/model/abstract_recommender.py index 8d7ecbd16..17d4ece6a 100644 --- a/recbole/model/abstract_recommender.py +++ b/recbole/model/abstract_recommender.py @@ -3,9 +3,9 @@ # @Email : slmu@ruc.edu.cn # UPDATE: -# @Time : 2022/7/16, 2020/8/6, 2020/8/25 -# @Author : Zhen Tian, Shanlei Mu, Yupeng Hou -# @Email : chenyuwuxinn@gmail.com, slmu@ruc.edu.cn, houyupeng@ruc.edu.cn +# @Time : 2022/7/16, 2020/8/6, 2020/8/25, 2023/4/24 +# @Author : Zhen Tian, Shanlei Mu, Yupeng Hou, Chenglong Ma +# @Email : chenyuwuxinn@gmail.com, slmu@ruc.edu.cn, houyupeng@ruc.edu.cn, chenglong.m@outlook.com """ recbole.model.abstract_recommender @@ -117,6 +117,8 @@ class AutoEncoderMixin(object): def build_histroy_items(self, dataset): self.history_item_id, self.history_item_value, _ = dataset.history_item_matrix() + self.history_item_id = self.history_item_id.to(self.device) + self.history_item_value = self.history_item_value.to(self.device) def get_rating_matrix(self, user): r"""Get a batch of user's feature with the user's id and history interaction matrix. @@ -132,11 +134,12 @@ def get_rating_matrix(self, user): row_indices = torch.arange(user.shape[0]).repeat_interleave( self.history_item_id.shape[1], dim=0 ) - rating_matrix = torch.zeros(1).repeat(user.shape[0], self.n_items) + rating_matrix = torch.zeros(1, device=self.device).repeat( + user.shape[0], self.n_items + ) rating_matrix.index_put_( (row_indices, col_indices), self.history_item_value[user].flatten() ) - rating_matrix = rating_matrix.to(self.device) return rating_matrix @@ -335,7 +338,6 @@ def embed_float_fields(self, float_fields): Args: float_fields (torch.FloatTensor): The input dense tensor. shape of [batch_size, num_float_field] - embed (bool): Return the embedding of columns or just the columns itself. Defaults to ``True``. Returns: torch.FloatTensor: The result embedding tensor of float columns. diff --git a/recbole/model/context_aware_recommender/__init__.py b/recbole/model/context_aware_recommender/__init__.py index 7ec7b8fc2..489f13fea 100644 --- a/recbole/model/context_aware_recommender/__init__.py +++ b/recbole/model/context_aware_recommender/__init__.py @@ -15,3 +15,4 @@ from recbole.model.context_aware_recommender.xdeepfm import xDeepFM from recbole.model.context_aware_recommender.fignn import FiGNN from recbole.model.context_aware_recommender.kd_dagfm import KD_DAGFM +from recbole.model.context_aware_recommender.eulernet import EulerNet diff --git a/recbole/model/context_aware_recommender/eulernet.py b/recbole/model/context_aware_recommender/eulernet.py new file mode 100644 index 000000000..f6f8b536b --- /dev/null +++ b/recbole/model/context_aware_recommender/eulernet.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# @Time : 2023/4/21 12:00 +# @Author : Zhen Tian +# @Email : chenyuwuxinn@gmail.com +# @File : eulernet.py + +r""" +EulerNet +################################################ +Reference: + Zhen Tian et al. "EulerNet: Adaptive Feature Interaction Learning via Euler's Formula for CTR Prediction." in SIGIR 2023. + +Reference code: + https://github.com/chenyuwuxin/EulerNet + +""" + +import torch +import torch.nn as nn +from torch.nn.init import xavier_normal_, constant_ +from recbole.model.abstract_recommender import ContextRecommender +from recbole.model.loss import RegLoss + + +class EulerNet(ContextRecommender): + r"""EulerNet is a context-based recommendation model. + It can adaptively learn the arbitrary-order feature interactions in a complex vector space + by conducting space mapping according to Euler's formula. Meanwhile, it can jointly capture + the explicit and implicit feature interactions in a unified model architecture. + """ + + def __init__(self, config, dataset): + super(EulerNet, self).__init__(config, dataset) + field_num = self.field_num = self.num_feature_field + shape_list = [config.embedding_size * field_num] + [ + num_neurons * config.embedding_size for num_neurons in config.order_list + ] + + interaction_shapes = [] + for inshape, outshape in zip(shape_list[:-1], shape_list[1:]): + interaction_shapes.append(EulerInteractionLayer(config, inshape, outshape)) + + self.Euler_interaction_layers = nn.Sequential(*interaction_shapes) + self.mu = nn.Parameter(torch.ones(1, field_num, 1)) + self.reg = nn.Linear(shape_list[-1], 1) + self.reg_weight = config.reg_weight + nn.init.normal_(self.reg.weight, mean=0, std=0.01) + self.sigmoid = nn.Sigmoid() + self.reg_loss = RegLoss() + self.loss = nn.BCEWithLogitsLoss() + self.apply(self._init_other_weights) + + def _init_other_weights(self, module): + if isinstance(module, nn.Embedding): + xavier_normal_(module.weight.data) + elif isinstance(module, nn.Linear): + if module.bias is not None: + constant_(module.bias.data, 0) + + def forward(self, interaction): + fm_all_embeddings = self.concat_embed_input_fields( + interaction + ) # [batch_size, num_field, embed_dim] + r, p = self.mu * torch.cos(fm_all_embeddings), self.mu * torch.sin( + fm_all_embeddings + ) + o_r, o_p = self.Euler_interaction_layers((r, p)) + o_r, o_p = o_r.reshape(o_r.shape[0], -1), o_p.reshape(o_p.shape[0], -1) + re, im = self.reg(o_r), self.reg(o_p) + logits = re + im + return logits.squeeze(-1) + + def calculate_loss(self, interaction): + label = interaction[self.LABEL] + output = self.forward(interaction) + return self.loss(output, label) + self.RegularLoss(self.reg_weight) + + def predict(self, interaction): + return self.sigmoid(self.forward(interaction)) + + def RegularLoss(self, weight): + if weight == 0: + return 0 + loss = 0 + for _ in ["Euler_interaction_layers", "mu", "reg"]: + comp = getattr(self, _) + if isinstance(comp, nn.Parameter): + loss += torch.norm(comp, p=2) + continue + for params in comp.parameters(): + loss += torch.norm(params, p=2) + return loss * weight + + +class EulerInteractionLayer(nn.Module): + r"""Euler interaction layer is the core component of EulerNet, + which enables the adaptive learning of explicit feature interactions. An Euler + interaction layer performs the feature interaction under the complex space one time, + taking as input a complex representation and outputting a transformed complex representation. + """ + + def __init__(self, config, inshape, outshape): + super().__init__() + self.feature_dim = config.embedding_size + self.apply_norm = config.apply_norm + + init_orders = torch.softmax( + torch.randn(inshape // self.feature_dim, outshape // self.feature_dim) + / 0.01, + dim=0, + ) + self.inter_orders = nn.Parameter(init_orders) + self.im = nn.Linear(inshape, outshape) + + self.bias_lam = nn.Parameter( + torch.randn(1, self.feature_dim, outshape // self.feature_dim) * 0.01 + ) + self.bias_theta = nn.Parameter( + torch.randn(1, self.feature_dim, outshape // self.feature_dim) * 0.01 + ) + nn.init.normal_(self.im.weight, mean=0, std=0.1) + + self.drop_ex = nn.Dropout(p=config.drop_ex) + self.drop_im = nn.Dropout(p=config.drop_im) + self.norm_r = nn.LayerNorm([self.feature_dim]) + self.norm_p = nn.LayerNorm([self.feature_dim]) + + def forward(self, complex_features): + r, p = complex_features + + lam = r**2 + p**2 + 1e-8 + theta = torch.atan2(p, r) + lam, theta = lam.reshape(lam.shape[0], -1, self.feature_dim), theta.reshape( + theta.shape[0], -1, self.feature_dim + ) + r, p = self.drop_im(r), self.drop_im(p) + + lam = 0.5 * torch.log(lam) + lam, theta = torch.transpose(lam, -2, -1), torch.transpose(theta, -2, -1) + lam, theta = self.drop_ex(lam), self.drop_ex(theta) + lam, theta = ( + lam @ (self.inter_orders) + self.bias_lam, + theta @ (self.inter_orders) + self.bias_theta, + ) + lam = torch.exp(lam) + lam, theta = torch.transpose(lam, -2, -1), torch.transpose(theta, -2, -1) + + r, p = r.reshape(r.shape[0], -1), p.reshape(p.shape[0], -1) + r, p = self.im(r), self.im(p) + r, p = torch.relu(r), torch.relu(p) + r, p = r.reshape(r.shape[0], -1, self.feature_dim), p.reshape( + p.shape[0], -1, self.feature_dim + ) + + o_r, o_p = r + lam * torch.cos(theta), p + lam * torch.sin(theta) + o_r, o_p = o_r.reshape(o_r.shape[0], -1, self.feature_dim), o_p.reshape( + o_p.shape[0], -1, self.feature_dim + ) + if self.apply_norm: + o_r, o_p = self.norm_r(o_r), self.norm_p(o_p) + return o_r, o_p diff --git a/recbole/model/general_recommender/__init__.py b/recbole/model/general_recommender/__init__.py index fa4c9ffff..04187e129 100644 --- a/recbole/model/general_recommender/__init__.py +++ b/recbole/model/general_recommender/__init__.py @@ -29,3 +29,5 @@ from recbole.model.general_recommender.sgl import SGL from recbole.model.general_recommender.admmslim import ADMMSLIM from recbole.model.general_recommender.simplex import SimpleX +from recbole.model.general_recommender.diffrec import DiffRec +from recbole.model.general_recommender.ldiffrec import LDiffRec diff --git a/recbole/model/general_recommender/dgcf.py b/recbole/model/general_recommender/dgcf.py index d4a74f061..a20b670b7 100644 --- a/recbole/model/general_recommender/dgcf.py +++ b/recbole/model/general_recommender/dgcf.py @@ -335,7 +335,7 @@ def _create_centered_distance(X): r = torch.sum(X * X, dim=1, keepdim=True) # (N, 1) # (x^2 - 2xy + y^2) -> l2 distance between all vectors - value = r - 2 * torch.mm(X, X.T + r.T) + value = r - 2 * torch.mm(X, X.T) + r.T zero_value = torch.zeros_like(value) value = torch.where(value > 0.0, value, zero_value) D = torch.sqrt(value + 1e-8) diff --git a/recbole/model/general_recommender/diffrec.py b/recbole/model/general_recommender/diffrec.py new file mode 100644 index 000000000..965ea7f8c --- /dev/null +++ b/recbole/model/general_recommender/diffrec.py @@ -0,0 +1,634 @@ +# -*- coding: utf-8 -*- +# @Time : 2023/10/6 +# @Author : Enze Liu +# @Email : enzeeliu@foxmail.com + +r""" +DiffRec +################################################ +Reference: + Wenjie Wang et al. "Diffusion Recommender Model." in SIGIR 2023. + +Reference code: + https://github.com/YiyanXu/DiffRec +""" + +import enum +import math +import copy +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +from recbole.model.init import xavier_normal_initialization +from recbole.utils.enum_type import InputType +from recbole.model.abstract_recommender import AutoEncoderMixin, GeneralRecommender +from recbole.model.layers import MLPLayers +import typing + + +class ModelMeanType(enum.Enum): + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class DNN(nn.Module): + r""" + A deep neural network for the reverse diffusion preocess. + """ + + def __init__( + self, + dims: typing.List, + emb_size: int, + time_type="cat", + act_func="tanh", + norm=False, + dropout=0.5, + ): + super(DNN, self).__init__() + self.dims = dims + self.time_type = time_type + self.time_emb_dim = emb_size + self.norm = norm + + self.emb_layer = nn.Linear(self.time_emb_dim, self.time_emb_dim) + + if self.time_type == "cat": + # Concatenate timestep embedding with input + self.dims[0] += self.time_emb_dim + else: + raise ValueError( + "Unimplemented timestep embedding type %s" % self.time_type + ) + + self.mlp_layers = MLPLayers( + layers=self.dims, dropout=0, activation=act_func, last_activation=False + ) + self.drop = nn.Dropout(dropout) + + self.apply(xavier_normal_initialization) + + def forward(self, x, timesteps): + time_emb = timestep_embedding(timesteps, self.time_emb_dim).to(x.device) + emb = self.emb_layer(time_emb) + if self.norm: + x = F.normalize(x) + x = self.drop(x) + h = torch.cat([x, emb], dim=-1) + h = self.mlp_layers(h) + return h + + +class DiffRec(GeneralRecommender, AutoEncoderMixin): + r""" + DiffRec is a generative recommender model which infers users' interaction probabilities in a denoising manner. + Note that DiffRec simultaneously ranks all items for each user. + We implement the the DiffRec model with only user dataloader. + """ + input_type = InputType.LISTWISE + + def __init__(self, config, dataset): + super(DiffRec, self).__init__(config, dataset) + + if config["mean_type"] == "x0": + self.mean_type = ModelMeanType.START_X + elif config["mean_type"] == "eps": + self.mean_type = ModelMeanType.EPSILON + else: + raise ValueError("Unimplemented mean type %s" % config["mean_type"]) + self.time_aware = config["time-aware"] + self.w_max = config["w_max"] + self.w_min = config["w_min"] + self.build_histroy_items(dataset) + + self.noise_schedule = config["noise_schedule"] + self.noise_scale = config["noise_scale"] + self.noise_min = config["noise_min"] + self.noise_max = config["noise_max"] + self.steps = config["steps"] + self.beta_fixed = config["beta_fixed"] + self.emb_size = config["embedding_size"] + self.norm = config["norm"] # True or False + self.reweight = config["reweight"] # reweight the loss for different timesteps + self.sampling_noise = config[ + "sampling_noise" + ] # whether sample noise during predict + self.sampling_steps = config["sampling_steps"] + self.mlp_act_func = config["mlp_act_func"] + assert self.sampling_steps <= self.steps, "Too much steps in inference." + + self.history_num_per_term = config["history_num_per_term"] + self.Lt_history = torch.zeros( + self.steps, self.history_num_per_term, dtype=torch.float64 + ).to(self.device) + self.Lt_count = torch.zeros(self.steps, dtype=int).to(self.device) + + dims = [self.n_items] + config["dims_dnn"] + [self.n_items] + + self.mlp = DNN( + dims=dims, + emb_size=self.emb_size, + time_type="cat", + norm=self.norm, + act_func=self.mlp_act_func, + ).to(self.device) + + if self.noise_scale != 0.0: + self.betas = torch.tensor(self.get_betas(), dtype=torch.float64).to( + self.device + ) + if self.beta_fixed: + self.betas[ + 0 + ] = 0.00001 # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 + # The variance \beta_1 of the first step is fixed to a small constant to prevent overfitting. + assert len(self.betas.shape) == 1, "betas must be 1-D" + assert ( + len(self.betas) == self.steps + ), "num of betas must equal to diffusion steps" + assert (self.betas > 0).all() and ( + self.betas <= 1 + ).all(), "betas out of range" + + self.calculate_for_diffusion() + + def build_histroy_items(self, dataset): + r""" + Add time-aware reweighting to the original user-item interaction matrix when config['time-aware'] is True. + """ + if not self.time_aware: + super().build_histroy_items(dataset) + else: + inter_feat = copy.deepcopy(dataset.inter_feat) + inter_feat.sort(dataset.time_field) + user_ids, item_ids = ( + inter_feat[dataset.uid_field].numpy(), + inter_feat[dataset.iid_field].numpy(), + ) + + w_max = self.w_max + w_min = self.w_min + values = np.zeros(len(inter_feat)) + + row_num = dataset.user_num + row_ids, col_ids = user_ids, item_ids + + for uid in range(1, row_num + 1): + uindex = np.argwhere(user_ids == uid).flatten() + int_num = len(uindex) + weight = np.linspace(w_min, w_max, int_num) + values[uindex] = weight + + history_len = np.zeros(row_num, dtype=np.int64) + for row_id in row_ids: + history_len[row_id] += 1 + + max_inter_num = np.max(history_len) + col_num = max_inter_num + + history_matrix = np.zeros((row_num, col_num), dtype=np.int64) + history_value = np.zeros((row_num, col_num)) + history_len[:] = 0 + + for row_id, value, col_id in zip(row_ids, values, col_ids): + if history_len[row_id] >= col_num: + continue + history_matrix[row_id, history_len[row_id]] = col_id + history_value[row_id, history_len[row_id]] = value + history_len[row_id] += 1 + + self.history_item_id = torch.LongTensor(history_matrix) + self.history_item_value = torch.FloatTensor(history_value) + self.history_item_id = self.history_item_id.to(self.device) + self.history_item_value = self.history_item_value.to(self.device) + + def get_betas(self): + r""" + Given the schedule name, create the betas for the diffusion process. + """ + if self.noise_schedule == "linear" or self.noise_schedule == "linear-var": + start = self.noise_scale * self.noise_min + end = self.noise_scale * self.noise_max + if self.noise_schedule == "linear": + return np.linspace(start, end, self.steps, dtype=np.float64) + else: + return betas_from_linear_variance( + self.steps, np.linspace(start, end, self.steps, dtype=np.float64) + ) + elif self.noise_schedule == "cosine": + return betas_for_alpha_bar( + self.steps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + ) + # Deep Unsupervised Learning using Noneequilibrium Thermodynamics 2.4.1 + elif self.noise_schedule == "binomial": + ts = np.arange(self.steps) + betas = [1 / (self.steps - t + 1) for t in ts] + return betas + else: + raise NotImplementedError(f"unknown beta schedule: {self.noise_schedule}!") + + def calculate_for_diffusion(self): + r""" + Calculate the coefficients for the diffusion process. + """ + alphas = 1.0 - self.betas + # [alpha_{1}, ..., alpha_{1}*...*alpha_{T}] shape (steps,) + self.alphas_cumprod = torch.cumprod(alphas, axis=0).to(self.device) + # alpha_{t-1} + self.alphas_cumprod_prev = torch.cat( + [torch.tensor([1.0]).to(self.device), self.alphas_cumprod[:-1]] + ).to(self.device) + # alpha_{t+1} + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], torch.tensor([0.0]).to(self.device)] + ).to(self.device) + assert self.alphas_cumprod_prev.shape == (self.steps,) + + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) + + self.posterior_variance = ( + self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + + self.posterior_log_variance_clipped = torch.log( + torch.cat( + [self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]] + ) + ) + # Eq.10 coef for x_theta + self.posterior_mean_coef1 = ( + self.betas + * torch.sqrt(self.alphas_cumprod_prev) + / (1.0 - self.alphas_cumprod) + ) + # Eq.10 coef for x_t + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * torch.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def p_sample(self, x_start): + r""" + Generate users' interaction probabilities in a denoising manner. + Args: + x_start (torch.FloatTensor): the input tensor that contains user's history interaction matrix, + for DiffRec shape: [batch_size, n_items] + for LDiffRec shape: [batch_size, hidden_size] + Returns: + torch.FloatTensor: the interaction probabilities, + for DiffRec shape: [batch_size, n_items] + for LDiffRec shape: [batch_size, hidden_size] + """ + steps = self.sampling_steps + if steps == 0: + x_t = x_start + else: + t = torch.tensor([steps - 1] * x_start.shape[0]).to(x_start.device) + x_t = self.q_sample(x_start, t) + + indices = list(range(self.steps))[::-1] + + if self.noise_scale == 0.0: + for i in indices: + t = torch.tensor([i] * x_t.shape[0]).to(x_start.device) + x_t = self.mlp(x_t, t) + return x_t + + for i in indices: + t = torch.tensor([i] * x_t.shape[0]).to(x_start.device) + out = self.p_mean_variance(x_t, t) + if self.sampling_noise: + noise = torch.randn_like(x_t) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))) + ) # no noise when t == 0 + x_t = ( + out["mean"] + + nonzero_mask * torch.exp(0.5 * out["log_variance"]) * noise + ) + else: + x_t = out["mean"] + return x_t + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + x_start = self.get_rating_matrix(user) + scores = self.p_sample(x_start) + return scores + + def predict(self, interaction): + item = interaction[self.ITEM_ID] + x_t = self.full_sort_predict(interaction) + scores = x_t[:, item] + return scores + + def calculate_loss(self, interaction): + user = interaction[self.USER_ID] + x_start = self.get_rating_matrix(user) + + batch_size, device = x_start.size(0), x_start.device + ts, pt = self.sample_timesteps(batch_size, device, "importance") + noise = torch.randn_like(x_start) + if self.noise_scale != 0.0: + x_t = self.q_sample(x_start, ts, noise) + else: + x_t = x_start + + model_output = self.mlp(x_t, ts) + target = { + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.mean_type] + + assert model_output.shape == target.shape == x_start.shape + + mse = mean_flat((target - model_output) ** 2) + + reloss = self.reweight_loss(x_start, x_t, mse, ts, target, model_output, device) + self.update_Lt_history(ts, reloss) + + # importance sampling + reloss /= pt + mean_loss = reloss.mean() + return mean_loss + + def reweight_loss(self, x_start, x_t, mse, ts, target, model_output, device): + if self.reweight: + if self.mean_type == ModelMeanType.START_X: + # Eq.11 + weight = self.SNR(ts - 1) - self.SNR(ts) + # Eq.12 + weight = torch.where((ts == 0), 1.0, weight) + loss = mse + elif self.mean_type == ModelMeanType.EPSILON: + weight = (1 - self.alphas_cumprod[ts]) / ( + (1 - self.alphas_cumprod_prev[ts]) ** 2 * (1 - self.betas[ts]) + ) + weight = torch.where((ts == 0), 1.0, weight) + likelihood = mean_flat( + (x_start - self._predict_xstart_from_eps(x_t, ts, model_output)) + ** 2 + / 2.0 + ) + loss = torch.where((ts == 0), likelihood, mse) + else: + weight = torch.tensor([1.0] * len(target)).to(device) + loss = mse + reloss = weight * loss + return reloss + + def update_Lt_history(self, ts, reloss): + # update Lt_history & Lt_count + for t, loss in zip(ts, reloss): + if self.Lt_count[t] == self.history_num_per_term: + Lt_history_old = self.Lt_history.clone() + self.Lt_history[t, :-1] = Lt_history_old[t, 1:] + self.Lt_history[t, -1] = loss.detach() + else: + try: + self.Lt_history[t, self.Lt_count[t]] = loss.detach() + self.Lt_count[t] += 1 + except: + print(t) + print(self.Lt_count[t]) + print(loss) + raise ValueError + + def sample_timesteps( + self, batch_size, device, method="uniform", uniform_prob=0.001 + ): + if method == "importance": # importance sampling + if not (self.Lt_count == self.history_num_per_term).all(): + return self.sample_timesteps(batch_size, device, method="uniform") + + Lt_sqrt = torch.sqrt(torch.mean(self.Lt_history**2, axis=-1)) + pt_all = Lt_sqrt / torch.sum(Lt_sqrt) + pt_all *= 1 - uniform_prob + pt_all += uniform_prob / len(pt_all) # ensure the least prob > uniform_prob + + assert pt_all.sum(-1) - 1.0 < 1e-5 + + t = torch.multinomial(pt_all, num_samples=batch_size, replacement=True) + pt = pt_all.gather(dim=0, index=t) * len(pt_all) + + return t, pt + + elif method == "uniform": # uniform sampling + t = torch.randint(0, self.steps, (batch_size,), device=device).long() + pt = torch.ones_like(t).float() + + return t, pt + + else: + raise ValueError + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + self._extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + * x_start + + self._extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_start.shape + ) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + r""" + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + self._extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + self._extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = self._extract_into_tensor( + self.posterior_variance, t, x_t.shape + ) + posterior_log_variance_clipped = self._extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t): + r""" + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + """ + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = self.mlp(x, t) + + model_variance = self.posterior_variance + model_log_variance = self.posterior_log_variance_clipped + + model_variance = self._extract_into_tensor(model_variance, t, x.shape) + model_log_variance = self._extract_into_tensor(model_log_variance, t, x.shape) + + if self.mean_type == ModelMeanType.START_X: + pred_xstart = model_output + elif self.mean_type == ModelMeanType.EPSILON: + pred_xstart = self._predict_xstart_from_eps(x, t, eps=model_output) + else: + raise NotImplementedError(self.mean_type) + + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + self._extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) + * x_t + - self._extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * eps + ) + + def SNR(self, t): + r""" + Compute the signal-to-noise ratio for a single timestep. + """ + self.alphas_cumprod = self.alphas_cumprod.to(t.device) + return self.alphas_cumprod[t] / (1 - self.alphas_cumprod[t]) + + def _extract_into_tensor(self, arr, timesteps, broadcast_shape): + r""" + Extract values from a 1-D torch tensor for a batch of indices. + + Args: + arr (torch.Tensor): the 1-D torch tensor. + timesteps (torch.Tensor): a tensor of indices into the array to extract. + broadcast_shape (torch.Size): a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + torch.Tensor: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + # res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + arr = arr.to(timesteps.device) + res = arr[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def betas_from_linear_variance(steps, variance, max_beta=0.999): + alpha_bar = 1 - variance + betas = [] + betas.append(1 - alpha_bar[0]) + for i in range(1, steps): + betas.append(min(1 - alpha_bar[i] / alpha_bar[i - 1], max_beta)) + return np.array(betas) + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + r""" + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + Args: + num_diffusion_timesteps (int): the number of betas to produce. + alpha_bar (Callable): a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + max_beta (int): the maximum beta to use; use values lower than 1 to + prevent singularities. + Returns: + np.ndarray: a 1-D array of beta values. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + r""" + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def mean_flat(tensor): + r""" + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def timestep_embedding(timesteps, dim, max_period=10000): + r""" + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. (N,) + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to( + timesteps.device + ) # shape (dim//2,) + args = timesteps[:, None].float() * freqs[None] # (N, dim//2) + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # (N, (dim//2)*2) + if dim % 2: + # zero pad in the last dimension to ensure shape (N, dim) + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/recbole/model/general_recommender/ldiffrec.py b/recbole/model/general_recommender/ldiffrec.py new file mode 100644 index 000000000..c9b41e2e3 --- /dev/null +++ b/recbole/model/general_recommender/ldiffrec.py @@ -0,0 +1,344 @@ +# -*- coding: utf-8 -*- +# @Time : 2023/10/6 +# @Author : Enze Liu +# @Email : enzeeliu@foxmail.com + +r""" +DiffRec +################################################ +Reference: + Wenjie Wang et al. "Diffusion Recommender Model." in SIGIR 2023. + +Reference code: + https://github.com/YiyanXu/DiffRec +""" + +import os +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +from recbole.model.init import xavier_normal_initialization +from recbole.model.layers import MLPLayers +from recbole.model.general_recommender.diffrec import ( + DiffRec, + DNN, + ModelMeanType, + mean_flat, +) +from kmeans_pytorch import kmeans + + +class AutoEncoder(nn.Module): + r""" + Guassian Diffusion for large-scale recommendation. + """ + + def __init__( + self, + item_emb, + n_cate, + in_dims, + out_dims, + device, + act_func, + reparam=True, + dropout=0.1, + ): + super(AutoEncoder, self).__init__() + + self.item_emb = item_emb + self.n_cate = n_cate + self.in_dims = in_dims + self.out_dims = out_dims + self.act_func = act_func + self.n_item = len(item_emb) + self.reparam = reparam + self.dropout = nn.Dropout(dropout) + + if n_cate == 1: # no clustering + in_dims_temp = ( + [self.n_item + 1] + self.in_dims[:-1] + [self.in_dims[-1] * 2] + ) + out_dims_temp = [self.in_dims[-1]] + self.out_dims + [self.n_item + 1] + + self.encoder = MLPLayers(in_dims_temp, activation=self.act_func) + self.decoder = MLPLayers( + out_dims_temp, activation=self.act_func, last_activation=False + ) + + else: + self.cluster_ids, _ = kmeans( + X=item_emb, num_clusters=n_cate, distance="euclidean", device=device + ) + # cluster_ids(labels): [0, 1, 2, 2, 1, 0, 0, ...] + category_idx = [] + for i in range(n_cate): + idx = np.argwhere(self.cluster_ids.numpy() == i).flatten().tolist() + category_idx.append(torch.tensor(idx, dtype=int) + 1) + self.category_idx = category_idx # [cate1: [iid1, iid2, ...], cate2: [iid3, iid4, ...], cate3: [iid5, iid6, ...]] + self.category_map = torch.cat(tuple(category_idx), dim=-1) # map + self.category_len = [ + len(self.category_idx[i]) for i in range(n_cate) + ] # item num in each category + print("category length: ", self.category_len) + assert sum(self.category_len) == self.n_item + + ##### Build the Encoder and Decoder ##### + encoders = [] + decode_dim = [] + for i in range(n_cate): + if i == n_cate - 1: + latent_dims = list(self.in_dims - np.array(decode_dim).sum(axis=0)) + else: + latent_dims = [ + int(self.category_len[i] / self.n_item * self.in_dims[j]) + for j in range(len(self.in_dims)) + ] + latent_dims = [ + latent_dims[j] if latent_dims[j] != 0 else 1 + for j in range(len(self.in_dims)) + ] + in_dims_temp = ( + [self.category_len[i]] + latent_dims[:-1] + [latent_dims[-1] * 2] + ) + encoders.append(MLPLayers(in_dims_temp, activation=self.act_func)) + decode_dim.append(latent_dims) + + self.encoder = nn.ModuleList(encoders) + print("Latent dims of each category: ", decode_dim) + + self.decode_dim = [decode_dim[i][::-1] for i in range(len(decode_dim))] + + if len(out_dims) == 0: # one-layer decoder: [encoder_dim_sum, n_item] + out_dim = self.in_dims[-1] + self.decoder = MLPLayers([out_dim, self.n_item], activation=None) + else: # multi-layer decoder: [encoder_dim, hidden_size, cate_num] + # decoder_modules = [[] for _ in range(n_cate)] + decoders = [] + for i in range(n_cate): + out_dims_temp = self.decode_dim[i] + [self.category_len[i]] + decoders.append( + MLPLayers( + out_dims_temp, + activation=self.act_func, + last_activation=False, + ) + ) + self.decoder = nn.ModuleList(decoders) + + self.apply(xavier_normal_initialization) + + def Encode(self, batch): + batch = self.dropout(batch) + if self.n_cate == 1: + hidden = self.encoder(batch) + mu = hidden[:, : self.in_dims[-1]] + logvar = hidden[:, self.in_dims[-1] :] + + if self.training and self.reparam: + latent = self.reparamterization(mu, logvar) + else: + latent = mu + + kl_divergence = -0.5 * torch.mean( + torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) + ) + + return batch, latent, kl_divergence + + else: + batch_cate = [] + for i in range(self.n_cate): + batch_cate.append(batch[:, self.category_idx[i]]) + # [batch_size, n_items] -> [[batch_size, n1_items], [batch_size, n2_items], [batch_size, n3_items]] + latent_mu = [] + latent_logvar = [] + for i in range(self.n_cate): + hidden = self.encoder[i](batch_cate[i]) + latent_mu.append(hidden[:, : self.decode_dim[i][0]]) + latent_logvar.append(hidden[:, self.decode_dim[i][0] :]) + # latent: [[batch_size, latent_size1], [batch_size, latent_size2], [batch_size, latent_size3]] + + mu = torch.cat(tuple(latent_mu), dim=-1) + logvar = torch.cat(tuple(latent_logvar), dim=-1) + if self.training and self.reparam: + latent = self.reparamterization(mu, logvar) + else: + latent = mu + + kl_divergence = -0.5 * torch.mean( + torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) + ) + + return torch.cat(tuple(batch_cate), dim=-1), latent, kl_divergence + + def reparamterization(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std).add_(mu) + + def Decode(self, batch): + if len(self.out_dims) == 0 or self.n_cate == 1: # one-layer decoder + return self.decoder(batch) + else: + batch_cate = [] + start = 0 + for i in range(self.n_cate): + end = start + self.decode_dim[i][0] + batch_cate.append(batch[:, start:end]) + start = end + pred_cate = [] + for i in range(self.n_cate): + pred_cate.append(self.decoder[i](batch_cate[i])) + pred = torch.cat(tuple(pred_cate), dim=-1) + + return pred + + +class LDiffRec(DiffRec): + r""" + L-DiffRec clusters items into groups, compresses the interaction vector over each group into a + low-dimensional latent vector via a group-specific VAE, and conducts the forward and reverse + diffusion processes in the latent space. + """ + + def __init__(self, config, dataset): + super(LDiffRec, self).__init__(config, dataset) + self.n_cate = config["n_cate"] + self.reparam = config["reparam"] + self.ae_act_func = config["ae_act_func"] + self.in_dims = config["in_dims"] + self.out_dims = config["out_dims"] + + # control loss in training + self.update_count = 0 + self.update_count_vae = 0 + self.lamda = config["lamda"] + self.anneal_cap = config["anneal_cap"] + self.anneal_steps = config["anneal_steps"] + self.vae_anneal_cap = config["vae_anneal_cap"] + self.vae_anneal_steps = config["vae_anneal_steps"] + + out_dims = self.out_dims + in_dims = self.in_dims[::-1] + emb_path = os.path.join(dataset.dataset_path, f"item_emb.npy") + if self.n_cate > 1: + if not os.path.exists(emb_path): + self.logger.exception( + "The item embedding file must be given when n_cate>1." + ) + item_emb = torch.from_numpy(np.load(emb_path, allow_pickle=True)) + else: + item_emb = torch.zeros((self.n_items - 1, 64)) + self.autoencoder = AutoEncoder( + item_emb, + self.n_cate, + in_dims, + out_dims, + self.device, + self.ae_act_func, + self.reparam, + ).to(self.device) + + self.latent_size = in_dims[-1] + dims = [self.latent_size] + config["dims_dnn"] + [self.latent_size] + self.mlp = DNN( + dims=dims, + emb_size=self.emb_size, + time_type="cat", + norm=self.norm, + act_func=self.mlp_act_func, + ).to(self.device) + + def calculate_loss(self, interaction): + user = interaction[self.USER_ID] + batch = self.get_rating_matrix(user) + + batch_cate, batch_latent, vae_kl = self.autoencoder.Encode(batch) + + # calculate loss in diffusion + batch_size, device = batch_latent.size(0), batch_latent.device + ts, pt = self.sample_timesteps(batch_size, device, "importance") + noise = torch.randn_like(batch_latent) + if self.noise_scale != 0.0: + x_t = self.q_sample(batch_latent, ts, noise) + else: + x_t = batch_latent + + model_output = self.mlp(x_t, ts) + target = { + ModelMeanType.START_X: batch_latent, + ModelMeanType.EPSILON: noise, + }[self.mean_type] + + assert model_output.shape == target.shape == batch_latent.shape + + mse = mean_flat((target - model_output) ** 2) + + reloss = self.reweight_loss( + batch_latent, x_t, mse, ts, target, model_output, device + ) + + if self.mean_type == ModelMeanType.START_X: + batch_latent_recon = model_output + else: + batch_latent_recon = self._predict_xstart_from_eps(x_t, ts, model_output) + + self.update_Lt_history(ts, reloss) + + diff_loss = (reloss / pt).mean() + + batch_recon = self.autoencoder.Decode(batch_latent_recon) + + if self.anneal_steps > 0: + lamda = max( + (1.0 - self.update_count / self.anneal_steps) * self.lamda, + self.anneal_cap, + ) + else: + lamda = max(self.lamda, self.anneal_cap) + + if self.vae_anneal_steps > 0: + anneal = min( + self.vae_anneal_cap, 1.0 * self.update_count_vae / self.vae_anneal_steps + ) + else: + anneal = self.vae_anneal_cap + + self.update_count_vae += 1 + self.update_count += 1 + vae_loss = compute_loss(batch_recon, batch_cate) + anneal * vae_kl + + loss = lamda * diff_loss + vae_loss + + return loss + + def full_sort_predict(self, interaction): + user = interaction[self.USER_ID] + batch = self.get_rating_matrix(user) + _, batch_latent, _ = self.autoencoder.Encode(batch) + batch_latent_recon = super(LDiffRec, self).p_sample(batch_latent) + prediction = self.autoencoder.Decode( + batch_latent_recon + ) # [batch_size, n1_items + n2_items + n3_items] + if self.n_cate > 1: + transform = torch.zeros((prediction.shape[0], prediction.shape[1] + 1)).to( + prediction.device + ) + transform[:, self.autoencoder.category_map] = prediction + else: + transform = prediction + return transform + + def predict(self, interaction): + item = interaction[self.ITEM_ID] + x_t = self.full_sort_predict(interaction) + scores = x_t[:, item] + return scores + + +def compute_loss(recon_x, x): + return -torch.mean( + torch.sum(F.log_softmax(recon_x, 1) * x, -1) + ) # multinomial log likelihood in MultVAE diff --git a/recbole/model/general_recommender/pop.py b/recbole/model/general_recommender/pop.py index 2984d23c6..a4d75d39c 100644 --- a/recbole/model/general_recommender/pop.py +++ b/recbole/model/general_recommender/pop.py @@ -6,6 +6,10 @@ # @Time : 2020/11/9 # @Author : Zihan Lin # @Email : zhlin@ruc.edu.cn +# UPDATE +# @Time :2023/9/21 +# @Author : Kesha Ou +# @Email :1582706091@qq.com r""" Pop @@ -43,7 +47,7 @@ def calculate_loss(self, interaction): self.max_cnt = torch.max(self.item_cnt, dim=0)[0] - return torch.nn.Parameter(torch.zeros(1)) + return torch.nn.Parameter(torch.zeros(1)).to(self.device) def predict(self, interaction): item = interaction[self.ITEM_ID] diff --git a/recbole/model/layers.py b/recbole/model/layers.py index 6defc637b..2179b91bb 100644 --- a/recbole/model/layers.py +++ b/recbole/model/layers.py @@ -52,7 +52,13 @@ class MLPLayers(nn.Module): """ def __init__( - self, layers, dropout=0.0, activation="relu", bn=False, init_method=None + self, + layers, + dropout=0.0, + activation="relu", + bn=False, + init_method=None, + last_activation=True, ): super(MLPLayers, self).__init__() self.layers = layers @@ -72,7 +78,8 @@ def __init__( activation_func = activation_layer(self.activation, output_size) if activation_func is not None: mlp_modules.append(activation_func) - + if self.activation is not None and not last_activation: + mlp_modules.pop() self.mlp_layers = nn.Sequential(*mlp_modules) if self.init_method is not None: self.apply(self.init_weights) diff --git a/recbole/model/sequential_recommender/bert4rec.py b/recbole/model/sequential_recommender/bert4rec.py index 0a759d6ff..936e57d77 100644 --- a/recbole/model/sequential_recommender/bert4rec.py +++ b/recbole/model/sequential_recommender/bert4rec.py @@ -3,6 +3,11 @@ # @Author : Hui Wang # @Email : hui.wang@ruc.edu.cn +# UPDATE +# @Time : 2023/9/4 +# @Author : Enze Liu +# @Email : enzeeliu@foxmail.com + r""" BERT4Rec ################################################ @@ -75,6 +80,10 @@ def __init__(self, config, dataset): self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.dropout = nn.Dropout(self.hidden_dropout_prob) + self.output_ffn = nn.Linear(self.hidden_size, self.hidden_size) + self.output_gelu = nn.GELU() + self.output_ln = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.output_bias = nn.Parameter(torch.zeros(self.n_items)) # we only need compute the loss at the masked position try: @@ -124,7 +133,9 @@ def forward(self, item_seq): trm_output = self.trm_encoder( input_emb, extended_attention_mask, output_all_encoded_layers=True ) - output = trm_output[-1] + ffn_output = self.output_ffn(trm_output[-1]) + ffn_output = self.output_gelu(ffn_output) + output = self.output_ln(ffn_output) return output # [B L H] def multi_hot_embed(self, masked_index, max_length): @@ -172,8 +183,14 @@ def calculate_loss(self, interaction): if self.loss_type == "BPR": pos_items_emb = self.item_embedding(pos_items) # [B mask_len H] neg_items_emb = self.item_embedding(neg_items) # [B mask_len H] - pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B mask_len] - neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B mask_len] + pos_score = ( + torch.sum(seq_output * pos_items_emb, dim=-1) + + self.output_bias[pos_items] + ) # [B mask_len] + neg_score = ( + torch.sum(seq_output * neg_items_emb, dim=-1) + + self.output_bias[neg_items] + ) # [B mask_len] targets = (masked_index > 0).float() loss = -torch.sum( torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets @@ -183,8 +200,9 @@ def calculate_loss(self, interaction): elif self.loss_type == "CE": loss_fct = nn.CrossEntropyLoss(reduction="none") test_item_emb = self.item_embedding.weight[: self.n_items] # [item_num H] - logits = torch.matmul( - seq_output, test_item_emb.transpose(0, 1) + logits = ( + torch.matmul(seq_output, test_item_emb.transpose(0, 1)) + + self.output_bias ) # [B mask_len item_num] targets = (masked_index > 0).float().view(-1) # [B*mask_len] @@ -204,7 +222,9 @@ def predict(self, interaction): seq_output = self.forward(item_seq) seq_output = self.gather_indexes(seq_output, item_seq_len - 1) # [B H] test_item_emb = self.item_embedding(test_item) - scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] + scores = (torch.mul(seq_output, test_item_emb)).sum(dim=1) + self.output_bias[ + test_item + ] # [B] return scores def full_sort_predict(self, interaction): @@ -216,7 +236,7 @@ def full_sort_predict(self, interaction): test_items_emb = self.item_embedding.weight[ : self.n_items ] # delete masked token - scores = torch.matmul( - seq_output, test_items_emb.transpose(0, 1) + scores = ( + torch.matmul(seq_output, test_items_emb.transpose(0, 1)) + self.output_bias ) # [B, item_num] return scores diff --git a/recbole/properties/dataset/url.yaml b/recbole/properties/dataset/url.yaml index f0ace4736..6aef872d7 100644 --- a/recbole/properties/dataset/url.yaml +++ b/recbole/properties/dataset/url.yaml @@ -56,6 +56,7 @@ amazon-toys-games-18: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatas amazon-video-games-18: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Amazon_ratings/Amazon2018/Amazon_Video_Games.zip anime: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Anime/anime.zip avazu: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Avazu/avazu.zip +amazon-m2: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Amazon_M2/amazon_m2.zip beeradvocate: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/BeerAdvocate/BeerAdvocate.zip behance: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Behance/Behance.zip book-crossing: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Book-Crossing/book-crossing.zip diff --git a/recbole/properties/model/DiffRec.yaml b/recbole/properties/model/DiffRec.yaml new file mode 100644 index 000000000..1c8113e78 --- /dev/null +++ b/recbole/properties/model/DiffRec.yaml @@ -0,0 +1,20 @@ +# params for the diffusion +noise_schedule: 'linear' # (str) The schedule for noise generating: [linear, linear-var, cosine, binomial] +noise_scale: 0.001 # (int) The scale for noise generating +noise_min: 0.0005 # (int) Noise lower bound for noise generating +noise_max: 0.005 # (int) Noise upper bound for noise generating +sampling_noise: False # (bool) Whether to use sampling noise +sampling_steps: 0 # (int) Steps of the forward process during inference +reweight: True # (bool) Assign different weight to different timestep or not +mean_type: 'x0' # (str) MeanType for diffusion: [x0, eps] +steps: 5 # (int) Diffusion steps +history_num_per_term: 10 # (int) The number of history items needed to calculate loss weight +beta_fixed: True # (bool) Whether to fix the variance of the first step to prevent overfitting + +# params for the model +dims_dnn: [300] # (list of int) The dims for the DNN +embedding_size: 10 # (int) Timestep embedding size +mlp_act_func: 'tanh' # (str) Activation function for MLP +time-aware: False # (bool) T-DiffRec or not +w_max: 1 # (int) The upper bound of the time-aware interaction weight +w_min: 0.1 # (int) The lower bound of the time-aware interaction weight diff --git a/recbole/properties/model/EulerNet.yaml b/recbole/properties/model/EulerNet.yaml new file mode 100644 index 000000000..b069aaf4d --- /dev/null +++ b/recbole/properties/model/EulerNet.yaml @@ -0,0 +1,6 @@ +embedding_size: 16 # (int) The embedding size of features. +order_list: [30] # (list) The order vectors of EulerNet +drop_ex: 0.3 # (float) The dropout rate for the modulus and phase +drop_im: 0.3 # (float) The dropout rate for the real and imaginary part +apply_norm: False # (bool) Whether perform the layer norm +reg_weight: 1e-5 # (float) The L2 regularization weight. \ No newline at end of file diff --git a/recbole/properties/model/LDiffRec.yaml b/recbole/properties/model/LDiffRec.yaml new file mode 100644 index 000000000..4820983a0 --- /dev/null +++ b/recbole/properties/model/LDiffRec.yaml @@ -0,0 +1,32 @@ +# params for autoencoder +n_cate: 1 # (int) Category num of items +reparam: True # (bool) Autoencoder with variational inference or not +in_dims: [300] # (list of int) The dims for the encoder +out_dims: [] # (list of int) The hidden dims for the decoder +ae_act_func: 'tanh' # (str) Activation function for AutoEncoder +lamda: 0.03 # (float) Hyper-parameter of multinomial log-likelihood for AE +anneal_cap: 0.005 # (float) The upper bound of the annealing weight +anneal_steps: 1000 # (int) The steps of annealing +vae_anneal_cap: 0.3 # (float) The upper bound of the VAE annealing weight +vae_anneal_steps: 200 # (int) The steps of VAE annealing + +# params for the diffusion +noise_schedule: 'linear' # (str) The schedule for noise generating: [linear, linear-var, cosine, binomial] +noise_scale: 0.1 # (int) The scale for noise generating +noise_min: 0.001 # (int) Noise lower bound for noise generating +noise_max: 0.005 # (int) Noise upper bound for noise generating +sampling_noise: False # (bool) Whether to use sampling noise +sampling_steps: 0 # (int) Steps of the forward process during inference +reweight: True # (bool) Assign different weight to different timestep or not +mean_type: 'x0' # (str) MeanType for diffusion: [x0, eps] +steps: 5 # (int) Diffusion steps +history_num_per_term: 10 # (int) The number of history items needed to calculate loss weight +beta_fixed: True # (bool) Whether to fix the variance of the first step to prevent overfitting + +# params for the model +dims_dnn: [300] # (list of int) The dims for the DNN +embedding_size: 10 # (int) Timestep embedding size +mlp_act_func: 'tanh' # (str) Activation function for MLP +time-aware: False # (bool) T-DiffRec or not +w_max: 1 # (int) The upper bound of the time-aware interaction weight +w_min: 0.1 # (int) The lower bound of the time-aware interaction weight \ No newline at end of file diff --git a/recbole/properties/model/lightgbm.yaml b/recbole/properties/model/lightgbm.yaml index c8296d765..93e5f0e5f 100644 --- a/recbole/properties/model/lightgbm.yaml +++ b/recbole/properties/model/lightgbm.yaml @@ -1,10 +1,8 @@ # Dataset convert_token_to_onehot: False # (bool) Whether to convert token type features into one-hot form. token_num_threhold: 10000 # (int) The threshold of one-hot conversion. -lgb_silent: False # (bool) Whether to print messages during construction. # Train -lgb_model: ~ # (file name of stored lgb model or 'Booster' instance) lgb_params: # (dict) Booster params. boosting: gbdt num_leaves: 90 @@ -15,7 +13,4 @@ lgb_params: # (dict) Booster params. lambda_l1: 0.1 metric: ['auc', 'binary_logloss'] force_row_wise: True -lgb_learning_rates: ~ # (list, callable or None) List of learning rates or a customized function. lgb_num_boost_round: 300 # (int) Number of boosting iterations. -lgb_early_stopping_rounds: ~ # (int or None) Activates early stopping. -lgb_verbose_eval: 100 # (bool or int) Requires at least one validation data to print evaluation metrics. \ No newline at end of file diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index b5863b67d..e25713da0 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -175,6 +175,8 @@ def _build_optimizer(self, **kwargs): if learner.lower() == "adam": optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "adamw": + optimizer = optim.AdamW(params, lr=learning_rate, weight_decay=weight_decay) elif learner.lower() == "sgd": optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay) elif learner.lower() == "adagrad": @@ -1150,18 +1152,11 @@ def __init__(self, config, model): super(LightGBMTrainer, self).__init__(config, model) self.lgb = __import__("lightgbm") - self.boost_model = config["lgb_model"] - self.silent = config["lgb_silent"] # train params self.params = config["lgb_params"] self.num_boost_round = config["lgb_num_boost_round"] self.evals = () - self.early_stopping_rounds = config["lgb_early_stopping_rounds"] - self.evals_result = {} - self.verbose_eval = config["lgb_verbose_eval"] - self.learning_rates = config["lgb_learning_rates"] - self.callbacks = None self.deval_data = self.deval_label = None self.eval_pred = self.eval_true = None @@ -1174,7 +1169,7 @@ def _interaction_to_lib_datatype(self, dataloader): dataset(lgb.Dataset): Data in the form of 'lgb.Dataset'. """ data, label = self._interaction_to_sparse(dataloader) - return self.lgb.Dataset(data=data, label=label, silent=self.silent) + return self.lgb.Dataset(data=data, label=label) def _train_at_once(self, train_data, valid_data): r""" @@ -1187,16 +1182,7 @@ def _train_at_once(self, train_data, valid_data): self.dvalid = self._interaction_to_lib_datatype(valid_data) self.evals = [self.dtrain, self.dvalid] self.model = self.lgb.train( - self.params, - self.dtrain, - self.num_boost_round, - self.evals, - early_stopping_rounds=self.early_stopping_rounds, - evals_result=self.evals_result, - verbose_eval=self.verbose_eval, - learning_rates=self.learning_rates, - init_model=self.boost_model, - callbacks=self.callbacks, + self.params, self.dtrain, self.num_boost_round, self.evals ) self.model.save_model(self.temp_file) diff --git a/requirements.txt b/requirements.txt index f92e2ec3a..ade8cf9fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ colorlog==4.7.2 colorama==0.4.4 tensorboard>=2.5.0 thop>=0.1.1.post2207130030 -ray>=1.13.0 +ray>=1.13.0, <=2.6.3 tabulate>=0.8.10 plotly>=4.0.0 texttable>=0.9.0 \ No newline at end of file diff --git a/run_example/recbole-using-all-items-for-prediction.ipynb b/run_example/recbole-using-all-items-for-prediction.ipynb index 33eca3bba..f23b5bb4f 100644 --- a/run_example/recbole-using-all-items-for-prediction.ipynb +++ b/run_example/recbole-using-all-items-for-prediction.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "f0100cdc", "metadata": { @@ -54,6 +55,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "bbb84f80", "metadata": { @@ -590,6 +592,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "d0f27239", "metadata": { @@ -761,6 +764,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "c1127367", "metadata": { @@ -856,6 +860,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "3a74f096", "metadata": { @@ -1099,6 +1104,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "e61e4796", "metadata": { @@ -1190,6 +1196,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "dfa2bf97", "metadata": { @@ -1252,6 +1259,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "352b35aa", "metadata": { @@ -1308,6 +1316,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "c9749297", "metadata": { @@ -1393,6 +1402,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "3a4ebb13", "metadata": { @@ -1419,6 +1429,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "0da61c25", "metadata": { @@ -1436,6 +1447,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "642300a0", "metadata": { @@ -1592,6 +1604,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "f22ab4c5", "metadata": { @@ -1660,6 +1673,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "0349a339", "metadata": { @@ -1736,6 +1750,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "61a9aaa8", "metadata": { @@ -1785,6 +1800,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "9d7071fb", "metadata": { @@ -1829,8 +1845,9 @@ "\n", "def add_last_item(old_interaction, last_item_id, max_len=50):\n", " new_seq_items = old_interaction['item_id_list'][-1]\n", - " if old_interaction['item_length'][-1].item() < max_len:\n", - " new_seq_items[old_interaction['item_length'][-1].item()] = last_item_id\n", + " item_length = old_interaction['item_length'][-1].item()\n", + " if item_length < max_len:\n", + " new_seq_items[item_length] = last_item_id\n", " else:\n", " new_seq_items = torch.roll(new_seq_items, -1)\n", " new_seq_items[-1] = last_item_id\n", @@ -1841,7 +1858,10 @@ " with torch.no_grad():\n", " uid_series = dataset.token2id(dataset.uid_field, [external_user_id])\n", " index = np.isin(dataset[dataset.uid_field].numpy(), uid_series)\n", - " input_interaction = dataset[index]\n", + "\n", + " # instead of passing in a bool array whose shape is the same with dataset,\n", + " # we just filter the corresponding index to pass in, then we get great performance improvements.\n", + " input_interaction = dataset[index.nonzero()] \n", " test = {\n", " 'item_id_list': add_last_item(input_interaction, \n", " input_interaction['item_id'][-1].item(), model.max_seq_length),\n", @@ -1900,6 +1920,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "9c7564dc", "metadata": { diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 9b7d33af8..672d8b659 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -239,6 +239,22 @@ def test_NCL(self): config_dict = {"model": "NCL", "num_clusters": 100} quick_test(config_dict) + def test_DiffRec(self): + config_dict = {"model": "DiffRec"} + quick_test(config_dict) + + def test_TDiffRec(self): + config_dict = {"model": "DiffRec", "time-aware": True} + quick_test(config_dict) + + def test_LDiffRec(self): + config_dict = {"model": "LDiffRec"} + quick_test(config_dict) + + def test_LTDiffRec(self): + config_dict = {"model": "LDiffRec", "time-aware": True} + quick_test(config_dict) + class TestContextRecommender(unittest.TestCase): # todo: more complex context information should be test, such as criteo dataset @@ -414,6 +430,13 @@ def test_kd_dagfm(self): } quick_test(config_dict) + def test_eulernet(self): + config_dict = { + "model": "EulerNet", + "threshold": {"rating": 4}, + } + quick_test(config_dict) + class TestSequentialRecommender(unittest.TestCase): def test_din(self):