updates: add wrapper for API calls to update users

Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
Sumner Evans
2024-08-21 13:45:45 -06:00
parent 284178df65
commit 0670c2b2bc
6 changed files with 124 additions and 73 deletions
+34
View File
@@ -0,0 +1,34 @@
package connector
import (
"context"
"fmt"
"github.com/gotd/td/tg"
)
type hasUpdates interface {
GetUsers() []tg.UserClass
}
// Wrapper for API calls that return a response with updates.
func APICallWithUpdates[U hasUpdates](ctx context.Context, t *TelegramClient, fn func() (U, error)) (U, error) {
resp, err := fn()
if err != nil {
return resp, err
}
// TODO do we also need to expand this to chats and messages?
for _, user := range resp.GetUsers() {
user, ok := user.(*tg.User)
if !ok {
return resp, fmt.Errorf("user is %T not *tg.User", user)
}
err := t.updateGhost(ctx, user.ID, user)
if err != nil {
return resp, err
}
}
return resp, nil
}
+7 -1
View File
@@ -36,14 +36,20 @@ func (t *TelegramClient) FetchMessages(ctx context.Context, fetchParams bridgev2
return nil, err return nil, err
} }
} }
msgs, err := APICallWithUpdates(ctx, t, func() (tg.ModifiedMessagesMessages, error) {
rawMsgs, err := t.client.API().MessagesGetHistory(ctx, &req) rawMsgs, err := t.client.API().MessagesGetHistory(ctx, &req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msgs, ok := rawMsgs.(interface{ GetMessages() []tg.MessageClass }) msgs, ok := rawMsgs.(tg.ModifiedMessagesMessages)
if !ok { if !ok {
return nil, fmt.Errorf("unsupported messages type %T", rawMsgs) return nil, fmt.Errorf("unsupported messages type %T", rawMsgs)
} }
return msgs, nil
})
if err != nil {
return nil, err
}
var markRead bool // TODO implement var markRead bool // TODO implement
messages := msgs.GetMessages() messages := msgs.GetMessages()
+26 -16
View File
@@ -53,11 +53,7 @@ func (t *TelegramClient) getDMChatInfo(ctx context.Context, userID int64) (*brid
return &chatInfo, nil return &chatInfo, nil
} }
func (t *TelegramClient) getGroupChatInfo(ctx context.Context, fullChat *tg.MessagesChatFull, chatID int64) (*bridgev2.ChatInfo, bool, error) { func (t *TelegramClient) getGroupChatInfo(fullChat *tg.MessagesChatFull, chatID int64) (*bridgev2.ChatInfo, bool, error) {
if err := t.updateUsersFromResponse(ctx, fullChat); err != nil {
return nil, false, err
}
var name *string var name *string
var isBroadcastChannel, isMegagroup bool var isBroadcastChannel, isMegagroup bool
for _, c := range fullChat.GetChats() { for _, c := range fullChat.GetChats() {
@@ -140,7 +136,6 @@ func (t *TelegramClient) filterChannelParticipants(chatParticipants []tg.Channel
} }
func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) {
// fmt.Printf("get chat info %+v\n", portal)
peerType, id, err := ids.ParsePortalID(portal.ID) peerType, id, err := ids.ParsePortalID(portal.ID)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -150,11 +145,13 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
case ids.PeerTypeUser: case ids.PeerTypeUser:
return t.getDMChatInfo(ctx, id) return t.getDMChatInfo(ctx, id)
case ids.PeerTypeChat: case ids.PeerTypeChat:
fullChat, err := t.client.API().MessagesGetFullChat(ctx, id) fullChat, err := APICallWithUpdates(ctx, t, func() (*tg.MessagesChatFull, error) {
return t.client.API().MessagesGetFullChat(ctx, id)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
chatInfo, _, err := t.getGroupChatInfo(ctx, fullChat, id) chatInfo, _, err := t.getGroupChatInfo(fullChat, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -202,12 +199,14 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
return nil, fmt.Errorf("channel access hash not found for %d", id) return nil, fmt.Errorf("channel access hash not found for %d", id)
} }
inputChannel := &tg.InputChannel{ChannelID: id, AccessHash: accessHash} inputChannel := &tg.InputChannel{ChannelID: id, AccessHash: accessHash}
fullChat, err := t.client.API().ChannelsGetFullChannel(ctx, inputChannel) fullChat, err := APICallWithUpdates(ctx, t, func() (*tg.MessagesChatFull, error) {
return t.client.API().ChannelsGetFullChannel(ctx, inputChannel)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
chatInfo, isBroadcastChannel, err := t.getGroupChatInfo(ctx, fullChat, id) chatInfo, isBroadcastChannel, err := t.getGroupChatInfo(fullChat, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -245,6 +244,7 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
limit := t.main.Config.MemberList.NormalizedMaxInitialSync() limit := t.main.Config.MemberList.NormalizedMaxInitialSync()
if limit <= 200 { if limit <= 200 {
participants, err := APICallWithUpdates(ctx, t, func() (*tg.ChannelsChannelParticipants, error) {
p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{ p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{
Channel: inputChannel, Channel: inputChannel,
Filter: &tg.ChannelParticipantsRecent{}, Filter: &tg.ChannelParticipantsRecent{},
@@ -256,16 +256,20 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
participants, ok := p.(*tg.ChannelsChannelParticipants) participants, ok := p.(*tg.ChannelsChannelParticipants)
if !ok { if !ok {
return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p) return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p)
} else {
return participants, nil
} }
chatInfo.Members.IsFull = len(participants.Participants) < limit })
if err := t.updateUsersFromResponse(ctx, participants); err != nil { if err != nil {
return nil, err return nil, err
} }
chatInfo.Members.IsFull = len(participants.Participants) < limit
chatInfo.Members.Members = append(chatInfo.Members.Members, t.filterChannelParticipants(participants.Participants, limit)...) chatInfo.Members.Members = append(chatInfo.Members.Members, t.filterChannelParticipants(participants.Participants, limit)...)
} else { } else {
remaining := t.main.Config.MemberList.NormalizedMaxInitialSync() remaining := t.main.Config.MemberList.NormalizedMaxInitialSync()
var offset int var offset int
for remaining > 0 { for remaining > 0 {
participants, err := APICallWithUpdates(ctx, t, func() (*tg.ChannelsChannelParticipants, error) {
p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{ p, err := t.client.API().ChannelsGetParticipants(ctx, &tg.ChannelsGetParticipantsRequest{
Channel: inputChannel, Channel: inputChannel,
Filter: &tg.ChannelParticipantsSearch{}, Filter: &tg.ChannelParticipantsSearch{},
@@ -276,6 +280,16 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
return nil, err return nil, err
} }
participants, ok := p.(*tg.ChannelsChannelParticipants) participants, ok := p.(*tg.ChannelsChannelParticipants)
if !ok {
return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p)
} else {
return participants, nil
}
})
if err != nil {
return nil, err
}
participants, ok := p.(*tg.ChannelsChannelParticipants)
if !ok { if !ok {
return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p) return nil, fmt.Errorf("returned participants is %T not *tg.ChannelsChannelParticipants", p)
} }
@@ -283,10 +297,6 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
chatInfo.Members.IsFull = true chatInfo.Members.IsFull = true
break break
} }
if err := t.updateUsersFromResponse(ctx, participants); err != nil {
return nil, err
}
chatInfo.Members.Members = append(chatInfo.Members.Members, t.filterChannelParticipants(participants.Participants, limit)...) chatInfo.Members.Members = append(chatInfo.Members.Members, t.filterChannelParticipants(participants.Participants, limit)...)
offset += len(participants.Participants) offset += len(participants.Participants)
-15
View File
@@ -383,21 +383,6 @@ func (t *TelegramClient) Disconnect() {
t.clientCancel() t.clientCancel()
} }
func (t *TelegramClient) updateUsersFromResponse(ctx context.Context, resp interface{ GetUsers() []tg.UserClass }) error {
// TODO table for the access hashes?
for _, user := range resp.GetUsers() {
user, ok := user.(*tg.User)
if !ok {
return fmt.Errorf("user is %T not *tg.User", user)
}
err := t.updateGhost(ctx, user.ID, user)
if err != nil {
return err
}
}
return nil
}
func (t *TelegramClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) { func (t *TelegramClient) GetUserInfo(ctx context.Context, ghost *bridgev2.Ghost) (*bridgev2.UserInfo, error) {
id, err := ids.ParseUserID(ghost.ID) id, err := ids.ParseUserID(ghost.ID)
if err != nil { if err != nil {
+22 -8
View File
@@ -50,12 +50,21 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
} }
client := userLogin.Client.(*TelegramClient) client := userLogin.Client.(*TelegramClient)
var messages tg.MessagesMessagesClass var messages tg.ModifiedMessagesMessages
switch info.PeerType { switch info.PeerType {
case ids.PeerTypeUser, ids.PeerTypeChat: case ids.PeerTypeUser, ids.PeerTypeChat:
messages, err = client.client.API().MessagesGetMessages(ctx, []tg.InputMessageClass{ messages, err = APICallWithUpdates(ctx, client, func() (tg.ModifiedMessagesMessages, error) {
m, err := client.client.API().MessagesGetMessages(ctx, []tg.InputMessageClass{
&tg.InputMessageID{ID: int(info.MessageID)}, &tg.InputMessageID{ID: int(info.MessageID)},
}) })
if err != nil {
return nil, err
} else if messages, ok := m.(tg.ModifiedMessagesMessages); !ok {
return nil, fmt.Errorf("unsupported messages type %T", messages)
} else {
return messages, nil
}
})
case ids.PeerTypeChannel: case ids.PeerTypeChannel:
var accessHash int64 var accessHash int64
var found bool var found bool
@@ -65,12 +74,21 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
} else if !found { } else if !found {
return nil, fmt.Errorf("channel access hash not found for %d", info.ChatID) return nil, fmt.Errorf("channel access hash not found for %d", info.ChatID)
} else { } else {
messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{ messages, err = APICallWithUpdates(ctx, client, func() (tg.ModifiedMessagesMessages, error) {
m, err := client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{
Channel: &tg.InputChannel{ChannelID: info.ChatID, AccessHash: accessHash}, Channel: &tg.InputChannel{ChannelID: info.ChatID, AccessHash: accessHash},
ID: []tg.InputMessageClass{ ID: []tg.InputMessageClass{
&tg.InputMessageID{ID: int(info.MessageID)}, &tg.InputMessageID{ID: int(info.MessageID)},
}, },
}) })
if err != nil {
return nil, err
} else if messages, ok := m.(tg.ModifiedMessagesMessages); !ok {
return nil, fmt.Errorf("unsupported messages type %T", messages)
} else {
return messages, nil
}
})
} }
default: default:
return nil, fmt.Errorf("unknown peer type %s", info.PeerType) return nil, fmt.Errorf("unknown peer type %s", info.PeerType)
@@ -80,11 +98,8 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
} }
var msgMedia tg.MessageMediaClass var msgMedia tg.MessageMediaClass
if m, ok := messages.(getMessages); !ok {
return nil, fmt.Errorf("unknown message type %T", messages)
} else {
var found bool var found bool
for _, message := range m.GetMessages() { for _, message := range messages.GetMessages() {
if msg, ok := message.(*tg.Message); ok && msg.ID == int(info.MessageID) { if msg, ok := message.(*tg.Message); ok && msg.ID == int(info.MessageID) {
msgMedia = msg.Media msgMedia = msg.Media
found = true found = true
@@ -94,7 +109,6 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
if !found { if !found {
return nil, fmt.Errorf("no media found with ID %d", info.MessageID) return nil, fmt.Errorf("no media found with ID %d", info.MessageID)
} }
}
transferer := media.NewTransferer(client.client.API()) transferer := media.NewTransferer(client.client.API())
var readyTransferer *media.ReadyTransferer var readyTransferer *media.ReadyTransferer
+3 -1
View File
@@ -51,9 +51,11 @@ func (t *TelegramClient) computeReactionsList(ctx context.Context, msg *tg.Messa
} else if peer, err := t.inputPeerForPortalID(ctx, ids.MakePortalKey(msg.PeerID, t.loginID).ID); err != nil { } else if peer, err := t.inputPeerForPortalID(ctx, ids.MakePortalKey(msg.PeerID, t.loginID).ID); err != nil {
return nil, false, nil, fmt.Errorf("failed to get input peer: %w", err) return nil, false, nil, fmt.Errorf("failed to get input peer: %w", err)
} else { } else {
reactions, err := t.client.API().MessagesGetMessageReactionsList(ctx, &tg.MessagesGetMessageReactionsListRequest{ reactions, err := APICallWithUpdates(ctx, t, func() (*tg.MessagesMessageReactionsList, error) {
return t.client.API().MessagesGetMessageReactionsList(ctx, &tg.MessagesGetMessageReactionsListRequest{
Peer: peer, ID: msg.ID, Limit: 100, Peer: peer, ID: msg.ID, Limit: 100,
}) })
})
if err != nil { if err != nil {
return nil, false, nil, fmt.Errorf("failed to get reactions list: %w", err) return nil, false, nil, fmt.Errorf("failed to get reactions list: %w", err)
} }