Commit 6493a88a authored by qiuqunfeng's avatar qiuqunfeng
Browse files

Refactor getEnabledRuleNum function to return error and improve handling of rule category status

The getEnabledRuleNum function now returns an error for better error handling. It calculates the number of enabled WAF rule categories based on the WAF service's rule category status, ensuring accurate counts in the GetWaf method. This change enhances the robustness of the WAF service response.
parent 0228ee23
......@@ -61,19 +61,28 @@ func NewWafService(clusterClientManager *utils.ClusterClientManager, db *gorm.DB
return &wafService{clusterClientManager: clusterClientManager, db: db, gatewayUrl: gatewayUrl, elasticClient: elasticClient}
}
func getEnabledRuleNum(db *gorm.DB, wafService *model.WafService) int {
ruleNum := 0
var ruleCategoryNum int64
if err := db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Count(&ruleCategoryNum).Error; err != nil {
log.Error().Msgf("failed to get rule categories: %v", err)
return 0
func getEnabledRuleNum(db *gorm.DB, wafService *model.WafService) (int, error) {
// Get total number of rule categories
var totalCategories int64
if err := db.Model(&model.WafRuleCategory{}).Where("status = ?", 0).Count(&totalCategories).Error; err != nil {
return 0, fmt.Errorf("failed to get rule categories: %v", err)
}
if wafService.RuleCategoryStatus != nil && len(wafService.RuleCategoryStatus.CategoryID) == 1 {
ruleNum = len(wafService.RuleCategoryStatus.CategoryID)
// If no rule category status is set, all categories are enabled
if wafService.RuleCategoryStatus == nil {
return int(totalCategories), nil
}
// If status is 0, all categories are enabled
if wafService.RuleCategoryStatus.Status == 0 {
return int(totalCategories), nil
}
ruleCategoryNum = ruleCategoryNum - int64(ruleNum)
return int(ruleCategoryNum)
// If status is 1, count only enabled categories
disabledCount := len(wafService.RuleCategoryStatus.CategoryID)
enabledCount := int(totalCategories) - disabledCount
return enabledCount, nil
}
func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayName string) (*WafService, error) {
......@@ -108,11 +117,15 @@ func (s *wafService) GetWaf(ctx context.Context, regionCode, namespace, gatewayN
hosts := strings.Join(listener.Hosts, "@")
listeners = append(listeners, fmt.Sprintf("%s-%d", hosts, listener.Port))
}
ruleNum, err := getEnabledRuleNum(s.db, wafService)
if err != nil {
return nil, fmt.Errorf("failed to get enabled rule count: %v", err)
}
return &WafService{
GatewayName: wafService.GatewayName,
Mode: wafService.Mode,
RuleNum: getEnabledRuleNum(s.db, wafService),
RuleNum: ruleNum,
AttackNum: wafService.AttackNum,
Listeners: listeners,
}, nil
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment