channels: handle messages Matrix <-> TG
Signed-off-by: Sumner Evans <sumner.evans@automattic.com>
This commit is contained in:
+74
-9
@@ -21,11 +21,13 @@ import (
|
|||||||
"go.mau.fi/mautrix-telegram/pkg/connector/ids"
|
"go.mau.fi/mautrix-telegram/pkg/connector/ids"
|
||||||
"go.mau.fi/mautrix-telegram/pkg/connector/media"
|
"go.mau.fi/mautrix-telegram/pkg/connector/media"
|
||||||
"go.mau.fi/mautrix-telegram/pkg/connector/msgconv"
|
"go.mau.fi/mautrix-telegram/pkg/connector/msgconv"
|
||||||
|
"go.mau.fi/mautrix-telegram/pkg/connector/store"
|
||||||
"go.mau.fi/mautrix-telegram/pkg/connector/util"
|
"go.mau.fi/mautrix-telegram/pkg/connector/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TelegramClient struct {
|
type TelegramClient struct {
|
||||||
main *TelegramConnector
|
main *TelegramConnector
|
||||||
|
ScopedStore *store.ScopedStore
|
||||||
telegramUserID int64
|
telegramUserID int64
|
||||||
loginID networkid.UserLoginID
|
loginID networkid.UserLoginID
|
||||||
userID networkid.UserID
|
userID networkid.UserID
|
||||||
@@ -103,13 +105,28 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
|
|||||||
UpdateDispatcher: tg.NewUpdateDispatcher(),
|
UpdateDispatcher: tg.NewUpdateDispatcher(),
|
||||||
EntityHandler: client.onEntityUpdate,
|
EntityHandler: client.onEntityUpdate,
|
||||||
}
|
}
|
||||||
dispatcher.OnNewMessage(client.onUpdateNewMessage)
|
dispatcher.OnNewMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateNewMessage) error {
|
||||||
dispatcher.OnNewChannelMessage(client.onUpdateNewChannelMessage)
|
return client.onUpdateNewMessage(ctx, update)
|
||||||
|
})
|
||||||
|
dispatcher.OnNewChannelMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateNewChannelMessage) error {
|
||||||
|
fmt.Printf("%+v\n", update)
|
||||||
|
return client.onUpdateNewMessage(ctx, update)
|
||||||
|
})
|
||||||
dispatcher.OnUserName(client.onUserName)
|
dispatcher.OnUserName(client.onUserName)
|
||||||
dispatcher.OnDeleteMessages(client.onDeleteMessages)
|
dispatcher.OnDeleteMessages(func(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteMessages) error {
|
||||||
dispatcher.OnEditMessage(client.onMessageEdit)
|
return client.onDeleteMessages(ctx, update)
|
||||||
|
})
|
||||||
|
dispatcher.OnDeleteChannelMessages(func(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteChannelMessages) error {
|
||||||
|
return client.onDeleteMessages(ctx, update)
|
||||||
|
})
|
||||||
|
dispatcher.OnEditMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateEditMessage) error {
|
||||||
|
return client.onMessageEdit(ctx, update)
|
||||||
|
})
|
||||||
|
dispatcher.OnEditChannelMessage(func(ctx context.Context, e tg.Entities, update *tg.UpdateEditChannelMessage) error {
|
||||||
|
return client.onMessageEdit(ctx, update)
|
||||||
|
})
|
||||||
|
|
||||||
store := tc.Store.GetScopedStore(telegramUserID)
|
client.ScopedStore = tc.Store.GetScopedStore(telegramUserID)
|
||||||
|
|
||||||
updatesManager := updates.New(updates.Config{
|
updatesManager := updates.New(updates.Config{
|
||||||
OnChannelTooLong: func(channelID int64) {
|
OnChannelTooLong: func(channelID int64) {
|
||||||
@@ -118,12 +135,12 @@ func NewTelegramClient(ctx context.Context, tc *TelegramConnector, login *bridge
|
|||||||
},
|
},
|
||||||
Handler: dispatcher,
|
Handler: dispatcher,
|
||||||
Logger: zaplog.Named("gaps"),
|
Logger: zaplog.Named("gaps"),
|
||||||
Storage: store,
|
Storage: client.ScopedStore,
|
||||||
AccessHasher: store,
|
AccessHasher: client.ScopedStore,
|
||||||
})
|
})
|
||||||
|
|
||||||
client.client = telegram.NewClient(tc.Config.AppID, tc.Config.AppHash, telegram.Options{
|
client.client = telegram.NewClient(tc.Config.AppID, tc.Config.AppHash, telegram.Options{
|
||||||
SessionStorage: store,
|
SessionStorage: client.ScopedStore,
|
||||||
Logger: zaplog,
|
Logger: zaplog,
|
||||||
UpdateHandler: updatesManager,
|
UpdateHandler: updatesManager,
|
||||||
})
|
})
|
||||||
@@ -184,7 +201,7 @@ func (t *TelegramClient) Disconnect() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) {
|
func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) {
|
||||||
fmt.Printf("%+v\n", portal)
|
fmt.Printf("get chat info %+v\n", portal)
|
||||||
peerType, id, err := ids.ParsePortalID(portal.ID)
|
peerType, id, err := ids.ParsePortalID(portal.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -253,6 +270,54 @@ func (t *TelegramClient) GetChatInfo(ctx context.Context, portal *bridgev2.Porta
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, user := range fullChat.Users {
|
||||||
|
memberList.Members = append(memberList.Members, bridgev2.ChatMember{
|
||||||
|
EventSender: bridgev2.EventSender{
|
||||||
|
IsFromMe: user.GetID() == t.telegramUserID,
|
||||||
|
SenderLogin: ids.MakeUserLoginID(user.GetID()),
|
||||||
|
Sender: ids.MakeUserID(user.GetID()),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case ids.PeerTypeChannel:
|
||||||
|
accessHash, found, err := t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
|
||||||
|
} else if !found {
|
||||||
|
return nil, fmt.Errorf("channel access hash not found for %d", id)
|
||||||
|
}
|
||||||
|
fullChat, err := t.client.API().ChannelsGetFullChannel(ctx, &tg.InputChannel{ChannelID: id, AccessHash: accessHash})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, c := range fullChat.Chats {
|
||||||
|
if c.GetID() == id {
|
||||||
|
switch chat := c.(type) {
|
||||||
|
case *tg.Chat:
|
||||||
|
name = chat.Title
|
||||||
|
case *tg.Channel:
|
||||||
|
name = chat.Title
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chatFull, ok := fullChat.FullChat.(*tg.ChatFull)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("full chat is not %T", chatFull)
|
||||||
|
}
|
||||||
|
|
||||||
|
if photo, ok := chatFull.GetChatPhoto(); ok {
|
||||||
|
avatar = &bridgev2.Avatar{
|
||||||
|
ID: ids.MakeAvatarID(photo.GetID()),
|
||||||
|
Get: func(ctx context.Context) (data []byte, err error) {
|
||||||
|
data, _, err = media.NewTransferer(t.client.API()).WithPhoto(photo).Download(ctx)
|
||||||
|
return
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
memberList.IsFull = false
|
||||||
for _, user := range fullChat.Users {
|
for _, user := range fullChat.Users {
|
||||||
memberList.Members = append(memberList.Members, bridgev2.ChatMember{
|
memberList.Members = append(memberList.Members, bridgev2.ChatMember{
|
||||||
EventSender: bridgev2.EventSender{
|
EventSender: bridgev2.EventSender{
|
||||||
|
|||||||
@@ -49,18 +49,26 @@ func (tc *TelegramConnector) Download(ctx context.Context, mediaID networkid.Med
|
|||||||
&tg.InputMessageID{ID: int(info.MessageID)},
|
&tg.InputMessageID{ID: int(info.MessageID)},
|
||||||
})
|
})
|
||||||
case ids.PeerTypeChannel:
|
case ids.PeerTypeChannel:
|
||||||
// TODO test this
|
var accessHash int64
|
||||||
messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{
|
var found bool
|
||||||
Channel: &tg.InputChannel{ChannelID: info.ChatID},
|
accessHash, found, err = client.ScopedStore.GetChannelAccessHash(ctx, client.telegramUserID, info.ChatID)
|
||||||
ID: []tg.InputMessageClass{
|
if err != nil {
|
||||||
&tg.InputMessageID{ID: int(info.MessageID)},
|
return nil, fmt.Errorf("failed to get channel access hash: %w", err)
|
||||||
},
|
} else if !found {
|
||||||
})
|
return nil, fmt.Errorf("channel access hash not found for %d", info.ChatID)
|
||||||
|
} else {
|
||||||
|
messages, err = client.client.API().ChannelsGetMessages(ctx, &tg.ChannelsGetMessagesRequest{
|
||||||
|
Channel: &tg.InputChannel{ChannelID: info.ChatID, AccessHash: accessHash},
|
||||||
|
ID: []tg.InputMessageClass{
|
||||||
|
&tg.InputMessageID{ID: int(info.MessageID)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown peer type %s", info.PeerType)
|
return nil, fmt.Errorf("unknown peer type %s", info.PeerType)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to get messages for %+v: %w", info, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var msgMedia tg.MessageMediaClass
|
var msgMedia tg.MessageMediaClass
|
||||||
|
|||||||
@@ -92,27 +92,6 @@ func ParsePortalID(portalID networkid.PortalID) (pt PeerType, id int64, err erro
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func InputPeerForPortalID(portalID networkid.PortalID) (tg.InputPeerClass, error) {
|
|
||||||
peerType, id, err := ParsePortalID(portalID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
switch peerType {
|
|
||||||
case PeerTypeUser:
|
|
||||||
return &tg.InputPeerUser{UserID: id}, nil
|
|
||||||
case PeerTypeChat:
|
|
||||||
return &tg.InputPeerChat{ChatID: id}, nil
|
|
||||||
case PeerTypeChannel:
|
|
||||||
return &tg.InputPeerChannel{ChannelID: id}, nil
|
|
||||||
default:
|
|
||||||
panic("invalid peer type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func InputPeerForPortalKey(portalKey networkid.PortalKey) (tg.InputPeerClass, error) {
|
|
||||||
return InputPeerForPortalID(portalKey.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func MakeAvatarID(photoID int64) networkid.AvatarID {
|
func MakeAvatarID(photoID int64) networkid.AvatarID {
|
||||||
return networkid.AvatarID(strconv.FormatInt(photoID, 10))
|
return networkid.AvatarID(strconv.FormatInt(photoID, 10))
|
||||||
}
|
}
|
||||||
|
|||||||
+10
-6
@@ -39,12 +39,11 @@ func getMediaFilenameAndCaption(content *event.MessageEventContent) (filename, c
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (resp *bridgev2.MatrixMessageResponse, err error) {
|
func (t *TelegramClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (resp *bridgev2.MatrixMessageResponse, err error) {
|
||||||
sender := message.NewSender(t.client.API())
|
peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID)
|
||||||
peer, err := ids.InputPeerForPortalID(msg.Portal.ID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
builder := sender.To(peer)
|
builder := message.NewSender(t.client.API()).To(peer)
|
||||||
|
|
||||||
// TODO handle sticker
|
// TODO handle sticker
|
||||||
|
|
||||||
@@ -173,8 +172,13 @@ func (t *TelegramClient) HandleMatrixMessageRemove(ctx context.Context, msg *bri
|
|||||||
return err
|
return err
|
||||||
} else if messageID, err := ids.ParseMessageID(dbMsg.ID); err != nil {
|
} else if messageID, err := ids.ParseMessageID(dbMsg.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
|
} else if peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID); err != nil {
|
||||||
|
return err
|
||||||
} else {
|
} else {
|
||||||
_, err = message.NewSender(t.client.API()).Self().Revoke().Messages(ctx, messageID)
|
_, err := message.NewSender(t.client.API()).
|
||||||
|
To(peer).
|
||||||
|
Revoke().
|
||||||
|
Messages(ctx, messageID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -224,7 +228,7 @@ func (t *TelegramClient) appendEmojiID(reactionList []tg.ReactionClass, emojiID
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (reaction *database.Reaction, err error) {
|
func (t *TelegramClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (reaction *database.Reaction, err error) {
|
||||||
peer, err := ids.InputPeerForPortalID(msg.Portal.ID)
|
peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -255,7 +259,7 @@ func (t *TelegramClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error {
|
func (t *TelegramClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error {
|
||||||
peer, err := ids.InputPeerForPortalID(msg.Portal.ID)
|
peer, err := t.inputPeerForPortalID(ctx, msg.Portal.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,6 @@ func (c *Container) Upgrade(ctx context.Context) error {
|
|||||||
return c.Database.Upgrade(ctx)
|
return c.Database.Upgrade(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Container) GetScopedStore(telegramUserID int64) *scopedStore {
|
func (c *Container) GetScopedStore(telegramUserID int64) *ScopedStore {
|
||||||
return &scopedStore{c.Database, telegramUserID}
|
return &ScopedStore{c.Database, telegramUserID}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
"go.mau.fi/util/dbutil"
|
"go.mau.fi/util/dbutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// scopedStore is a wrapper around a database that implements
|
// ScopedStore is a wrapper around a database that implements
|
||||||
// [session.Storage] scoped to a specific Telegram user ID.
|
// [session.Storage] scoped to a specific Telegram user ID.
|
||||||
type scopedStore struct {
|
type ScopedStore struct {
|
||||||
db *dbutil.Database
|
db *dbutil.Database
|
||||||
telegramUserID int64
|
telegramUserID int64
|
||||||
}
|
}
|
||||||
@@ -60,22 +60,22 @@ const (
|
|||||||
`
|
`
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ session.Storage = (*scopedStore)(nil)
|
var _ session.Storage = (*ScopedStore)(nil)
|
||||||
|
|
||||||
func (s *scopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
|
func (s *ScopedStore) LoadSession(ctx context.Context) (sessionData []byte, err error) {
|
||||||
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
|
row := s.db.QueryRow(ctx, loadSessionQuery, s.telegramUserID)
|
||||||
err = row.Scan(&sessionData)
|
err = row.Scan(&sessionData)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) StoreSession(ctx context.Context, data []byte) error {
|
func (s *ScopedStore) StoreSession(ctx context.Context, data []byte) error {
|
||||||
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
|
_, err := s.db.Exec(ctx, storeSessionQuery, s.telegramUserID, data)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ updates.StateStorage = (*scopedStore)(nil)
|
var _ updates.StateStorage = (*ScopedStore)(nil)
|
||||||
|
|
||||||
func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
func (s *ScopedStore) ForEachChannels(ctx context.Context, userID int64, f func(ctx context.Context, channelID int64, pts int) error) error {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
|
rows, err := s.db.Query(ctx, allChannelsQuery, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -93,7 +93,7 @@ func (s *scopedStore) ForEachChannels(ctx context.Context, userID int64, f func(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
|
func (s *ScopedStore) GetChannelPts(ctx context.Context, userID int64, channelID int64) (pts int, found bool, err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts)
|
err = s.db.QueryRow(ctx, getChannelPtsQuery, userID, channelID).Scan(&pts)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
@@ -102,13 +102,13 @@ func (s *scopedStore) GetChannelPts(ctx context.Context, userID int64, channelID
|
|||||||
return pts, err == nil, err
|
return pts, err == nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
|
func (s *ScopedStore) SetChannelPts(ctx context.Context, userID int64, channelID int64, pts int) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts)
|
_, err = s.db.Exec(ctx, setChannelPtsQuery, userID, channelID, pts)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
func (s *ScopedStore) GetState(ctx context.Context, userID int64) (state updates.State, found bool, err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq)
|
err = s.db.QueryRow(ctx, getStateQuery, userID).Scan(&state.Pts, &state.Qts, &state.Date, &state.Seq)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
@@ -117,45 +117,45 @@ func (s *scopedStore) GetState(ctx context.Context, userID int64) (state updates
|
|||||||
return state, err == nil, err
|
return state, err == nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
|
func (s *ScopedStore) SetState(ctx context.Context, userID int64, state updates.State) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq)
|
_, err = s.db.Exec(ctx, setStateQuery, userID, state.Pts, state.Qts, state.Date, state.Seq)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
|
func (s *ScopedStore) SetPts(ctx context.Context, userID int64, pts int) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setPtsQuery, userID, pts)
|
_, err = s.db.Exec(ctx, setPtsQuery, userID, pts)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
|
func (s *ScopedStore) SetQts(ctx context.Context, userID int64, qts int) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setQtsQuery, userID, qts)
|
_, err = s.db.Exec(ctx, setQtsQuery, userID, qts)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
|
func (s *ScopedStore) SetSeq(ctx context.Context, userID int64, seq int) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setSeqQuery, userID, seq)
|
_, err = s.db.Exec(ctx, setSeqQuery, userID, seq)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
|
func (s *ScopedStore) SetDate(ctx context.Context, userID int64, date int) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setDateQuery, userID, date)
|
_, err = s.db.Exec(ctx, setDateQuery, userID, date)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
|
func (s *ScopedStore) SetDateSeq(ctx context.Context, userID int64, date int, seq int) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq)
|
_, err = s.db.Exec(ctx, setDateSeqQuery, userID, date, seq)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ updates.ChannelAccessHasher = (*scopedStore)(nil)
|
var _ updates.ChannelAccessHasher = (*ScopedStore)(nil)
|
||||||
|
|
||||||
func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
|
func (s *ScopedStore) GetChannelAccessHash(ctx context.Context, userID int64, channelID int64) (accessHash int64, found bool, err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash)
|
err = s.db.QueryRow(ctx, getChannelAccessHashQuery, userID, channelID).Scan(&accessHash)
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
@@ -164,7 +164,7 @@ func (s *scopedStore) GetChannelAccessHash(ctx context.Context, userID int64, ch
|
|||||||
return accessHash, err == nil, err
|
return accessHash, err == nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
|
func (s *ScopedStore) SetChannelAccessHash(ctx context.Context, userID int64, channelID int64, accessHash int64) (err error) {
|
||||||
s.assertUserIDMatches(userID)
|
s.assertUserIDMatches(userID)
|
||||||
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
|
_, err = s.db.Exec(ctx, setChannelAccessHashQuery, userID, channelID, accessHash)
|
||||||
return
|
return
|
||||||
@@ -172,7 +172,7 @@ func (s *scopedStore) SetChannelAccessHash(ctx context.Context, userID int64, ch
|
|||||||
|
|
||||||
// Helper Functions
|
// Helper Functions
|
||||||
|
|
||||||
func (s *scopedStore) assertUserIDMatches(userID int64) {
|
func (s *ScopedStore) assertUserIDMatches(userID int64) {
|
||||||
if s.telegramUserID != userID {
|
if s.telegramUserID != userID {
|
||||||
panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID))
|
panic(fmt.Sprintf("scoped store for %d function called with user ID %d", s.telegramUserID, userID))
|
||||||
}
|
}
|
||||||
|
|||||||
+38
-12
@@ -20,7 +20,15 @@ import (
|
|||||||
"go.mau.fi/mautrix-telegram/pkg/connector/util"
|
"go.mau.fi/mautrix-telegram/pkg/connector/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, e tg.Entities, update *tg.UpdateNewMessage) error {
|
type IGetMessage interface {
|
||||||
|
GetMessage() tg.MessageClass
|
||||||
|
}
|
||||||
|
|
||||||
|
type IGetMessages interface {
|
||||||
|
GetMessages() []int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, update IGetMessage) error {
|
||||||
log := zerolog.Ctx(ctx)
|
log := zerolog.Ctx(ctx)
|
||||||
switch msg := update.GetMessage().(type) {
|
switch msg := update.GetMessage().(type) {
|
||||||
case *tg.Message:
|
case *tg.Message:
|
||||||
@@ -40,7 +48,7 @@ func (t *TelegramClient) onUpdateNewMessage(ctx context.Context, e tg.Entities,
|
|||||||
Type: bridgev2.RemoteEventMessage,
|
Type: bridgev2.RemoteEventMessage,
|
||||||
LogContext: func(c zerolog.Context) zerolog.Context {
|
LogContext: func(c zerolog.Context) zerolog.Context {
|
||||||
return c.
|
return c.
|
||||||
Int("message_id", update.Message.GetID()).
|
Int("message_id", msg.GetID()).
|
||||||
Str("sender", string(sender.Sender)).
|
Str("sender", string(sender.Sender)).
|
||||||
Str("sender_login", string(sender.SenderLogin)).
|
Str("sender_login", string(sender.SenderLogin)).
|
||||||
Bool("is_from_me", sender.IsFromMe)
|
Bool("is_from_me", sender.IsFromMe)
|
||||||
@@ -141,11 +149,6 @@ func (t *TelegramClient) getEventSender(msg messageWithSender) (sender bridgev2.
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) onUpdateNewChannelMessage(ctx context.Context, e tg.Entities, update *tg.UpdateNewChannelMessage) error {
|
|
||||||
fmt.Printf("update new channel message %+v\n", update)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TelegramClient) onUserName(ctx context.Context, e tg.Entities, update *tg.UpdateUserName) error {
|
func (t *TelegramClient) onUserName(ctx context.Context, e tg.Entities, update *tg.UpdateUserName) error {
|
||||||
ghost, err := t.main.Bridge.GetGhostByID(ctx, ids.MakeUserID(update.UserID))
|
ghost, err := t.main.Bridge.GetGhostByID(ctx, ids.MakeUserID(update.UserID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -159,8 +162,8 @@ func (t *TelegramClient) onUserName(ctx context.Context, e tg.Entities, update *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) onDeleteMessages(ctx context.Context, e tg.Entities, update *tg.UpdateDeleteMessages) error {
|
func (t *TelegramClient) onDeleteMessages(ctx context.Context, update IGetMessages) error {
|
||||||
for _, messageID := range update.Messages {
|
for _, messageID := range update.GetMessages() {
|
||||||
parts, err := t.main.Bridge.DB.Message.GetAllPartsByID(ctx, t.loginID, ids.MakeMessageID(messageID))
|
parts, err := t.main.Bridge.DB.Message.GetAllPartsByID(ctx, t.loginID, ids.MakeMessageID(messageID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -198,9 +201,9 @@ func (t *TelegramClient) onEntityUpdate(ctx context.Context, e tg.Entities) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TelegramClient) onMessageEdit(ctx context.Context, e tg.Entities, update *tg.UpdateEditMessage) error {
|
func (t *TelegramClient) onMessageEdit(ctx context.Context, update IGetMessage) error {
|
||||||
fmt.Printf("message edit %+v\n", update)
|
fmt.Printf("message edit %+v\n", update)
|
||||||
msg, ok := update.Message.(*tg.Message)
|
msg, ok := update.GetMessage().(*tg.Message)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("edit message is not *tg.Message")
|
return fmt.Errorf("edit message is not *tg.Message")
|
||||||
}
|
}
|
||||||
@@ -261,7 +264,7 @@ func (t *TelegramClient) handleTelegramReactions(ctx context.Context, msg *tg.Me
|
|||||||
// return
|
// return
|
||||||
|
|
||||||
// TODO should calls to this be limited?
|
// TODO should calls to this be limited?
|
||||||
} else if peer, err := ids.InputPeerForPortalKey(ids.MakePortalKey(msg.PeerID)); err != nil {
|
} else if peer, err := t.inputPeerForPortalID(ctx, ids.MakePortalKey(msg.PeerID).ID); err != nil {
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
reactions, err := t.client.API().MessagesGetMessageReactionsList(ctx, &tg.MessagesGetMessageReactionsListRequest{
|
reactions, err := t.client.API().MessagesGetMessageReactionsList(ctx, &tg.MessagesGetMessageReactionsListRequest{
|
||||||
@@ -300,6 +303,29 @@ func (t *TelegramClient) handleTelegramReactions(ctx context.Context, msg *tg.Me
|
|||||||
return t.handleTelegramParsedReactionsLocked(ctx, dbMsg, reactions, customEmojiIDs, isFull, nil, nil)
|
return t.handleTelegramParsedReactionsLocked(ctx, dbMsg, reactions, customEmojiIDs, isFull, nil, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TelegramClient) inputPeerForPortalID(ctx context.Context, portalID networkid.PortalID) (tg.InputPeerClass, error) {
|
||||||
|
peerType, id, err := ids.ParsePortalID(portalID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch peerType {
|
||||||
|
case ids.PeerTypeUser:
|
||||||
|
return &tg.InputPeerUser{UserID: id}, nil
|
||||||
|
case ids.PeerTypeChat:
|
||||||
|
return &tg.InputPeerChat{ChatID: id}, nil
|
||||||
|
case ids.PeerTypeChannel:
|
||||||
|
accessHash, found, err := t.ScopedStore.GetChannelAccessHash(ctx, t.telegramUserID, id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if !found {
|
||||||
|
return nil, fmt.Errorf("channel access hash not found for %d", id)
|
||||||
|
}
|
||||||
|
return &tg.InputPeerChannel{ChannelID: id, AccessHash: accessHash}, nil
|
||||||
|
default:
|
||||||
|
panic("invalid peer type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func splitDMReactionCounts(res []tg.ReactionCount, theirUserID, myUserID int64) (reactions []tg.MessagePeerReaction) {
|
func splitDMReactionCounts(res []tg.ReactionCount, theirUserID, myUserID int64) (reactions []tg.MessagePeerReaction) {
|
||||||
for _, item := range res {
|
for _, item := range res {
|
||||||
if item.Count == 2 || item.ChosenOrder > 0 {
|
if item.Count == 2 || item.ChosenOrder > 0 {
|
||||||
|
|||||||
Reference in New Issue
Block a user