Markov Chain Estimation with In-Context Learning

预备知识

核心思想

实验结果

20250813151043

20250813151810

个人测试

Loss 的变化曲线

20250818104245

Max_NUM_States

20250818141509

20250829103612

20250829103953

minlen

20250818153308

alpha

20250818170948

代码

# main.py


from typing import Dict, Tuple, Union

import torch
import freerec
from transformers import LlamaModel, LlamaConfig
from einops import rearrange

from sampler import TrainRandomWalkSource, ValidRandomWalkSource, CUTS

freerec.declare(version='1.0.1')

cfg = freerec.parser.Parser()
cfg.add_argument("--alpha", type=float, default=0.1)
cfg.add_argument("--min-num-states", type=int, default=30)
cfg.add_argument("--max-num-states", type=int, default=30)
cfg.add_argument("--minlen", type=int, default=1000)
cfg.add_argument("--maxlen", type=int, default=1000)

cfg.set_defaults(
    description="Markov",
    root="../../data",
    dataset='Amazon2014Beauty_550_LOU',
    epochs=1000,
    batch_size=500,
    optimizer='AdamW',
    lr=3e-4,
    weight_decay=0.01,
    which4best="Model",
    seed=1,
)
cfg.compile()


cfg.llama_config = LlamaConfig(
    vocab_size=0,
    hidden_size=256,
    intermediate_size=256,
    num_hidden_layers=4,
    num_attention_heads=2,
    max_position_embeddings=cfg.maxlen,
    tie_word_embeddings=True,
    attention_dropout=0.
)


class MarkovICL(freerec.models.SeqRecArch):

    def __init__(
        self, dataset: freerec.data.datasets.RecDataSet,
    ) -> None:
        super().__init__(dataset)


        self.model = LlamaModel(cfg.llama_config)

        self.criterion = freerec.criterions.CrossEntropy4Logits(reduction='mean')
        self.reset_parameters()

    def reset_parameters(self): ...

    def sure_trainpipe(self):
        return TrainRandomWalkSource(
            self.dataset.train(),
            datasize=10000,
            alpha=cfg.alpha,
            min_num_states=cfg.min_num_states,
            max_num_states=cfg.max_num_states,
            minlen=cfg.minlen,
            maxlen=cfg.maxlen,
        ).add_(
            offset=self.NUM_PADS, modified_fields=(self.ISeq, self.IPos)
        ).lpad_(
            cfg.maxlen, modified_fields=(self.ISeq, self.IPos),
            padding_value=self.PADDING_VALUE
        ).batch_(cfg.batch_size).tensor_()

    def sure_validpipe(self):
        return ValidRandomWalkSource(
            self.dataset.valid(),
            datasize=10000,
            alpha=cfg.alpha,
            min_num_states=cfg.min_num_states,
            max_num_states=cfg.max_num_states,
            minlen=cfg.minlen,
            maxlen=cfg.maxlen,
        ).add_(
            offset=self.NUM_PADS, modified_fields=(self.ISeq, self.IPos)
        ).lpad_(
            cfg.maxlen, modified_fields=(self.ISeq, self.IPos),
            padding_value=self.PADDING_VALUE
        ).batch_(cfg.batch_size).tensor_()

    def sure_testpipe(self):
        return ValidRandomWalkSource(
            self.dataset.test(),
            datasize=10000,
            alpha=cfg.alpha,
            min_num_states=cfg.min_num_states,
            max_num_states=cfg.max_num_states,
            minlen=cfg.minlen,
            maxlen=cfg.maxlen,
        ).add_(
            offset=self.NUM_PADS, modified_fields=(self.ISeq, self.IPos)
        ).lpad_(
            cfg.maxlen, modified_fields=(self.ISeq, self.IPos),
            padding_value=self.PADDING_VALUE
        ).batch_(cfg.batch_size).tensor_()

    # @staticmethod
    # def _ortho_vocab(B, V, D, device):
    #     _d = D // 2
    #     # Batched random orthogonal embeddings
    #     emb_dict = torch.randn(B, max(V, _d), _d, dtype=torch.float32, device=device)
    #     emb_dict, _ = torch.linalg.qr(emb_dict)
    #     emb_dict = emb_dict[:, :V, :_d]  # V vectors of size D
    #     # Now pad with zeros : B x V x D -> B x V x 2D
    #     emb_dict = torch.cat([emb_dict, torch.zeros(B, V, _d, device=device)], dim=-1)
    #     return emb_dict

    @staticmethod
    def _ortho_vocab(B, V, D, device):
        _d = D
        # _d = D // 2
        # Batched random orthogonal embeddings
        emb_dict = torch.randn(B, max(V, _d), _d, dtype=torch.float32, device=device)
        emb_dict, _ = torch.linalg.qr(emb_dict)
        emb_dict = emb_dict[:, :V, :_d]  # V vectors of size D
        # emb_dict = torch.cat(
        #     (emb_dict, torch.zeros_like(emb_dict)), dim=-1
        # )
        return emb_dict
    
    def encode(
        self, data: Dict[freerec.data.fields.Field, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        chains = data[self.ISeq]
        B, S = chains.shape
        voc = self._ortho_vocab(
            B=B, V=chains.max() + 1,
            D=self.model.config.hidden_size,
            device=self.device
        )

        row_index = torch.arange(chains.shape[0], device=self.device)
        row_index = row_index.view(-1, 1)
        emb = voc[row_index, chains]

        out = self.model(inputs_embeds=emb, output_attentions=False)

        return out.last_hidden_state, voc

    def fit(
        self, data: Dict[freerec.data.fields.Field, torch.Tensor]
    ) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
        hiddens, voc = self.encode(data)
        indices = data[self.ISeq] != self.PADDING_VALUE
        logits = torch.einsum("BMD,BND->BMN", hiddens, voc) # (B, M, N)
        logits = logits[indices]
        labels = data[self.IPos][indices] # (*,)
        rec_loss = self.criterion(logits, labels)

        return rec_loss

    def recommend_from_full(
        self, data: Dict[freerec.data.fields.Field, torch.Tensor]
    ) -> torch.Tensor:
        hiddens, voc = self.encode(data)
        hiddens = hiddens[:, -CUTS:, :]
        target = data[self.IPos][:, -CUTS:] # (B, M)
        logits = torch.einsum("BMD,BND->BMN", hiddens, voc) # (B, M, N)
        logits = rearrange(logits, "B M N -> B N M")

        return self.criterion(logits, target), data['empirical'].mean().item(), data['oracle'].mean().item()


class CoachForMarkov(freerec.launcher.Coach):

    def train_per_epoch(self, epoch: int):
        for data in self.dataloader:
            data = self.dict_to_device(data)
            loss = self.model(data)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
           
            self.monitor(
                loss.item(), 
                n=data[self.Size], reduction="mean", 
                mode='train', pool=['LOSS']
            )

    def set_other(self):
        self.register_metric(
            "MODEL", lambda x: x, best_caster=min
        )
        self.register_metric(
            "EMPIRICAL", lambda x: x, best_caster=min
        )
        self.register_metric(
            "ORACLE", lambda x: x, best_caster=min
        )

    def evaluate(self, epoch: int, step: int = -1, mode: str = 'valid'):
        for data in self.dataloader:
            bsz = data[self.Size]

            data = self.dict_to_device(data)
            model_loss, empirical_loss, oracle_loss  = self.model(data, ranking='full')
            self.monitor(
                model_loss,
                n=bsz, reduction="mean", mode=mode,
                pool=['Model']
            )
            self.monitor(
                empirical_loss,
                n=bsz, reduction="mean", mode=mode,
                pool=['EMPIRICAL']
            )
            self.monitor(
                oracle_loss,
                n=bsz, reduction="mean", mode=mode,
                pool=['ORACLE']
            )


def main():

    dataset: freerec.data.datasets.RecDataSet
    try:
        dataset = getattr(freerec.data.datasets, cfg.dataset)(root=cfg.root)
    except AttributeError:
        dataset = freerec.data.datasets.RecDataSet(cfg.root, cfg.dataset, tasktag=cfg.tasktag)

    model = MarkovICL(dataset)

    # datapipe
    trainpipe = model.sure_trainpipe()
    validpipe = model.sure_validpipe()
    testpipe = model.sure_testpipe()

    coach = CoachForMarkov(
        dataset=dataset,
        trainpipe=trainpipe,
        validpipe=validpipe,
        testpipe=testpipe,
        model=model,
        cfg=cfg
    )
    coach.fit()


if __name__ == "__main__":
    main()
# sampler.py

from typing import Iterable, Dict, Any, List

import numpy as np
import random
from freerec.data.tags import ITEM, ID, SEQUENCE, POSITIVE
from freerec.data.datasets.base import RecDataSet
from freerec.data.postprocessing.base import Source
from freerec.data.postprocessing.source import OrderedSource


CUTS = 20


class TrainRandomWalkSource(Source):

    def __init__(
        self, dataset: RecDataSet, datasize: int,
        alpha: float = 0.1,
        min_num_states: int = 30, max_num_states: int = 30, 
        minlen: int = 1000, maxlen: int = 1000
    ) -> None:
        super().__init__(dataset, tuple(), datasize, shuffle=False)

        self._rng = random.Random()

        self.alpha = alpha
        self.min_num_states = max(min_num_states, CUTS)
        self.max_num_states = max_num_states
        self.minlen = minlen
        self.maxlen = maxlen

        self.Item = self.fields[ITEM, ID]
        self.ISeq = self.Item.fork(SEQUENCE)
        self.IPos = self.Item.fork(POSITIVE)
    
    def sample_transition_matrix(self, num_states: int) -> np.ndarray:
        return np.random.dirichlet([self.alpha] * num_states, size=num_states)

    def sample_num_states(self):
        return self._rng.randint(self.min_num_states, self.max_num_states)

    def sample_chain_length(self):
        return self._rng.randint(self.minlen, self.maxlen)

    def estimate_transition_probability(self, chain: List[int], num_states: int):
        counts = np.zeros((num_states,))
        chain, x = chain[:-1], chain[-1]

        positions = np.where(chain == x)[0][:-1]
        positions += 1

        vals = chain[positions]
        np.add.at(counts, vals, 1)

        return (counts + self.alpha) / (counts.sum() + self.alpha * num_states)

    def cross_entropy_from_probs(self, probs: np.array, target: np.ndarray):
        # probs: (CUTS, NUM_STATES)
        probs[probs == 0] = 1.e-8

        target = target.copy()[:, None]
        probs = np.take_along_axis(probs, target, axis=1)
        return np.mean(-np.log(probs)).item()

    def sample_chain(self) -> List[int]:
        k = self.sample_num_states()
        P = self.sample_transition_matrix(num_states=k)
        n = self.sample_chain_length()
        cprobs = P.cumsum(axis=1)
        rands = np.random.rand(n)

        chain = np.zeros(n, dtype=int)
        for i in range(1, n):
            chain[i] = np.searchsorted(cprobs[chain[i - 1]], rands[i])
        seq, target = chain[:-1], chain[1:]
        s = len(seq) - CUTS
        estimation = np.stack([
            self.estimate_transition_probability(
                seq[:s+i], k
            )
            for i in range(1, CUTS + 1)
        ], axis=0)
        oracle = np.stack([
            P[seq[s+i]]
            for i in range(CUTS)
        ], axis=0)
        empirical_loss = self.cross_entropy_from_probs(estimation, target[-CUTS:])
        oracle_loss = self.cross_entropy_from_probs(oracle, target[-CUTS:])
        return seq.tolist(), target.tolist(), empirical_loss, oracle_loss

    def __iter__(self):
        for _ in self.launcher:
            seq, target, empirical, oracle = self.sample_chain()
            yield {
                self.ISeq: seq, self.IPos: target,
                'empirical': empirical,
                'oracle': oracle
            }


class ValidRandomWalkSource(OrderedSource):

    def __init__(
        self, dataset: RecDataSet, datasize: int,
        alpha: float = 0.1,
        min_num_states: int = 30, max_num_states: int = 30, 
        minlen: int = 1000, maxlen: int = 1000
    ) -> None:

        source = TrainRandomWalkSource(
            dataset, datasize, alpha, min_num_states, max_num_states, minlen, maxlen
        )

        super().__init__(dataset, list(source))
# test.py

from typing import Dict, Tuple, Union

import torch
import freerec
from transformers import LlamaModel, LlamaConfig
from einops import rearrange

from sampler import TrainRandomWalkSource, ValidRandomWalkSource, CUTS

freerec.declare(version='1.0.1')

cfg = freerec.parser.Parser()

cfg.add_argument("--path", type=str, default="./logs/Markov/Amazon2014Beauty_550_LOU/0817134602")
cfg.add_argument("--alpha", type=float, default=0.1)
cfg.add_argument("--min-num-states", type=int, default=30)
cfg.add_argument("--max-num-states", type=int, default=30)
cfg.add_argument("--minlen", type=int, default=1000)
cfg.add_argument("--maxlen", type=int, default=1000)

cfg.set_defaults(
    description="Markov",
    root="../../data",
    dataset='Amazon2014Beauty_550_LOU',
    epochs=1000,
    batch_size=500,
    optimizer='AdamW',
    lr=3e-4,
    weight_decay=0.01,
    which4best="Model",
    seed=1,
)
cfg.compile()

cfg.epochs = 1

cfg.llama_config = LlamaConfig(
    vocab_size=0,
    hidden_size=256,
    intermediate_size=256,
    num_hidden_layers=4,
    num_attention_heads=2,
    max_position_embeddings=cfg.maxlen,
    tie_word_embeddings=True,
    attention_dropout=0.
)


class MarkovICL(freerec.models.SeqRecArch):

    def __init__(
        self, dataset: freerec.data.datasets.RecDataSet,
    ) -> None:
        super().__init__(dataset)


        self.model = LlamaModel(cfg.llama_config)

        self.criterion = freerec.criterions.CrossEntropy4Logits(reduction='mean')
        self.reset_parameters()

    def reset_parameters(self): ...

    def sure_trainpipe(self):
        return TrainRandomWalkSource(
            self.dataset.train(),
            datasize=10000,
            alpha=cfg.alpha,
            min_num_states=cfg.min_num_states,
            max_num_states=cfg.max_num_states,
            minlen=cfg.minlen,
            maxlen=cfg.maxlen,
        ).add_(
            offset=self.NUM_PADS, modified_fields=(self.ISeq, self.IPos)
        ).lpad_(
            cfg.maxlen, modified_fields=(self.ISeq, self.IPos),
            padding_value=self.PADDING_VALUE
        ).batch_(cfg.batch_size).tensor_()

    def sure_validpipe(self):
        return ValidRandomWalkSource(
            self.dataset.valid(),
            datasize=10000,
            alpha=cfg.alpha,
            min_num_states=cfg.min_num_states,
            max_num_states=cfg.max_num_states,
            minlen=cfg.minlen,
            maxlen=cfg.maxlen,
        ).add_(
            offset=self.NUM_PADS, modified_fields=(self.ISeq, self.IPos)
        ).lpad_(
            cfg.maxlen, modified_fields=(self.ISeq, self.IPos),
            padding_value=self.PADDING_VALUE
        ).batch_(cfg.batch_size).tensor_()

    def sure_testpipe(self):
        return ValidRandomWalkSource(
            self.dataset.test(),
            datasize=10000,
            alpha=cfg.alpha,
            min_num_states=cfg.min_num_states,
            max_num_states=cfg.max_num_states,
            minlen=cfg.minlen,
            maxlen=cfg.maxlen,
        ).add_(
            offset=self.NUM_PADS, modified_fields=(self.ISeq, self.IPos)
        ).lpad_(
            cfg.maxlen, modified_fields=(self.ISeq, self.IPos),
            padding_value=self.PADDING_VALUE
        ).batch_(cfg.batch_size).tensor_()

    @staticmethod
    def _ortho_vocab(B, V, D, device):
        _d = D // 2
        # Batched random orthogonal embeddings
        emb_dict = torch.randn(B, max(V, _d), _d, dtype=torch.float32, device=device)
        emb_dict, _ = torch.linalg.qr(emb_dict)
        emb_dict = emb_dict[:, :V, :_d]  # V vectors of size D
        # Now pad with zeros : B x V x D -> B x V x 2D
        emb_dict = torch.cat([emb_dict, torch.zeros(B, V, _d, device=device)], dim=-1)
        return emb_dict
    
    def encode(
        self, data: Dict[freerec.data.fields.Field, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        chains = data[self.ISeq]
        B, S = chains.shape
        voc = self._ortho_vocab(
            B=B, V=chains.max() + 1,
            D=self.model.config.hidden_size,
            device=self.device
        )

        row_index = torch.arange(chains.shape[0], device=self.device)
        row_index = row_index.view(-1, 1)
        emb = voc[row_index, chains]

        out = self.model(inputs_embeds=emb, output_attentions=False)

        return out.last_hidden_state, voc

    def fit(
        self, data: Dict[freerec.data.fields.Field, torch.Tensor]
    ) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
        hiddens, voc = self.encode(data)
        indices = data[self.ISeq] != self.PADDING_VALUE
        logits = torch.einsum("BMD,BND->BMN", hiddens, voc) # (B, M, N)
        logits = logits[indices]
        labels = data[self.IPos][indices] # (*,)
        rec_loss = self.criterion(logits, labels)

        return rec_loss

    def recommend_from_full(
        self, data: Dict[freerec.data.fields.Field, torch.Tensor]
    ) -> torch.Tensor:
        hiddens, voc = self.encode(data)
        hiddens = hiddens[:, -CUTS:, :]
        target = data[self.IPos][:, -CUTS:] # (B, M)
        logits = torch.einsum("BMD,BND->BMN", hiddens, voc) # (B, M, N)
        logits = rearrange(logits, "B M N -> B N M")

        return self.criterion(logits, target), data['empirical'].mean().item(), data['oracle'].mean().item()


class CoachForMarkov(freerec.launcher.Coach):

    def train_per_epoch(self, epoch: int):
        ...

    def set_other(self):
        self.register_metric(
            "MODEL", lambda x: x, best_caster=min
        )
        self.register_metric(
            "EMPIRICAL", lambda x: x, best_caster=min
        )
        self.register_metric(
            "ORACLE", lambda x: x, best_caster=min
        )

    def evaluate(self, epoch: int, step: int = -1, mode: str = 'valid'):
        for data in self.dataloader:
            bsz = data[self.Size]

            data = self.dict_to_device(data)
            model_loss, empirical_loss, oracle_loss  = self.model(data, ranking='full')
            self.monitor(
                model_loss,
                n=bsz, reduction="mean", mode=mode,
                pool=['Model']
            )
            self.monitor(
                empirical_loss,
                n=bsz, reduction="mean", mode=mode,
                pool=['EMPIRICAL']
            )
            self.monitor(
                oracle_loss,
                n=bsz, reduction="mean", mode=mode,
                pool=['ORACLE']
            )


def main():

    dataset: freerec.data.datasets.RecDataSet
    try:
        dataset = getattr(freerec.data.datasets, cfg.dataset)(root=cfg.root)
    except AttributeError:
        dataset = freerec.data.datasets.RecDataSet(cfg.root, cfg.dataset, tasktag=cfg.tasktag)

    model = MarkovICL(dataset)

    # datapipe
    trainpipe = model.sure_trainpipe()
    validpipe = model.sure_validpipe()
    testpipe = model.sure_testpipe()

    coach = CoachForMarkov(
        dataset=dataset,
        trainpipe=trainpipe,
        validpipe=validpipe,
        testpipe=testpipe,
        model=model,
        cfg=cfg
    )
    coach.load(cfg.path, filename="best.pt")
    coach.fit()


if __name__ == "__main__":
    main()

参考文献

  1. Lepage S., Mary J. and Picard D. Markov Chain Estimation with In-Context Learning. arXiv, 2025. [PDF] [Code]