gotd/updates: stop listening to channel on ChannelForbidden/Invalid
This commit is contained in:
@@ -106,6 +106,9 @@ func (m *Manager) Run(ctx context.Context, api API, userID int64, opt AuthOption
|
||||
AccessHash int64
|
||||
})
|
||||
if err := m.cfg.Storage.ForEachChannels(ctx, userID, func(ctx context.Context, channelID int64, pts int) error {
|
||||
if pts == -1 {
|
||||
return nil
|
||||
}
|
||||
hash, found, err := m.cfg.AccessHasher.GetChannelAccessHash(ctx, userID, channelID)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "get channel access hash")
|
||||
|
||||
@@ -3,6 +3,7 @@ package updates
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
@@ -44,7 +45,8 @@ type internalState struct {
|
||||
idleTimeout *time.Timer
|
||||
|
||||
// Channel states.
|
||||
channels map[int64]*channelState
|
||||
channels map[int64]*channelState
|
||||
channelsLock sync.Mutex
|
||||
|
||||
// Immutable fields.
|
||||
client API
|
||||
@@ -117,11 +119,10 @@ func newState(ctx context.Context, cfg stateConfig) *internalState {
|
||||
})
|
||||
|
||||
for id, info := range cfg.Channels {
|
||||
state := s.newChannelState(id, info.AccessHash, info.Pts)
|
||||
s.channels[id] = state
|
||||
s.wg.Go(func() error {
|
||||
return state.Run(ctx)
|
||||
})
|
||||
if info.Pts == -1 {
|
||||
continue
|
||||
}
|
||||
s.createAndRunChannelState(ctx, id, info.AccessHash, info.Pts)
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -332,7 +333,9 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date
|
||||
return nil
|
||||
}
|
||||
|
||||
s.channelsLock.Lock()
|
||||
state, ok := s.channels[channelID]
|
||||
s.channelsLock.Unlock()
|
||||
if !ok {
|
||||
accessHash, found, err := s.hasher.GetChannelAccessHash(context.Background(), s.selfID, channelID)
|
||||
if err != nil {
|
||||
@@ -359,6 +362,9 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date
|
||||
}
|
||||
|
||||
localPts, found, err := s.storage.GetChannelPts(ctx, s.selfID, channelID)
|
||||
if localPts == -1 {
|
||||
found = false
|
||||
}
|
||||
if err != nil {
|
||||
localPts = pts - ptsCount
|
||||
s.log.Error("GetChannelPts error", zap.Error(err))
|
||||
@@ -371,16 +377,30 @@ func (s *internalState) handleChannel(ctx context.Context, channelID int64, date
|
||||
}
|
||||
}
|
||||
|
||||
state = s.newChannelState(channelID, accessHash, localPts)
|
||||
s.channels[channelID] = state
|
||||
s.wg.Go(func() error {
|
||||
return state.Run(ctx)
|
||||
})
|
||||
state = s.createAndRunChannelState(ctx, channelID, accessHash, localPts)
|
||||
}
|
||||
|
||||
return state.Push(ctx, cu)
|
||||
}
|
||||
|
||||
func (s *internalState) createAndRunChannelState(ctx context.Context, channelID, accessHash int64, initialPts int) (state *channelState) {
|
||||
state = s.newChannelState(channelID, accessHash, initialPts)
|
||||
s.channelsLock.Lock()
|
||||
s.channels[channelID] = state
|
||||
s.channelsLock.Unlock()
|
||||
s.wg.Go(func() error {
|
||||
err := state.Run(ctx)
|
||||
if errors.Is(err, ErrRemoveChannelState) {
|
||||
s.channelsLock.Lock()
|
||||
delete(s.channels, channelID)
|
||||
s.channelsLock.Unlock()
|
||||
s.log.Info("Removed channel state due to error", zap.Int64("channel_id", channelID))
|
||||
}
|
||||
return err
|
||||
})
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *internalState) newChannelState(channelID, accessHash int64, initialPts int) *channelState {
|
||||
return newChannelState(channelStateConfig{
|
||||
Out: s.internalQueue,
|
||||
|
||||
@@ -2,6 +2,7 @@ package updates
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/telegram"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tg"
|
||||
"go.mau.fi/mautrix-telegram/pkg/gotd/tgerr"
|
||||
)
|
||||
|
||||
type channelUpdate struct {
|
||||
@@ -40,6 +42,9 @@ type channelState struct {
|
||||
tracer trace.Tracer
|
||||
handler telegram.UpdateHandler
|
||||
onTooLong func(channelID int64) error
|
||||
|
||||
runCtx context.Context
|
||||
stop context.CancelCauseFunc
|
||||
}
|
||||
|
||||
type channelStateConfig struct {
|
||||
@@ -90,33 +95,42 @@ func (s *channelState) Push(ctx context.Context, u channelUpdate) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-s.runCtx.Done():
|
||||
return s.runCtx.Err()
|
||||
case s.updates <- u:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var ErrRemoveChannelState = errors.New("remove channel state")
|
||||
|
||||
func (s *channelState) Run(ctx context.Context) error {
|
||||
s.runCtx, s.stop = context.WithCancelCause(ctx)
|
||||
defer s.stop(nil)
|
||||
// Subscribe to channel updates.
|
||||
if err := s.getDifference(ctx); err != nil {
|
||||
if err := s.getDifference(s.runCtx); err != nil {
|
||||
s.log.Error("Failed to subscribe to channel updates", zap.Error(err))
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case u := <-s.updates:
|
||||
ctx := trace.ContextWithSpanContext(ctx, u.span)
|
||||
ctx := trace.ContextWithSpanContext(s.runCtx, u.span)
|
||||
if err := s.handleUpdate(ctx, u.update, u.entities); err != nil {
|
||||
s.log.Error("Handle update error", zap.Error(err))
|
||||
}
|
||||
case <-s.pts.gapTimeout.C:
|
||||
s.log.Debug("Gap timeout")
|
||||
s.getDifferenceLogger(ctx)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
s.getDifferenceLogger(s.runCtx)
|
||||
case <-s.runCtx.Done():
|
||||
if cause := context.Cause(s.runCtx); cause != nil && ctx.Err() == nil {
|
||||
return cause
|
||||
}
|
||||
return s.runCtx.Err()
|
||||
case <-s.idleTimeout.C:
|
||||
s.log.Debug("Idle timeout")
|
||||
s.resetIdleTimer()
|
||||
s.getDifferenceLogger(ctx)
|
||||
s.getDifferenceLogger(s.runCtx)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -245,6 +259,12 @@ func (s *channelState) getDifference(ctx context.Context) error {
|
||||
Limit: s.diffLim,
|
||||
})
|
||||
if err != nil {
|
||||
if tgerr.Is(err, "CHANNEL_PRIVATE") || tgerr.Is(err, "CHANNEL_INVALID") {
|
||||
if err := s.storage.SetChannelPts(ctx, s.selfID, s.channelID, -1); err != nil {
|
||||
s.log.Error("SetChannelPts error (clear)", zap.Error(err))
|
||||
}
|
||||
s.stop(fmt.Errorf("%w: %w", ErrRemoveChannelState, err))
|
||||
}
|
||||
return errors.Wrap(err, "get channel difference")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user