diff --git a/models/organization/org_list.go b/models/organization/org_list.go index 72ebf6f178..4c4168af1f 100644 --- a/models/organization/org_list.go +++ b/models/organization/org_list.go @@ -16,6 +16,31 @@ import ( "xorm.io/builder" ) +type OrgList []*Organization + +func (orgs OrgList) LoadTeams(ctx context.Context) (map[int64]TeamList, error) { + if len(orgs) == 0 { + return map[int64]TeamList{}, nil + } + + orgIDs := make([]int64, len(orgs)) + for i, org := range orgs { + orgIDs[i] = org.ID + } + + teams, err := GetTeamsByOrgIDs(ctx, orgIDs) + if err != nil { + return nil, err + } + + teamMap := make(map[int64]TeamList, len(orgs)) + for _, team := range teams { + teamMap[team.OrgID] = append(teamMap[team.OrgID], team) + } + + return teamMap, nil +} + // SearchOrganizationsOptions options to filter organizations type SearchOrganizationsOptions struct { db.ListOptions diff --git a/models/organization/org_list_test.go b/models/organization/org_list_test.go index fc8d148a1d..edc8996f3e 100644 --- a/models/organization/org_list_test.go +++ b/models/organization/org_list_test.go @@ -60,3 +60,14 @@ func TestGetUserOrgsList(t *testing.T) { assert.EqualValues(t, 2, orgs[0].NumRepos) } } + +func TestLoadOrgListTeams(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + orgs, err := organization.GetUserOrgsList(db.DefaultContext, &user_model.User{ID: 4}) + assert.NoError(t, err) + assert.Len(t, orgs, 1) + teamsMap, err := organization.OrgList(orgs).LoadTeams(db.DefaultContext) + assert.NoError(t, err) + assert.Len(t, teamsMap, 1) + assert.Len(t, teamsMap[3], 5) +} diff --git a/models/organization/team_list.go b/models/organization/team_list.go index 5b45429acf..cc2a50236a 100644 --- a/models/organization/team_list.go +++ b/models/organization/team_list.go @@ -126,3 +126,8 @@ func GetUserRepoTeams(ctx context.Context, orgID, userID, repoID int64) (teams T And("team_repo.repo_id=?", repoID). Find(&teams) } + +func GetTeamsByOrgIDs(ctx context.Context, orgIDs []int64) (TeamList, error) { + teams := make([]*Team, 0, 10) + return teams, db.GetEngine(ctx).Where(builder.In("org_id", orgIDs)).Find(&teams) +} diff --git a/services/oauth2_provider/access_token.go b/services/oauth2_provider/access_token.go index 77ddce0534..5cb6fb64c5 100644 --- a/services/oauth2_provider/access_token.go +++ b/services/oauth2_provider/access_token.go @@ -241,14 +241,15 @@ func GetOAuthGroupsForUser(ctx context.Context, user *user_model.User, onlyPubli return nil, fmt.Errorf("GetUserOrgList: %w", err) } + orgTeams, err := org_model.OrgList(orgs).LoadTeams(ctx) + if err != nil { + return nil, fmt.Errorf("LoadTeams: %w", err) + } + var groups []string for _, org := range orgs { groups = append(groups, org.Name) - teams, err := org.LoadTeams(ctx) - if err != nil { - return nil, fmt.Errorf("LoadTeams: %w", err) - } - for _, team := range teams { + for _, team := range orgTeams[org.ID] { if team.IsMember(ctx, user.ID) { groups = append(groups, org.Name+":"+team.LowerName) }