@@ -19,7 +19,9 @@ use crate::semantic_index::{
1919use crate :: types:: bound_super:: BoundSuperError ;
2020use crate :: types:: constraints:: { ConstraintSet , IteratorConstraintsExtension } ;
2121use crate :: types:: context:: InferContext ;
22- use crate :: types:: diagnostic:: { INVALID_TYPE_ALIAS_TYPE , SUPER_CALL_IN_NAMED_TUPLE_METHOD } ;
22+ use crate :: types:: diagnostic:: {
23+ CANNOT_OVERWRITE_ATTRIBUTE , INVALID_TYPE_ALIAS_TYPE , SUPER_CALL_IN_NAMED_TUPLE_METHOD ,
24+ } ;
2325use crate :: types:: enums:: enum_metadata;
2426use crate :: types:: function:: { DataclassTransformerParams , KnownFunction } ;
2527use crate :: types:: generics:: {
@@ -1923,6 +1925,68 @@ impl<'db> ClassLiteral<'db> {
19231925 Some ( typed_dict_params_from_class_def ( class_stmt) )
19241926 }
19251927
1928+ /// Returns a tuple containing both dataclass params and dataclass transform params
1929+ fn merged_dataclass_params (
1930+ self ,
1931+ db : & ' db dyn Db ,
1932+ field_policy : CodeGeneratorKind < ' db > ,
1933+ ) -> ( Option < DataclassParams < ' db > > , Option < DataclassParams < ' db > > ) {
1934+ let dataclass_params = self . dataclass_params ( db) ;
1935+
1936+ let mut transformer_params =
1937+ if let CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) = field_policy {
1938+ Some ( DataclassParams :: from_transformer_params (
1939+ db,
1940+ transformer_params,
1941+ ) )
1942+ } else {
1943+ None
1944+ } ;
1945+
1946+ // Dataclass transformer flags can be overwritten using class arguments.
1947+ if let Some ( transformer_params) = transformer_params. as_mut ( ) {
1948+ if let Some ( class_def) = self . definition ( db) . kind ( db) . as_class ( ) {
1949+ let module = parsed_module ( db, self . file ( db) ) . load ( db) ;
1950+
1951+ if let Some ( arguments) = & class_def. node ( & module) . arguments {
1952+ let mut flags = transformer_params. flags ( db) ;
1953+
1954+ for keyword in & arguments. keywords {
1955+ if let Some ( arg_name) = & keyword. arg {
1956+ if let Some ( is_set) =
1957+ keyword. value . as_boolean_literal_expr ( ) . map ( |b| b. value )
1958+ {
1959+ for ( flag_name, flag) in DATACLASS_FLAGS {
1960+ if arg_name. as_str ( ) == * flag_name {
1961+ flags. set ( * flag, is_set) ;
1962+ }
1963+ }
1964+ }
1965+ }
1966+ }
1967+
1968+ * transformer_params =
1969+ DataclassParams :: new ( db, flags, transformer_params. field_specifiers ( db) ) ;
1970+ }
1971+ }
1972+ }
1973+
1974+ ( dataclass_params, transformer_params)
1975+ }
1976+
1977+ /// Checks if the given dataclass parameter flag is set for this class.
1978+ /// This checks both the `dataclass_params` and `transformer_params`.
1979+ fn has_dataclass_param (
1980+ self ,
1981+ db : & ' db dyn Db ,
1982+ field_policy : CodeGeneratorKind < ' db > ,
1983+ param : DataclassFlags ,
1984+ ) -> bool {
1985+ let ( dataclass_params, transformer_params) = self . merged_dataclass_params ( db, field_policy) ;
1986+ dataclass_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
1987+ || transformer_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
1988+ }
1989+
19261990 /// Return the explicit `metaclass` of this class, if one is defined.
19271991 ///
19281992 /// ## Note
@@ -2332,53 +2396,8 @@ impl<'db> ClassLiteral<'db> {
23322396 inherited_generic_context : Option < GenericContext < ' db > > ,
23332397 name : & str ,
23342398 ) -> Option < Type < ' db > > {
2335- let dataclass_params = self . dataclass_params ( db) ;
2336-
23372399 let field_policy = CodeGeneratorKind :: from_class ( db, self , specialization) ?;
23382400
2339- let mut transformer_params =
2340- if let CodeGeneratorKind :: DataclassLike ( Some ( transformer_params) ) = field_policy {
2341- Some ( DataclassParams :: from_transformer_params (
2342- db,
2343- transformer_params,
2344- ) )
2345- } else {
2346- None
2347- } ;
2348-
2349- // Dataclass transformer flags can be overwritten using class arguments.
2350- if let Some ( transformer_params) = transformer_params. as_mut ( ) {
2351- if let Some ( class_def) = self . definition ( db) . kind ( db) . as_class ( ) {
2352- let module = parsed_module ( db, self . file ( db) ) . load ( db) ;
2353-
2354- if let Some ( arguments) = & class_def. node ( & module) . arguments {
2355- let mut flags = transformer_params. flags ( db) ;
2356-
2357- for keyword in & arguments. keywords {
2358- if let Some ( arg_name) = & keyword. arg {
2359- if let Some ( is_set) =
2360- keyword. value . as_boolean_literal_expr ( ) . map ( |b| b. value )
2361- {
2362- for ( flag_name, flag) in DATACLASS_FLAGS {
2363- if arg_name. as_str ( ) == * flag_name {
2364- flags. set ( * flag, is_set) ;
2365- }
2366- }
2367- }
2368- }
2369- }
2370-
2371- * transformer_params =
2372- DataclassParams :: new ( db, flags, transformer_params. field_specifiers ( db) ) ;
2373- }
2374- }
2375- }
2376-
2377- let has_dataclass_param = |param| {
2378- dataclass_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2379- || transformer_params. is_some_and ( |params| params. flags ( db) . contains ( param) )
2380- } ;
2381-
23822401 let instance_ty =
23832402 Type :: instance ( db, self . apply_optional_specialization ( db, specialization) ) ;
23842403
@@ -2456,7 +2475,11 @@ impl<'db> ClassLiteral<'db> {
24562475 }
24572476
24582477 let is_kw_only = name == "__replace__"
2459- || kw_only. unwrap_or ( has_dataclass_param ( DataclassFlags :: KW_ONLY ) ) ;
2478+ || kw_only. unwrap_or ( self . has_dataclass_param (
2479+ db,
2480+ field_policy,
2481+ DataclassFlags :: KW_ONLY ,
2482+ ) ) ;
24602483
24612484 // Use the alias name if provided, otherwise use the field name
24622485 let parameter_name =
@@ -2498,7 +2521,7 @@ impl<'db> ClassLiteral<'db> {
24982521
24992522 match ( field_policy, name) {
25002523 ( CodeGeneratorKind :: DataclassLike ( _) , "__init__" ) => {
2501- if !has_dataclass_param ( DataclassFlags :: INIT ) {
2524+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: INIT ) {
25022525 return None ;
25032526 }
25042527
@@ -2513,7 +2536,7 @@ impl<'db> ClassLiteral<'db> {
25132536 signature_from_fields ( vec ! [ cls_parameter] , Some ( Type :: none ( db) ) )
25142537 }
25152538 ( CodeGeneratorKind :: DataclassLike ( _) , "__lt__" | "__le__" | "__gt__" | "__ge__" ) => {
2516- if !has_dataclass_param ( DataclassFlags :: ORDER ) {
2539+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: ORDER ) {
25172540 return None ;
25182541 }
25192542
@@ -2535,9 +2558,10 @@ impl<'db> ClassLiteral<'db> {
25352558 Some ( Type :: function_like_callable ( db, signature) )
25362559 }
25372560 ( CodeGeneratorKind :: DataclassLike ( _) , "__hash__" ) => {
2538- let unsafe_hash = has_dataclass_param ( DataclassFlags :: UNSAFE_HASH ) ;
2539- let frozen = has_dataclass_param ( DataclassFlags :: FROZEN ) ;
2540- let eq = has_dataclass_param ( DataclassFlags :: EQ ) ;
2561+ let unsafe_hash =
2562+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: UNSAFE_HASH ) ;
2563+ let frozen = self . has_dataclass_param ( db, field_policy, DataclassFlags :: FROZEN ) ;
2564+ let eq = self . has_dataclass_param ( db, field_policy, DataclassFlags :: EQ ) ;
25412565
25422566 if unsafe_hash || ( frozen && eq) {
25432567 let signature = Signature :: new (
@@ -2560,11 +2584,12 @@ impl<'db> ClassLiteral<'db> {
25602584 ( CodeGeneratorKind :: DataclassLike ( _) , "__match_args__" )
25612585 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY310 =>
25622586 {
2563- if !has_dataclass_param ( DataclassFlags :: MATCH_ARGS ) {
2587+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: MATCH_ARGS ) {
25642588 return None ;
25652589 }
25662590
2567- let kw_only_default = has_dataclass_param ( DataclassFlags :: KW_ONLY ) ;
2591+ let kw_only_default =
2592+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: KW_ONLY ) ;
25682593
25692594 let fields = self . fields ( db, specialization, field_policy) ;
25702595 let match_args = fields
@@ -2582,8 +2607,8 @@ impl<'db> ClassLiteral<'db> {
25822607 ( CodeGeneratorKind :: DataclassLike ( _) , "__weakref__" )
25832608 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY311 =>
25842609 {
2585- if !has_dataclass_param ( DataclassFlags :: WEAKREF_SLOT )
2586- || !has_dataclass_param ( DataclassFlags :: SLOTS )
2610+ if !self . has_dataclass_param ( db , field_policy , DataclassFlags :: WEAKREF_SLOT )
2611+ || !self . has_dataclass_param ( db , field_policy , DataclassFlags :: SLOTS )
25872612 {
25882613 return None ;
25892614 }
@@ -2625,7 +2650,7 @@ impl<'db> ClassLiteral<'db> {
26252650 signature_from_fields ( vec ! [ self_parameter] , Some ( instance_ty) )
26262651 }
26272652 ( CodeGeneratorKind :: DataclassLike ( _) , "__setattr__" ) => {
2628- if has_dataclass_param ( DataclassFlags :: FROZEN ) {
2653+ if self . has_dataclass_param ( db , field_policy , DataclassFlags :: FROZEN ) {
26292654 let signature = Signature :: new (
26302655 Parameters :: new (
26312656 db,
@@ -2646,11 +2671,12 @@ impl<'db> ClassLiteral<'db> {
26462671 ( CodeGeneratorKind :: DataclassLike ( _) , "__slots__" )
26472672 if Program :: get ( db) . python_version ( db) >= PythonVersion :: PY310 =>
26482673 {
2649- has_dataclass_param ( DataclassFlags :: SLOTS ) . then ( || {
2650- let fields = self . fields ( db, specialization, field_policy) ;
2651- let slots = fields. keys ( ) . map ( |name| Type :: string_literal ( db, name) ) ;
2652- Type :: heterogeneous_tuple ( db, slots)
2653- } )
2674+ self . has_dataclass_param ( db, field_policy, DataclassFlags :: SLOTS )
2675+ . then ( || {
2676+ let fields = self . fields ( db, specialization, field_policy) ;
2677+ let slots = fields. keys ( ) . map ( |name| Type :: string_literal ( db, name) ) ;
2678+ Type :: heterogeneous_tuple ( db, slots)
2679+ } )
26542680 }
26552681 ( CodeGeneratorKind :: TypedDict , "__setitem__" ) => {
26562682 let fields = self . fields ( db, specialization, field_policy) ;
@@ -3036,6 +3062,42 @@ impl<'db> ClassLiteral<'db> {
30363062 . collect ( )
30373063 }
30383064
3065+ pub ( crate ) fn validate_members ( self , context : & InferContext < ' db , ' _ > ) {
3066+ let db = context. db ( ) ;
3067+ let Some ( field_policy) = CodeGeneratorKind :: from_class ( db, self , None ) else {
3068+ return ;
3069+ } ;
3070+ let class_body_scope = self . body_scope ( db) ;
3071+ let table = place_table ( db, class_body_scope) ;
3072+ let use_def = use_def_map ( db, class_body_scope) ;
3073+ for ( symbol_id, declarations) in use_def. all_end_of_scope_symbol_declarations ( ) {
3074+ let result = place_from_declarations ( db, declarations. clone ( ) ) ;
3075+ let attr = result. ignore_conflicting_declarations ( ) ;
3076+ let symbol = table. symbol ( symbol_id) ;
3077+ let name = symbol. name ( ) ;
3078+ if let Some ( Type :: FunctionLiteral ( literal) ) = attr. place . ignore_possibly_undefined ( )
3079+ && matches ! ( name. as_str( ) , "__setattr__" | "__delattr__" )
3080+ {
3081+ if let Some ( CodeGeneratorKind :: DataclassLike ( _) ) =
3082+ CodeGeneratorKind :: from_class ( db, self , None )
3083+ && self . has_dataclass_param ( db, field_policy, DataclassFlags :: FROZEN )
3084+ {
3085+ if let Some ( builder) = context. report_lint (
3086+ & CANNOT_OVERWRITE_ATTRIBUTE ,
3087+ literal. node ( db, context. file ( ) , context. module ( ) ) ,
3088+ ) {
3089+ let mut diagnostic = builder. into_diagnostic ( format_args ! (
3090+ "Cannot overwrite attribute `{}` in class `{}`" ,
3091+ name,
3092+ self . name( db)
3093+ ) ) ;
3094+ diagnostic. info ( name) ;
3095+ }
3096+ }
3097+ }
3098+ }
3099+ }
3100+
30393101 /// Returns a list of all annotated attributes defined in the body of this class. This is similar
30403102 /// to the `__annotations__` attribute at runtime, but also contains default values.
30413103 ///
0 commit comments