Skip to content

Commit 59d5f5e

Browse files
zhangyubo0722TingquanGao
authored andcommitted
fix hip config
1 parent 61e5589 commit 59d5f5e

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

ppdet/engine/export_utils.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,12 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
245245

246246
fuse_normalize = reader_cfg.get('fuse_normalize', False)
247247
sample_transforms = reader_cfg['sample_transforms']
248-
hpi_dynamic_shape = None
249248
for st in sample_transforms[1:]:
250249
for key, value in st.items():
251250
p = {'type': key}
252251
if key == 'Resize':
253252
if int(image_shape[1]) != -1:
254253
value['target_size'] = image_shape[1:]
255-
hpi_dynamic_shape = image_shape[1:]
256254
value['interp'] = value.get('interp', 1) # cv2.INTER_LINEAR
257255
if fuse_normalize and key == 'NormalizeImage':
258256
continue
@@ -277,7 +275,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
277275
preprocess_list.append(p)
278276
break
279277

280-
return preprocess_list, label_list, hpi_dynamic_shape
278+
return preprocess_list, label_list
281279

282280

283281
def _parse_tracker(tracker_cfg):
@@ -287,7 +285,7 @@ def _parse_tracker(tracker_cfg):
287285
return tracker_params
288286

289287

290-
def _dump_infer_config(config, path, image_shape, model):
288+
def _dump_infer_config(config, path, image_shape, model, input_spec):
291289
arch_state = False
292290
from ppdet.core.config.yaml_helpers import setup_orderdict
293291
setup_orderdict()
@@ -381,34 +379,47 @@ def _dump_infer_config(config, path, image_shape, model):
381379
reader_cfg = config['TestReader']
382380
dataset_cfg = config['TestDataset']
383381

384-
infer_cfg['Preprocess'], infer_cfg['label_list'], hpi_dynamic_shape = _parse_reader(
382+
infer_cfg['Preprocess'], infer_cfg['label_list'] = _parse_reader(
385383
reader_cfg, dataset_cfg, config['metric'], label_arch, image_shape[1:])
386384
if config.get("uniform_output_enabled", None):
385+
for d in input_spec:
386+
if 'image' in d:
387+
hpi_dynamic_shape = list(d['image'].shape[2:])
387388
def get_dynamic_shapes(hpi_shape):
388389
return [[1, 3] + hpi_shape, [1, 3] + hpi_shape, [8, 3] + hpi_shape]
389390

390-
dynamic_shapes = get_dynamic_shapes(hpi_dynamic_shape) if hpi_dynamic_shape else [
391+
dynamic_shapes = get_dynamic_shapes(hpi_dynamic_shape) if hpi_dynamic_shape != [-1, -1] else [
391392
[1, 3, 320, 320],
392393
[1, 3, 640, 640],
393394
[8, 3, 1280, 1280]
394395
]
395396
shapes = {
396397
"image": dynamic_shapes,
397-
"im_shape": [[1, 2], [1, 2], [8, 2]],
398398
"scale_factor": [[1, 2], [1, 2], [8, 2]]
399399
}
400-
trt_dynamic_shape = [
401-
[dim for _ in range(shape[0]) for dim in shape[2:]]
402-
for shape in dynamic_shapes
403-
]
404400
trt_dynamic_shape_input_data = {
405-
"im_shape": trt_dynamic_shape,
406401
"scale_factor": [
407402
[2, 2],
408403
[1, 1],
409404
[0.67 for _ in range(2 * shapes["scale_factor"][-1][0])]
410405
]
411406
}
407+
model_names_required_imgsize = [
408+
"DETR",
409+
"DINO",
410+
"RCNN",
411+
"YOLOv3",
412+
"CenterNet",
413+
"BlazeFace",
414+
"BlazeFace-FPN-SSH",
415+
]
416+
if any(name in config.get('pdx_model_name', None) for name in model_names_required_imgsize):
417+
shapes["im_shape"] = [[1, 2], [1, 2], [8, 2]]
418+
trt_dynamic_shape = [
419+
[dim for _ in range(shape[0]) for dim in shape[2:]]
420+
for shape in dynamic_shapes
421+
]
422+
trt_dynamic_shape_input_data["im_shape"] = trt_dynamic_shape
412423
hpi_config = OrderedDict({
413424
"backend_configs": OrderedDict({
414425
"paddle_infer": OrderedDict({

ppdet/engine/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,11 +1217,6 @@ def _get_infer_cfg_and_input_spec(self,
12171217
if export_post_process and not export_benchmark:
12181218
image_shape = [None] + image_shape[1:]
12191219

1220-
# Save infer cfg
1221-
_dump_infer_config(self.cfg,
1222-
os.path.join(save_dir, yaml_name), image_shape,
1223-
model)
1224-
12251220
input_spec = [{
12261221
"image": InputSpec(
12271222
shape=image_shape, name='image'),
@@ -1263,6 +1258,11 @@ def _get_infer_cfg_and_input_spec(self,
12631258
"image": InputSpec(
12641259
shape=image_shape, name='image')
12651260
}]
1261+
1262+
# Save infer cfg
1263+
_dump_infer_config(self.cfg,
1264+
os.path.join(save_dir, yaml_name), image_shape,
1265+
model, input_spec)
12661266

12671267
return static_model, pruned_input_spec, input_spec
12681268

@@ -1299,7 +1299,7 @@ def export(self, output_dir='output_inference', for_fd=False):
12991299
try:
13001300
import encryption
13011301
except ModuleNotFoundError:
1302-
print("failed to import encryption")
1302+
logger.info("Skipping import of the encryption module.")
13031303
paddle_version = version.parse(paddle.__version__)
13041304
if self.cfg.get("export_with_pir", False):
13051305
assert (paddle_version >= version.parse(

0 commit comments

Comments
 (0)