gotd/updates: stop listening to channel on ChannelForbidden/Invalid

This commit is contained in:
Tulir Asokan
2025-12-10 19:12:34 +02:00
parent 7f13284b59
commit 0e3b1b63a9
3 changed files with 60 additions and 17 deletions
+3
View File
@@ -106,6 +106,9 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption
AccessHash int64
})
if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error {
if pts == -1 {
return nil
}
hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID)
if err != nil {
return errors.Wrap(err, "get channel access hash")
+31 -11
View File
@@ -3,6 +3,7 @@ package updates
import (
"context"
"fmt"
"sync"
"time"
"github.com/go-faster/errors"
@@ -44,7 +45,8 @@ type internalState struct {
idleTimeout *time.Timer
// Channel states.
channels map[int64]*channelState
channels map[int64]*channelState
channelsLock sync.Mutex
// Immutable fields.
client API
@@ -117,11 +119,10 @@ func newState(ctx context.Context, cfg stateConfig) *internalState {
})
for id, info := range cfg.Channels {
state := s.newChannelState(id, info.AccessHash, info.Pts)
s.channels[id] = state
s.wg.Go(func() error {
return state.Run(ctx)
})
if info.Pts == -1 {
continue
}
s.createAndRunChannelState(ctx, id, info.AccessHash, info.Pts)
}
return s
@@ -332,7 +333,9 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date
return nil
}
s.channelsLock.Lock()
state, ok := s.channels[channelID]
s.channelsLock.Unlock()
if !ok {
accessHash, found, err := s.hasher.GetChannelAccessHash(context.Background(), s.selfID, channelID)
if err != nil {
@@ -359,6 +362,9 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date
}
localPts, found, err := s.storage.GetChannelPts(ctx, s.selfID, channelID)
if localPts == -1 {
found = false
}
if err != nil {
localPts = pts - ptsCount
s.log.Error("GetChannelPts error", zap.Error(err))
@@ -371,16 +377,30 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date
}
}
state = s.newChannelState(channelID, accessHash, localPts)
s.channels[channelID] = state
s.wg.Go(func() error {
return state.Run(ctx)
})
state = s.createAndRunChannelState(ctx, channelID, accessHash, localPts)
}
return state.Push(ctx, cu)
}
func (s *internalState) createAndRunChannelState(ctx context.Context, channelID, accessHash int64, initialPts int) (state *channelState) {
state = s.newChannelState(channelID, accessHash, initialPts)
s.channelsLock.Lock()
s.channels[channelID] = state
s.channelsLock.Unlock()
s.wg.Go(func() error {
err := state.Run(ctx)
if errors.Is(err, ErrRemoveChannelState) {
s.channelsLock.Lock()
delete(s.channels, channelID)
s.channelsLock.Unlock()
s.log.Info("Removed channel state due to error", zap.Int64("channel_id", channelID))
}
return err
})
return state
}
func (s *internalState) newChannelState(channelID, accessHash int64, initialPts int) *channelState {
return newChannelState(channelStateConfig{
Out: s.internalQueue,
+26 -6
View File
@@ -2,6 +2,7 @@ package updates
import (
"context"
"fmt"
"time"
"github.com/go-faster/errors"
@@ -10,6 +11,7 @@ import (
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
)
type channelUpdate struct {
@@ -40,6 +42,9 @@ type channelState struct {
tracer trace.Tracer
handler telegram.UpdateHandler
onTooLong func(channelID int64) error
runCtx context.Context
stop context.CancelCauseFunc
}
type channelStateConfig struct {
@@ -90,33 +95,42 @@ func (s *channelState) Push(ctx context.Context, u channelUpdate) error {
select {
case <-ctx.Done():
return ctx.Err()
case <-s.runCtx.Done():
return s.runCtx.Err()
case s.updates <- u:
return nil
}
}
var ErrRemoveChannelState = errors.New("remove channel state")
func (s *channelState) Run(ctx context.Context) error {
s.runCtx, s.stop = context.WithCancelCause(ctx)
defer s.stop(nil)
// Subscribe to channel updates.
if err := s.getDifference(ctx); err != nil {
if err := s.getDifference(s.runCtx); err != nil {
s.log.Error("Failed to subscribe to channel updates", zap.Error(err))
}
for {
select {
case u := <-s.updates:
ctx := trace.ContextWithSpanContext(ctx, u.span)
ctx := trace.ContextWithSpanContext(s.runCtx, u.span)
if err := s.handleUpdate(ctx, u.update, u.entities); err != nil {
s.log.Error("Handle update error", zap.Error(err))
}
case <-s.pts.gapTimeout.C:
s.log.Debug("Gap timeout")
s.getDifferenceLogger(ctx)
case <-ctx.Done():
return ctx.Err()
s.getDifferenceLogger(s.runCtx)
case <-s.runCtx.Done():
if cause := context.Cause(s.runCtx); cause != nil && ctx.Err() == nil {
return cause
}
return s.runCtx.Err()
case <-s.idleTimeout.C:
s.log.Debug("Idle timeout")
s.resetIdleTimer()
s.getDifferenceLogger(ctx)
s.getDifferenceLogger(s.runCtx)
}
}
}
@@ -245,6 +259,12 @@ func (s *channelState) getDifference(ctx context.Context) error {
Limit: s.diffLim,
})
if err != nil {
if tgerr.Is(err, "CHANNEL_PRIVATE") || tgerr.Is(err, "CHANNEL_INVALID") {
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, -1); err != nil {
s.log.Error("SetChannelPts error (clear)", zap.Error(err))
}
s.stop(fmt.Errorf("%w: %w", ErrRemoveChannelState, err))
}
return errors.Wrap(err, "get channel difference")
}