Skip to content

Commit 5c3f14a

Browse files
committed
[ty] diagnostic on overridden __setattr__ and __delattr__ in frozen dataclasses
astral-sh/ty#111
1 parent 698231a commit 5c3f14a

File tree

4 files changed

+70
-33
lines changed

4 files changed

+70
-33
lines changed

crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,7 @@ from dataclasses import dataclass
390390
class MyFrozenClass:
391391
x: int
392392

393-
# TODO: Emit a diagnostic here
394-
def __setattr__(self, name: str, value: object) -> None: ...
393+
def __setattr__(self, name: str, value: object) -> None: ... # error: [cannot-overwrite-attribute]
395394

396395
# TODO: Emit a diagnostic here
397396
def __delattr__(self, name: str) -> None: ...

crates/ty_python_semantic/src/types/class.rs

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::semantic_index::{
1919
use crate::types::bound_super::BoundSuperError;
2020
use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension};
2121
use crate::types::context::InferContext;
22-
use crate::types::diagnostic::INVALID_TYPE_ALIAS_TYPE;
22+
use crate::types::diagnostic::{CANNOT_OVERWRITE_ATTRIBUTE, INVALID_TYPE_ALIAS_TYPE};
2323
use crate::types::enums::enum_metadata;
2424
use crate::types::function::{DataclassTransformerParams, KnownFunction};
2525
use crate::types::generics::{
@@ -1394,6 +1394,16 @@ impl<'db> ClassLiteral<'db> {
13941394
self.pep695_generic_context(db).is_some()
13951395
}
13961396

1397+
pub(crate) fn has_dataclass_param(self, db: &dyn Db, param: DataclassFlags) -> bool {
1398+
self.dataclass_params(db)
1399+
.is_some_and(|params| params.flags(db).contains(param))
1400+
|| self.dataclass_transformer_params(db).is_some_and(|params| {
1401+
DataclassParams::from_transformer_params(db, params)
1402+
.flags(db)
1403+
.contains(param)
1404+
})
1405+
}
1406+
13971407
#[salsa::tracked(cycle_initial=generic_context_cycle_initial,
13981408
heap_size=ruff_memory_usage::heap_size,
13991409
)]
@@ -2202,25 +2212,8 @@ impl<'db> ClassLiteral<'db> {
22022212
inherited_generic_context: Option<GenericContext<'db>>,
22032213
name: &str,
22042214
) -> Option<Type<'db>> {
2205-
let dataclass_params = self.dataclass_params(db);
2206-
22072215
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
22082216

2209-
let transformer_params =
2210-
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
2211-
Some(DataclassParams::from_transformer_params(
2212-
db,
2213-
transformer_params,
2214-
))
2215-
} else {
2216-
None
2217-
};
2218-
2219-
let has_dataclass_param = |param| {
2220-
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
2221-
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
2222-
};
2223-
22242217
let instance_ty =
22252218
Type::instance(db, self.apply_optional_specialization(db, specialization));
22262219

@@ -2298,7 +2291,7 @@ impl<'db> ClassLiteral<'db> {
22982291
}
22992292

23002293
let is_kw_only = name == "__replace__"
2301-
|| kw_only.unwrap_or(has_dataclass_param(DataclassFlags::KW_ONLY));
2294+
|| kw_only.unwrap_or(self.has_dataclass_param(db, DataclassFlags::KW_ONLY));
23022295

23032296
// Use the alias name if provided, otherwise use the field name
23042297
let parameter_name = alias.map(Name::new).unwrap_or(field_name);
@@ -2339,7 +2332,7 @@ impl<'db> ClassLiteral<'db> {
23392332

23402333
match (field_policy, name) {
23412334
(CodeGeneratorKind::DataclassLike(_), "__init__") => {
2342-
if !has_dataclass_param(DataclassFlags::INIT) {
2335+
if !self.has_dataclass_param(db, DataclassFlags::INIT) {
23432336
return None;
23442337
}
23452338

@@ -2354,7 +2347,7 @@ impl<'db> ClassLiteral<'db> {
23542347
signature_from_fields(vec![cls_parameter], Some(Type::none(db)))
23552348
}
23562349
(CodeGeneratorKind::DataclassLike(_), "__lt__" | "__le__" | "__gt__" | "__ge__") => {
2357-
if !has_dataclass_param(DataclassFlags::ORDER) {
2350+
if !self.has_dataclass_param(db, DataclassFlags::ORDER) {
23582351
return None;
23592352
}
23602353

@@ -2375,11 +2368,11 @@ impl<'db> ClassLiteral<'db> {
23752368
(CodeGeneratorKind::DataclassLike(_), "__match_args__")
23762369
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
23772370
{
2378-
if !has_dataclass_param(DataclassFlags::MATCH_ARGS) {
2371+
if !self.has_dataclass_param(db, DataclassFlags::MATCH_ARGS) {
23792372
return None;
23802373
}
23812374

2382-
let kw_only_default = has_dataclass_param(DataclassFlags::KW_ONLY);
2375+
let kw_only_default = self.has_dataclass_param(db, DataclassFlags::KW_ONLY);
23832376

23842377
let fields = self.fields(db, specialization, field_policy);
23852378
let match_args = fields
@@ -2397,8 +2390,8 @@ impl<'db> ClassLiteral<'db> {
23972390
(CodeGeneratorKind::DataclassLike(_), "__weakref__")
23982391
if Program::get(db).python_version(db) >= PythonVersion::PY311 =>
23992392
{
2400-
if !has_dataclass_param(DataclassFlags::WEAKREF_SLOT)
2401-
|| !has_dataclass_param(DataclassFlags::SLOTS)
2393+
if !self.has_dataclass_param(db, DataclassFlags::WEAKREF_SLOT)
2394+
|| !self.has_dataclass_param(db, DataclassFlags::SLOTS)
24022395
{
24032396
return None;
24042397
}
@@ -2440,7 +2433,7 @@ impl<'db> ClassLiteral<'db> {
24402433
signature_from_fields(vec![self_parameter], Some(instance_ty))
24412434
}
24422435
(CodeGeneratorKind::DataclassLike(_), "__setattr__") => {
2443-
if has_dataclass_param(DataclassFlags::FROZEN) {
2436+
if self.has_dataclass_param(db, DataclassFlags::FROZEN) {
24442437
let signature = Signature::new(
24452438
Parameters::new([
24462439
Parameter::positional_or_keyword(Name::new_static("self"))
@@ -2458,11 +2451,12 @@ impl<'db> ClassLiteral<'db> {
24582451
(CodeGeneratorKind::DataclassLike(_), "__slots__")
24592452
if Program::get(db).python_version(db) >= PythonVersion::PY310 =>
24602453
{
2461-
has_dataclass_param(DataclassFlags::SLOTS).then(|| {
2462-
let fields = self.fields(db, specialization, field_policy);
2463-
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2464-
Type::heterogeneous_tuple(db, slots)
2465-
})
2454+
self.has_dataclass_param(db, DataclassFlags::SLOTS)
2455+
.then(|| {
2456+
let fields = self.fields(db, specialization, field_policy);
2457+
let slots = fields.keys().map(|name| Type::string_literal(db, name));
2458+
Type::heterogeneous_tuple(db, slots)
2459+
})
24662460
}
24672461
(CodeGeneratorKind::TypedDict, "__setitem__") => {
24682462
let fields = self.fields(db, specialization, field_policy);
@@ -2811,6 +2805,38 @@ impl<'db> ClassLiteral<'db> {
28112805
.collect()
28122806
}
28132807

2808+
pub(crate) fn validate_members(self, context: &InferContext<'db, '_>) {
2809+
let db = context.db();
2810+
let class_body_scope = self.body_scope(db);
2811+
let table = place_table(db, class_body_scope);
2812+
let use_def = use_def_map(db, class_body_scope);
2813+
for (symbol_id, declarations) in use_def.all_end_of_scope_symbol_declarations() {
2814+
let result = place_from_declarations(db, declarations.clone());
2815+
let attr = result.ignore_conflicting_declarations();
2816+
let symbol = table.symbol(symbol_id);
2817+
let name = symbol.name();
2818+
if let Some(Type::FunctionLiteral(literal)) = attr.place.ignore_possibly_undefined()
2819+
&& name == "__setattr__"
2820+
{
2821+
if let Some(CodeGeneratorKind::DataclassLike(_)) =
2822+
CodeGeneratorKind::from_class(db, self, None)
2823+
&& self.has_dataclass_param(db, DataclassFlags::FROZEN)
2824+
{
2825+
if let Some(builder) = context.report_lint(
2826+
&CANNOT_OVERWRITE_ATTRIBUTE,
2827+
literal.node(db, context.file(), context.module()),
2828+
) {
2829+
let mut diagnostic = builder.into_diagnostic(format_args!(
2830+
"Cannot overwrite attribute __setattr__ in class {}",
2831+
self.name(db)
2832+
));
2833+
diagnostic.info("__setattr__");
2834+
}
2835+
}
2836+
}
2837+
}
2838+
}
2839+
28142840
/// Returns a list of all annotated attributes defined in the body of this class. This is similar
28152841
/// to the `__annotations__` attribute at runtime, but also contains default values.
28162842
///

crates/ty_python_semantic/src/types/diagnostic.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) {
4242
registry.register_lint(&AMBIGUOUS_PROTOCOL_MEMBER);
4343
registry.register_lint(&CALL_NON_CALLABLE);
4444
registry.register_lint(&POSSIBLY_MISSING_IMPLICIT_CALL);
45+
registry.register_lint(&CANNOT_OVERWRITE_ATTRIBUTE);
4546
registry.register_lint(&CONFLICTING_ARGUMENT_FORMS);
4647
registry.register_lint(&CONFLICTING_DECLARATIONS);
4748
registry.register_lint(&CONFLICTING_METACLASS);
@@ -356,6 +357,15 @@ declare_lint! {
356357
}
357358
}
358359

360+
declare_lint! {
361+
/// TODO
362+
pub(crate) static CANNOT_OVERWRITE_ATTRIBUTE = {
363+
summary: "TODO",
364+
status: LintStatus::preview("1.0.0"),
365+
default_level: Level::Error,
366+
}
367+
}
368+
359369
declare_lint! {
360370
/// ## What it does
361371
/// Checks for classes definitions which will fail at runtime due to

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
951951
if let Some(protocol) = class.into_protocol_class(self.db()) {
952952
protocol.validate_members(&self.context);
953953
}
954+
955+
class.validate_members(&self.context);
954956
}
955957
}
956958

0 commit comments

Comments
 (0)