Commit 6493a88a authored by qiuqunfeng's avatar qiuqunfeng
Browse files

Refactor getEnabledRuleNum function to return error and improve handling of rule category status

The getEnabledRuleNum function now returns an error for better error handling. It calculates the number of enabled WAF rule categories based on the WAF service's rule category status, ensuring accurate counts in the GetWaf method. This change enhances the robustness of the WAF service response.
parent 0228ee23
...@@ -61,19 +61,28 @@ func NewWafService(clusterClientManager *utils.ClusterClientManager, db *gorm.DB ...@@ -61,19 +61,28 @@ func NewWafService(clusterClientManager *utils.ClusterClientManager, db *gorm.DB
return &wafService{clusterClientManager: clusterClientManager, db: db, gatewayUrl: gatewayUrl, elasticClient: elasticClient} return &wafService{clusterClientManager: clusterClientManager, db: db, gatewayUrl: gatewayUrl, elasticClient: elasticClient}
} }
func getEnabledRuleNum(db *gorm.DB, wafService *model.WafService) int { func getEnabledRuleNum(db *gorm.DB, wafService *model.WafService) (int, error) {
ruleNum := 0 // Get total number of rule categories
var ruleCategoryNum int64 var totalCategories int64
if err := db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Count(&ruleCategoryNum).Error; err != nil { if err := db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Count(&totalCategories).Error; err != nil {
log.Error().Msgf("failed to get rule categories: %v", err) return 0, fmt.Errorf("failed to get rule categories: %v", err)
return 0
} }
if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) == 1 {
ruleNum = len(wafService.RuleCategoryStatus.CategoryID) // If no rule category status is set, all categories are enabled
if wafService.RuleCategoryStatus == nil {
return int(totalCategories), nil
}
// If status is 0, all categories are enabled
if wafService.RuleCategoryStatus.Status == 0 {
return int(totalCategories), nil
} }
ruleCategoryNum = ruleCategoryNum - int64(ruleNum) // If status is 1, count only enabled categories
return int(ruleCategoryNum) disabledCount := len(wafService.RuleCategoryStatus.CategoryID)
enabledCount := int(totalCategories) - disabledCount
return enabledCount, nil
} }
func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayName string) (*WafService, error) { func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayName string) (*WafService, error) {
...@@ -108,11 +117,15 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN ...@@ -108,11 +117,15 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN
hosts := strings.Join(listener.Hosts, "@") hosts := strings.Join(listener.Hosts, "@")
listeners = append(listeners, fmt.Sprintf("%s-%d", hosts, listener.Port)) listeners = append(listeners, fmt.Sprintf("%s-%d", hosts, listener.Port))
} }
ruleNum, err := getEnabledRuleNum(s.db, wafService)
if err != nil {
return nil, fmt.Errorf("failed to get enabled rule count: %v", err)
}
return &WafService{ return &WafService{
GatewayName: wafService.GatewayName, GatewayName: wafService.GatewayName,
Mode: wafService.Mode, Mode: wafService.Mode,
RuleNum: getEnabledRuleNum(s.db, wafService), RuleNum: ruleNum,
AttackNum: wafService.AttackNum, AttackNum: wafService.AttackNum,
Listeners: listeners, Listeners: listeners,
}, nil }, nil
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment