Skip to content

Commit da5dfae

Browse files
committed
[ty] diagnostic on overridden __setattr__ and __delattr__ in frozen dataclasses
astral-sh/ty#111
1 parent 285d641 commit da5dfae

File tree

7 files changed

+272
-139
lines changed

7 files changed

+272
-139
lines changed

crates/ty/docs/rules.md

Lines changed: 106 additions & 74 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ frozen_instance = MyFrozenClass(1)
443443
frozen_instance.x = 2 # error: [invalid-assignment]
444444
```
445445

446-
If `__setattr__()` or `__delattr__()` is defined in the class, we should emit a diagnostic.
446+
If `__setattr__()` or `__delattr__()` is defined in the class, a diagnostic is emitted.
447447

448448
```py
449449
from dataclasses import dataclass
@@ -452,10 +452,10 @@ from dataclasses import dataclass
452452
class MyFrozenClass:
453453
x: int
454454

455-
# TODO: Emit a diagnostic here
455+
# error: [cannot-overwrite-attribute] "Cannot overwrite attribute `__setattr__` in class `MyFrozenClass`"
456456
def __setattr__(self, name: str, value: object) -> None: ...
457457

458-
# TODO: Emit a diagnostic here
458+
# error: [cannot-overwrite-attribute] "Cannot overwrite attribute `__delattr__` in class `MyFrozenClass`"
459459
def __delattr__(self, name: str) -> None: ...
460460
```
461461

crates/ty_python_semantic/src/types/class.rs

Lines changed: 123 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ use crate::semantic_index::{
1919
use crate::types::bound_super::BoundSuperError;
2020
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
2121
use crate::types::context::InferContext;
22-
use crate::types::diagnostic::{INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD};
22+
use crate::types::diagnostic::{
23+
CANNOT_OVERWRITE_ATTRIBUTE, INVALID_TYPE_ALIAS_TYPE, SUPER_CALL_IN_NAMED_TUPLE_METHOD,
24+
};
2325
use crate::types::enums::enum_metadata;
2426
use crate::types::function::{DataclassTransformerParams, KnownFunction};
2527
use crate::types::generics::{
@@ -1923,6 +1925,67 @@ impl<'db> ClassLiteral<'db> {
19231925
Some(typed_dict_params_from_class_def(class_stmt))
19241926
}
19251927

1928+
fn merged_dataclass_params(
1929+
self,
1930+
db: &'db dyn Db,
1931+
field_policy: CodeGeneratorKind<'db>,
1932+
) -> (Option<DataclassParams<'db>>, Option<DataclassParams<'db>>) {
1933+
let dataclass_params = self.dataclass_params(db);
1934+
1935+
let mut transformer_params =
1936+
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
1937+
Some(DataclassParams::from_transformer_params(
1938+
db,
1939+
transformer_params,
1940+
))
1941+
} else {
1942+
None
1943+
};
1944+
1945+
// Dataclass transformer flags can be overwritten using class arguments.
1946+
if let Some(transformer_params) = transformer_params.as_mut() {
1947+
if let Some(class_def) = self.definition(db).kind(db).as_class() {
1948+
let module = parsed_module(db, self.file(db)).load(db);
1949+
1950+
if let Some(arguments) = &class_def.node(&module).arguments {
1951+
let mut flags = transformer_params.flags(db);
1952+
1953+
for keyword in &arguments.keywords {
1954+
if let Some(arg_name) = &keyword.arg {
1955+
if let Some(is_set) =
1956+
keyword.value.as_boolean_literal_expr().map(|b| b.value)
1957+
{
1958+
for (flag_name, flag) in DATACLASS_FLAGS {
1959+
if arg_name.as_str() == *flag_name {
1960+
flags.set(*flag, is_set);
1961+
}
1962+
}
1963+
}
1964+
}
1965+
}
1966+
1967+
*transformer_params =
1968+
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
1969+
}
1970+
}
1971+
}
1972+
1973+
(dataclass_params, transformer_params)
1974+
}
1975+
1976+
/// Checks if the given dataclass parameter flag is set for this class.
1977+
/// This checks both the `dataclass_params` and `transformer_params`.
1978+
fn has_dataclass_param(
1979+
self,
1980+
db: &'db dyn Db,
1981+
field_policy: CodeGeneratorKind<'db>,
1982+
param: DataclassFlags,
1983+
) -> bool {
1984+
let (dataclass_params, transformer_params) = self.merged_dataclass_params(db, field_policy);
1985+
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
1986+
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
1987+
}
1988+
19261989
/// Return the explicit `metaclass` of this class, if one is defined.
19271990
///
19281991
/// ## Note
@@ -2332,53 +2395,8 @@ impl<'db> ClassLiteral<'db> {
23322395
inherited_generic_context: Option<GenericContext<'db>>,
23332396
name: &str,
23342397
) -> Option<Type<'db>> {
2335-
let dataclass_params = self.dataclass_params(db);
2336-
23372398
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
23382399

2339-
let mut transformer_params =
2340-
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
2341-
Some(DataclassParams::from_transformer_params(
2342-
db,
2343-
transformer_params,
2344-
))
2345-
} else {
2346-
None
2347-
};
2348-
2349-
// Dataclass transformer flags can be overwritten using class arguments.
2350-
if let Some(transformer_params) = transformer_params.as_mut() {
2351-
if let Some(class_def) = self.definition(db).kind(db).as_class() {
2352-
let module = parsed_module(db, self.file(db)).load(db);
2353-
2354-
if let Some(arguments) = &class_def.node(&module).arguments {
2355-
let mut flags = transformer_params.flags(db);
2356-
2357-
for keyword in &arguments.keywords {
2358-
if let Some(arg_name) = &keyword.arg {
2359-
if let Some(is_set) =
2360-
keyword.value.as_boolean_literal_expr().map(|b| b.value)
2361-
{
2362-
for (flag_name, flag) in DATACLASS_FLAGS {
2363-
if arg_name.as_str() == *flag_name {
2364-
flags.set(*flag, is_set);
2365-
}
2366-
}
2367-
}
2368-
}
2369-
}
2370-
2371-
*transformer_params =
2372-
DataclassParams::new(db, flags, transformer_params.field_specifiers(db));
2373-
}
2374-
}
2375-
}
2376-
2377-
let has_dataclass_param = |param| {
2378-
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
2379-
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
2380-
};
2381-
23822400
let instance_ty =
23832401
Type::instance(db, self.apply_optional_specialization(db, specialization));
23842402

@@ -2456,7 +2474,11 @@ impl<'db> ClassLiteral<'db> {
24562474
}
24572475

24582476
let is_kw_only = name == "__replace__"
2459-
|| kw_only.unwrap_or(has_dataclass_param(DataclassFlags::KW_ONLY));
2477+
|| kw_only.unwrap_or(self.has_dataclass_param(
2478+
db,
2479+
field_policy,
2480+
DataclassFlags::KW_ONLY,
2481+
));
24602482

24612483
// Use the alias name if provided, otherwise use the field name
24622484
let parameter_name =
@@ -2498,7 +2520,7 @@ impl<'db> ClassLiteral<'db> {
24982520

24992521
match (field_policy, name) {
25002522
(CodeGeneratorKind::DataclassLike(_), "__init__") => {
2501-
if !has_dataclass_param(DataclassFlags::INIT) {
2523+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::INIT) {
25022524
return None;
25032525
}
25042526

@@ -2513,7 +2535,7 @@ impl<'db> ClassLiteral<'db> {
25132535
signature_from_fields(vec![cls_parameter], Some(Type::none(db)))
25142536
}
25152537
(CodeGeneratorKind::DataclassLike(_), "__lt__" | "__le__" | "__gt__" | "__ge__") => {
2516-
if !has_dataclass_param(DataclassFlags::ORDER) {
2538+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::ORDER) {
25172539
return None;
25182540
}
25192541

@@ -2535,9 +2557,10 @@ impl<'db> ClassLiteral<'db> {
25352557
Some(Type::function_like_callable(db, signature))
25362558
}
25372559
(CodeGeneratorKind::DataclassLike(_), "__hash__") => {
2538-
let unsafe_hash = has_dataclass_param(DataclassFlags::UNSAFE_HASH);
2539-
let frozen = has_dataclass_param(DataclassFlags::FROZEN);
2540-
let eq = has_dataclass_param(DataclassFlags::EQ);
2560+
let unsafe_hash =
2561+
self.has_dataclass_param(db, field_policy, DataclassFlags::UNSAFE_HASH);
2562+
let frozen = self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN);
2563+
let eq = self.has_dataclass_param(db, field_policy, DataclassFlags::EQ);
25412564

25422565
if unsafe_hash || (frozen && eq) {
25432566
let signature = Signature::new(
@@ -2560,11 +2583,12 @@ impl<'db> ClassLiteral<'db> {
25602583
(CodeGeneratorKind::DataclassLike(_), "__match_args__")
25612584
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
25622585
{
2563-
if !has_dataclass_param(DataclassFlags::MATCH_ARGS) {
2586+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::MATCH_ARGS) {
25642587
return None;
25652588
}
25662589

2567-
let kw_only_default = has_dataclass_param(DataclassFlags::KW_ONLY);
2590+
let kw_only_default =
2591+
self.has_dataclass_param(db, field_policy, DataclassFlags::KW_ONLY);
25682592

25692593
let fields = self.fields(db, specialization, field_policy);
25702594
let match_args = fields
@@ -2582,8 +2606,8 @@ impl<'db> ClassLiteral<'db> {
25822606
(CodeGeneratorKind::DataclassLike(_), "__weakref__")
25832607
if Program::get(db).python_version(db) >= PythonVersion::PY311 =>
25842608
{
2585-
if !has_dataclass_param(DataclassFlags::WEAKREF_SLOT)
2586-
|| !has_dataclass_param(DataclassFlags::SLOTS)
2609+
if !self.has_dataclass_param(db, field_policy, DataclassFlags::WEAKREF_SLOT)
2610+
|| !self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
25872611
{
25882612
return None;
25892613
}
@@ -2625,7 +2649,7 @@ impl<'db> ClassLiteral<'db> {
26252649
signature_from_fields(vec![self_parameter], Some(instance_ty))
26262650
}
26272651
(CodeGeneratorKind::DataclassLike(_), "__setattr__") => {
2628-
if has_dataclass_param(DataclassFlags::FROZEN) {
2652+
if self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN) {
26292653
let signature = Signature::new(
26302654
Parameters::new(
26312655
db,
@@ -2646,11 +2670,12 @@ impl<'db> ClassLiteral<'db> {
26462670
(CodeGeneratorKind::DataclassLike(_), "__slots__")
26472671
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
26482672
{
2649-
has_dataclass_param(DataclassFlags::SLOTS).then(|| {
2650-
let fields = self.fields(db, specialization, field_policy);
2651-
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2652-
Type::heterogeneous_tuple(db, slots)
2653-
})
2673+
self.has_dataclass_param(db, field_policy, DataclassFlags::SLOTS)
2674+
.then(|| {
2675+
let fields = self.fields(db, specialization, field_policy);
2676+
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2677+
Type::heterogeneous_tuple(db, slots)
2678+
})
26542679
}
26552680
(CodeGeneratorKind::TypedDict, "__setitem__") => {
26562681
let fields = self.fields(db, specialization, field_policy);
@@ -3036,6 +3061,42 @@ impl<'db> ClassLiteral<'db> {
30363061
.collect()
30373062
}
30383063

3064+
pub(crate) fn validate_members(self, context: &InferContext<'db, '_>) {
3065+
let db = context.db();
3066+
let Some(field_policy) = CodeGeneratorKind::from_class(db, self, None) else {
3067+
return;
3068+
};
3069+
let class_body_scope = self.body_scope(db);
3070+
let table = place_table(db, class_body_scope);
3071+
let use_def = use_def_map(db, class_body_scope);
3072+
for (symbol_id, declarations) in use_def.all_end_of_scope_symbol_declarations() {
3073+
let result = place_from_declarations(db, declarations.clone());
3074+
let attr = result.ignore_conflicting_declarations();
3075+
let symbol = table.symbol(symbol_id);
3076+
let name = symbol.name();
3077+
if let Some(Type::FunctionLiteral(literal)) = attr.place.ignore_possibly_undefined()
3078+
&& matches!(name.as_str(), "__setattr__" | "__delattr__")
3079+
{
3080+
if let Some(CodeGeneratorKind::DataclassLike(_)) =
3081+
CodeGeneratorKind::from_class(db, self, None)
3082+
&& self.has_dataclass_param(db, field_policy, DataclassFlags::FROZEN)
3083+
{
3084+
if let Some(builder) = context.report_lint(
3085+
&CANNOT_OVERWRITE_ATTRIBUTE,
3086+
literal.node(db, context.file(), context.module()),
3087+
) {
3088+
let mut diagnostic = builder.into_diagnostic(format_args!(
3089+
"Cannot overwrite attribute `{}` in class `{}`",
3090+
name,
3091+
self.name(db)
3092+
));
3093+
diagnostic.info(name);
3094+
}
3095+
}
3096+
}
3097+
}
3098+
}
3099+
30393100
/// Returns a list of all annotated attributes defined in the body of this class. This is similar
30403101
/// to the `__annotations__` attribute at runtime, but also contains default values.
30413102
///

crates/ty_python_semantic/src/types/diagnostic.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
5151
registry.register_lint(&AMBIGUOUS_PROTOCOL_MEMBER);
5252
registry.register_lint(&CALL_NON_CALLABLE);
5353
registry.register_lint(&POSSIBLY_MISSING_IMPLICIT_CALL);
54+
registry.register_lint(&CANNOT_OVERWRITE_ATTRIBUTE);
5455
registry.register_lint(&CONFLICTING_ARGUMENT_FORMS);
5556
registry.register_lint(&CONFLICTING_DECLARATIONS);
5657
registry.register_lint(&CONFLICTING_METACLASS);
@@ -393,6 +394,32 @@ declare_lint! {
393394
}
394395
}
395396

397+
declare_lint! {
398+
/// ## What it does
399+
/// Checks for dataclass definitions that have both `frozen=True` and a custom `__setattr__`
400+
/// method defined.
401+
///
402+
/// ## Why is this bad?
403+
/// Frozen dataclasses synthesize `__setattr__` and `__delattr__` methods which raise a
404+
/// `FrozenInstanceError` to emulate immutability.
405+
///
406+
/// Overriding either of these methods raises a runtime error.
407+
///
408+
/// ## Examples
409+
/// ```python
410+
/// from dataclasses import dataclass
411+
///
412+
/// @dataclass(frozen=True)
413+
/// class A:
414+
/// def __setattr__(self, name: str, value: object) -> None: ...
415+
/// ```
416+
pub(crate) static CANNOT_OVERWRITE_ATTRIBUTE = {
417+
summary: "detects dataclasses with `frozen=True` that have a custom `__setattr__` or `__delattr__` implementation",
418+
status: LintStatus::preview("1.0.0"),
419+
default_level: Level::Error,
420+
}
421+
}
422+
396423
declare_lint! {
397424
/// ## What it does
398425
/// Checks for classes definitions which will fail at runtime due to

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
990990
if let Some(protocol) = class.into_protocol_class(self.db()) {
991991
protocol.validate_members(&self.context);
992992
}
993+
994+
class.validate_members(&self.context);
993995
}
994996
}
995997

crates/ty_server/tests/e2e/snapshots/e2e__commands__debug_command.snap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Settings: Settings {
3131
"ambiguous-protocol-member": Warning (Default),
3232
"byte-string-type-annotation": Error (Default),
3333
"call-non-callable": Error (Default),
34+
"cannot-overwrite-attribute": Error (Default),
3435
"conflicting-argument-forms": Error (Default),
3536
"conflicting-declarations": Error (Default),
3637
"conflicting-metaclass": Error (Default),

ty.schema.json

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)