@@ -19,7 +19,9 @@ use crate::semantic_index::{
1919use crate :: types:: bound_super:: BoundSuperError ;
2020use crate :: types:: constraints:: { ConstraintSet , IteratorConstraintsExtension } ;
2121use crate :: types:: context:: InferContext ;
22- use crate :: types:: diagnostic:: { INVALID_TYPE_ALIAS_TYPE , SUPER_CALL_IN_NAMED_TUPLE_METHOD } ;
22+ use crate :: types:: diagnostic:: {
23+ CANNOT_OVERWRITE_ATTRIBUTE , INVALID_TYPE_ALIAS_TYPE , SUPER_CALL_IN_NAMED_TUPLE_METHOD ,
24+ } ;
2325use crate :: types:: enums:: enum_metadata;
2426use crate :: types:: function:: { DataclassTransformerParams , KnownFunction } ;
2527use crate :: types:: generics:: {
@@ -1923,6 +1925,67 @@ impl<'db> ClassLiteral<'db> {
19231925 Some ( typed_dict_params_from_class_def ( class_stmt) )
19241926 }
19251927
1928+ fn merged_dataclass_params (
1929+ self ,
1930+ db : & ' db dyn Db ,
1931+ field_policy : CodeGeneratorKind < ' db > ,
1932+ ) -> ( Option < DataclassParams < ' db > > , Option < DataclassParams < ' db > > ) {
1933+ let dataclass_params = self . dataclass_params ( db) ;
1934+
1935+ let mut transformer_params =
1936+ if let CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) = field_policy {
1937+ Some ( DataclassParams :: from_transformer_params (
1938+ db,
1939+ transformer_params,
1940+ ) )
1941+ } else {
1942+ None
1943+ } ;
1944+
1945+ // Dataclass transformer flags can be overwritten using class arguments.
1946+ if let Some ( transformer_params) = transformer_params. as_mut ( ) {
1947+ if let Some ( class_def) = self . definition ( db) . kind ( db) . as_class ( ) {
1948+ let module = parsed_module ( db, self . file ( db) ) . load ( db) ;
1949+
1950+ if let Some ( arguments) = & class_def. node ( & module) . arguments {
1951+ let mut flags = transformer_params. flags ( db) ;
1952+
1953+ for keyword in & arguments. keywords {
1954+ if let Some ( arg_name) = & keyword. arg {
1955+ if let Some ( is_set) =
1956+ keyword. value . as_boolean_literal_expr ( ) . map ( |b| b. value )
1957+ {
1958+ for ( flag_name, flag) in DATACLASS_FLAGS {
1959+ if arg_name. as_str ( ) == * flag_name {
1960+ flags. set ( * flag, is_set) ;
1961+ }
1962+ }
1963+ }
1964+ }
1965+ }
1966+
1967+ * transformer_params =
1968+ DataclassParams :: new ( db, flags, transformer_params. field_specifiers ( db) ) ;
1969+ }
1970+ }
1971+ }
1972+
1973+ ( dataclass_params, transformer_params)
1974+ }
1975+
1976+ /// Checks if the given dataclass parameter flag is set for this class.
1977+ /// This checks both the `dataclass_params` and `transformer_params`.
1978+ fn has_dataclass_param (
1979+ self ,
1980+ db : & ' db dyn Db ,
1981+ field_policy : CodeGeneratorKind < ' db > ,
1982+ param : DataclassFlags ,
1983+ ) -> bool {
1984+ let ( dataclass_params, transformer_params) = self . merged_dataclass_params ( db, field_policy) ;
1985+ dataclass_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
1986+ || transformer_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
1987+ }
1988+
19261989 /// Return the explicit `metaclass` of this class, if one is defined.
19271990 ///
19281991 /// ## Note
@@ -2332,53 +2395,8 @@ impl<'db> ClassLiteral<'db> {
23322395 inherited_generic_context : Option < GenericContext < ' db > > ,
23332396 name : & str ,
23342397 ) -> Option < Type < ' db > > {
2335- let dataclass_params = self . dataclass_params ( db) ;
2336-
23372398 let field_policy = CodeGeneratorKind :: from_class ( db, self , specialization) ?;
23382399
2339- let mut transformer_params =
2340- if let CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) = field_policy {
2341- Some ( DataclassParams :: from_transformer_params (
2342- db,
2343- transformer_params,
2344- ) )
2345- } else {
2346- None
2347- } ;
2348-
2349- // Dataclass transformer flags can be overwritten using class arguments.
2350- if let Some ( transformer_params) = transformer_params. as_mut ( ) {
2351- if let Some ( class_def) = self . definition ( db) . kind ( db) . as_class ( ) {
2352- let module = parsed_module ( db, self . file ( db) ) . load ( db) ;
2353-
2354- if let Some ( arguments) = & class_def. node ( & module) . arguments {
2355- let mut flags = transformer_params. flags ( db) ;
2356-
2357- for keyword in & arguments. keywords {
2358- if let Some ( arg_name) = & keyword. arg {
2359- if let Some ( is_set) =
2360- keyword. value . as_boolean_literal_expr ( ) . map ( |b| b. value )
2361- {
2362- for ( flag_name, flag) in DATACLASS_FLAGS {
2363- if arg_name. as_str ( ) == * flag_name {
2364- flags. set ( * flag, is_set) ;
2365- }
2366- }
2367- }
2368- }
2369- }
2370-
2371- * transformer_params =
2372- DataclassParams :: new ( db, flags, transformer_params. field_specifiers ( db) ) ;
2373- }
2374- }
2375- }
2376-
2377- let has_dataclass_param = |param| {
2378- dataclass_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2379- || transformer_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2380- } ;
2381-
23822400 let instance_ty =
23832401 Type :: instance ( db, self . apply_optional_specialization ( db, specialization) ) ;
23842402
@@ -2456,7 +2474,11 @@ impl<'db> ClassLiteral<'db> {
24562474 }
24572475
24582476 let is_kw_only = name == "__replace__"
2459- || kw_only. unwrap_or ( has_dataclass_param ( DataclassFlags :: KW_ONLY ) ) ;
2477+ || kw_only. unwrap_or ( self . has_dataclass_param (
2478+ db,
2479+ field_policy,
2480+ DataclassFlags :: KW_ONLY ,
2481+ ) ) ;
24602482
24612483 // Use the alias name if provided, otherwise use the field name
24622484 let parameter_name =
@@ -2498,7 +2520,7 @@ impl<'db> ClassLiteral<'db> {
24982520
24992521 match ( field_policy, name) {
25002522 ( CodeGeneratorKind :: DataclassLike ( _) , "__init__" ) => {
2501- if !has_dataclass_param ( DataclassFlags :: INIT ) {
2523+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: INIT ) {
25022524 return None ;
25032525 }
25042526
@@ -2513,7 +2535,7 @@ impl<'db> ClassLiteral<'db> {
25132535 signature_from_fields ( vec ! [ cls_parameter] , Some ( Type :: none ( db) ) )
25142536 }
25152537 ( CodeGeneratorKind :: DataclassLike ( _) , "__lt__" | "__le__" | "__gt__" | "__ge__" ) => {
2516- if !has_dataclass_param ( DataclassFlags :: ORDER ) {
2538+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: ORDER ) {
25172539 return None ;
25182540 }
25192541
@@ -2535,9 +2557,10 @@ impl<'db> ClassLiteral<'db> {
25352557 Some ( Type :: function_like_callable ( db, signature) )
25362558 }
25372559 ( CodeGeneratorKind :: DataclassLike ( _) , "__hash__" ) => {
2538- let unsafe_hash = has_dataclass_param ( DataclassFlags :: UNSAFE_HASH ) ;
2539- let frozen = has_dataclass_param ( DataclassFlags :: FROZEN ) ;
2540- let eq = has_dataclass_param ( DataclassFlags :: EQ ) ;
2560+ let unsafe_hash =
2561+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: UNSAFE_HASH ) ;
2562+ let frozen = self . has_dataclass_param ( db, field_policy, DataclassFlags :: FROZEN ) ;
2563+ let eq = self . has_dataclass_param ( db, field_policy, DataclassFlags :: EQ ) ;
25412564
25422565 if unsafe_hash || ( frozen && eq) {
25432566 let signature = Signature :: new (
@@ -2560,11 +2583,12 @@ impl<'db> ClassLiteral<'db> {
25602583 ( CodeGeneratorKind :: DataclassLike ( _) , "__match_args__" )
25612584 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY310 =>
25622585 {
2563- if !has_dataclass_param ( DataclassFlags :: MATCH_ARGS ) {
2586+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: MATCH_ARGS ) {
25642587 return None ;
25652588 }
25662589
2567- let kw_only_default = has_dataclass_param ( DataclassFlags :: KW_ONLY ) ;
2590+ let kw_only_default =
2591+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: KW_ONLY ) ;
25682592
25692593 let fields = self . fields ( db, specialization, field_policy) ;
25702594 let match_args = fields
@@ -2582,8 +2606,8 @@ impl<'db> ClassLiteral<'db> {
25822606 ( CodeGeneratorKind :: DataclassLike ( _) , "__weakref__" )
25832607 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY311 =>
25842608 {
2585- if !has_dataclass_param ( DataclassFlags :: WEAKREF_SLOT )
2586- || !has_dataclass_param ( DataclassFlags :: SLOTS )
2609+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: WEAKREF_SLOT )
2610+ || !self . has_dataclass_param ( db , field_policy , DataclassFlags :: SLOTS )
25872611 {
25882612 return None ;
25892613 }
@@ -2625,7 +2649,7 @@ impl<'db> ClassLiteral<'db> {
26252649 signature_from_fields ( vec ! [ self_parameter] , Some ( instance_ty) )
26262650 }
26272651 ( CodeGeneratorKind :: DataclassLike ( _) , "__setattr__" ) => {
2628- if has_dataclass_param ( DataclassFlags :: FROZEN ) {
2652+ if self . has_dataclass_param ( db , field_policy , DataclassFlags :: FROZEN ) {
26292653 let signature = Signature :: new (
26302654 Parameters :: new (
26312655 db,
@@ -2646,11 +2670,12 @@ impl<'db> ClassLiteral<'db> {
26462670 ( CodeGeneratorKind :: DataclassLike ( _) , "__slots__" )
26472671 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY310 =>
26482672 {
2649- has_dataclass_param ( DataclassFlags :: SLOTS ) . then ( || {
2650- let fields = self . fields ( db, specialization, field_policy) ;
2651- let slots = fields. keys ( ) . map ( |name| Type :: string_literal ( db, name) ) ;
2652- Type :: heterogeneous_tuple ( db, slots)
2653- } )
2673+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: SLOTS )
2674+ . then ( || {
2675+ let fields = self . fields ( db, specialization, field_policy) ;
2676+ let slots = fields. keys ( ) . map ( |name| Type :: string_literal ( db, name) ) ;
2677+ Type :: heterogeneous_tuple ( db, slots)
2678+ } )
26542679 }
26552680 ( CodeGeneratorKind :: TypedDict , "__setitem__" ) => {
26562681 let fields = self . fields ( db, specialization, field_policy) ;
@@ -3036,6 +3061,42 @@ impl<'db> ClassLiteral<'db> {
30363061 . collect ( )
30373062 }
30383063
3064+ pub ( crate ) fn validate_members ( self , context : & InferContext < ' db , ' _ > ) {
3065+ let db = context. db ( ) ;
3066+ let Some ( field_policy) = CodeGeneratorKind :: from_class ( db, self , None ) else {
3067+ return ;
3068+ } ;
3069+ let class_body_scope = self . body_scope ( db) ;
3070+ let table = place_table ( db, class_body_scope) ;
3071+ let use_def = use_def_map ( db, class_body_scope) ;
3072+ for ( symbol_id, declarations) in use_def. all_end_of_scope_symbol_declarations ( ) {
3073+ let result = place_from_declarations ( db, declarations. clone ( ) ) ;
3074+ let attr = result. ignore_conflicting_declarations ( ) ;
3075+ let symbol = table. symbol ( symbol_id) ;
3076+ let name = symbol. name ( ) ;
3077+ if let Some ( Type :: FunctionLiteral ( literal) ) = attr. place . ignore_possibly_undefined ( )
3078+ && matches ! ( name. as_str( ) , "__setattr__" | "__delattr__" )
3079+ {
3080+ if let Some ( CodeGeneratorKind :: DataclassLike ( _) ) =
3081+ CodeGeneratorKind :: from_class ( db, self , None )
3082+ && self . has_dataclass_param ( db, field_policy, DataclassFlags :: FROZEN )
3083+ {
3084+ if let Some ( builder) = context. report_lint (
3085+ & CANNOT_OVERWRITE_ATTRIBUTE ,
3086+ literal. node ( db, context. file ( ) , context. module ( ) ) ,
3087+ ) {
3088+ let mut diagnostic = builder. into_diagnostic ( format_args ! (
3089+ "Cannot overwrite attribute `{}` in class `{}`" ,
3090+ name,
3091+ self . name( db)
3092+ ) ) ;
3093+ diagnostic. info ( name) ;
3094+ }
3095+ }
3096+ }
3097+ }
3098+ }
3099+
30393100 /// Returns a list of all annotated attributes defined in the body of this class. This is similar
30403101 /// to the `__annotations__` attribute at runtime, but also contains default values.
30413102 ///
0 commit comments