From 0e3b1b63a91f36d33921d00b465e3e917290184f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Dec 2025 19:12:34 +0200 Subject: [PATCH] gotd/updates: stop listening to channel on ChannelForbidden/Invalid --- pkg/gotd/telegram/updates/manager.go | 3 ++ pkg/gotd/telegram/updates/state.go | 42 ++++++++++++++++------ pkg/gotd/telegram/updates/state_channel.go | 32 +++++++++++++---- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/pkg/gotd/telegram/updates/manager.go b/pkg/gotd/telegram/updates/manager.go index 3294ec44..64616ac4 100644 --- a/pkg/gotd/telegram/updates/manager.go +++ b/pkg/gotd/telegram/updates/manager.go @@ -106,6 +106,9 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption AccessHash int64 }) if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error { + if pts == -1 { + return nil + } hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID) if err != nil { return errors.Wrap(err, "get channel access hash") diff --git a/pkg/gotd/telegram/updates/state.go b/pkg/gotd/telegram/updates/state.go index 2ef7c6b6..a48abd5e 100644 --- a/pkg/gotd/telegram/updates/state.go +++ b/pkg/gotd/telegram/updates/state.go @@ -3,6 +3,7 @@ package updates import ( "context" "fmt" + "sync" "time" "github.com/go-faster/errors" @@ -44,7 +45,8 @@ type internalState struct { idleTimeout *time.Timer // Channel states. - channels map[int64]*channelState + channels map[int64]*channelState + channelsLock sync.Mutex // Immutable fields. client API @@ -117,11 +119,10 @@ func newState(ctx context.Context, cfg stateConfig) *internalState { }) for id, info := range cfg.Channels { - state := s.newChannelState(id, info.AccessHash, info.Pts) - s.channels[id] = state - s.wg.Go(func() error { - return state.Run(ctx) - }) + if info.Pts == -1 { + continue + } + s.createAndRunChannelState(ctx, id, info.AccessHash, info.Pts) } return s @@ -332,7 +333,9 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date return nil } + s.channelsLock.Lock() state, ok := s.channels[channelID] + s.channelsLock.Unlock() if !ok { accessHash, found, err := s.hasher.GetChannelAccessHash(context.Background(), s.selfID, channelID) if err != nil { @@ -359,6 +362,9 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date } localPts, found, err := s.storage.GetChannelPts(ctx, s.selfID, channelID) + if localPts == -1 { + found = false + } if err != nil { localPts = pts - ptsCount s.log.Error("GetChannelPts error", zap.Error(err)) @@ -371,16 +377,30 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date } } - state = s.newChannelState(channelID, accessHash, localPts) - s.channels[channelID] = state - s.wg.Go(func() error { - return state.Run(ctx) - }) + state = s.createAndRunChannelState(ctx, channelID, accessHash, localPts) } return state.Push(ctx, cu) } +func (s *internalState) createAndRunChannelState(ctx context.Context, channelID, accessHash int64, initialPts int) (state *channelState) { + state = s.newChannelState(channelID, accessHash, initialPts) + s.channelsLock.Lock() + s.channels[channelID] = state + s.channelsLock.Unlock() + s.wg.Go(func() error { + err := state.Run(ctx) + if errors.Is(err, ErrRemoveChannelState) { + s.channelsLock.Lock() + delete(s.channels, channelID) + s.channelsLock.Unlock() + s.log.Info("Removed channel state due to error", zap.Int64("channel_id", channelID)) + } + return err + }) + return state +} + func (s *internalState) newChannelState(channelID, accessHash int64, initialPts int) *channelState { return newChannelState(channelStateConfig{ Out: s.internalQueue, diff --git a/pkg/gotd/telegram/updates/state_channel.go b/pkg/gotd/telegram/updates/state_channel.go index 1cadb64e..c8d322f8 100644 --- a/pkg/gotd/telegram/updates/state_channel.go +++ b/pkg/gotd/telegram/updates/state_channel.go @@ -2,6 +2,7 @@ package updates import ( "context" + "fmt" "time" "github.com/go-faster/errors" @@ -10,6 +11,7 @@ import ( "go.mau.fi/mautrix-telegram/pkg/gotd/telegram" "go.mau.fi/mautrix-telegram/pkg/gotd/tg" + "go.mau.fi/mautrix-telegram/pkg/gotd/tgerr" ) type channelUpdate struct { @@ -40,6 +42,9 @@ type channelState struct { tracer trace.Tracer handler telegram.UpdateHandler onTooLong func(channelID int64) error + + runCtx context.Context + stop context.CancelCauseFunc } type channelStateConfig struct { @@ -90,33 +95,42 @@ func (s *channelState) Push(ctx context.Context, u channelUpdate) error { select { case <-ctx.Done(): return ctx.Err() + case <-s.runCtx.Done(): + return s.runCtx.Err() case s.updates <- u: return nil } } +var ErrRemoveChannelState = errors.New("remove channel state") + func (s *channelState) Run(ctx context.Context) error { + s.runCtx, s.stop = context.WithCancelCause(ctx) + defer s.stop(nil) // Subscribe to channel updates. - if err := s.getDifference(ctx); err != nil { + if err := s.getDifference(s.runCtx); err != nil { s.log.Error("Failed to subscribe to channel updates", zap.Error(err)) } for { select { case u := <-s.updates: - ctx := trace.ContextWithSpanContext(ctx, u.span) + ctx := trace.ContextWithSpanContext(s.runCtx, u.span) if err := s.handleUpdate(ctx, u.update, u.entities); err != nil { s.log.Error("Handle update error", zap.Error(err)) } case <-s.pts.gapTimeout.C: s.log.Debug("Gap timeout") - s.getDifferenceLogger(ctx) - case <-ctx.Done(): - return ctx.Err() + s.getDifferenceLogger(s.runCtx) + case <-s.runCtx.Done(): + if cause := context.Cause(s.runCtx); cause != nil && ctx.Err() == nil { + return cause + } + return s.runCtx.Err() case <-s.idleTimeout.C: s.log.Debug("Idle timeout") s.resetIdleTimer() - s.getDifferenceLogger(ctx) + s.getDifferenceLogger(s.runCtx) } } } @@ -245,6 +259,12 @@ func (s *channelState) getDifference(ctx context.Context) error { Limit: s.diffLim, }) if err != nil { + if tgerr.Is(err, "CHANNEL_PRIVATE") || tgerr.Is(err, "CHANNEL_INVALID") { + if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, -1); err != nil { + s.log.Error("SetChannelPts error (clear)", zap.Error(err)) + } + s.stop(fmt.Errorf("%w: %w", ErrRemoveChannelState, err)) + } return errors.Wrap(err, "get channel difference") }