diff --git a/product_category_tax/product.py b/product_category_tax/product.py index 141f29e..4606546 100644 --- a/product_category_tax/product.py +++ b/product_category_tax/product.py @@ -24,19 +24,48 @@ from openerp import models, fields, api, _ from openerp.exceptions import ValidationError -class ProductTemplate(models.Model): - _inherit = 'product.template' +class ProductCategTaxMixin(models.AbstractModel): + _name = 'product.categ.tax.mixin' @api.onchange('categ_id') def onchange_categ_id(self): if self.categ_id: - # I cannot use the commented line below: - #self.taxes_id = self.categ_id.sale_tax_ids.ids - # because it ADDS the taxes (equivalent of (4, ID)) instead - # of replacing the taxes... and I want to REPLACE the taxes - # So I have to use the awful syntax (6, 0, [IDs]) - self.taxes_id = [(6, 0, self.categ_id.sale_tax_ids.ids)] - self.supplier_taxes_id = [(6, 0, self.categ_id.purchase_tax_ids.ids)] + self.taxes_id, self.supplier_taxes_id = ( + self.apply_tax_from_category(self.categ_id)) + + @api.model + def apply_tax_from_category(self, categ_id): + # self.ensure_one() + # I cannot use the commented line below: + # self.taxes_id = self.categ_id.sale_tax_ids.ids + # because it ADDS the taxes (equivalent of (4, ID)) instead + # of replacing the taxes... and I want to REPLACE the taxes + # So I have to use the awful syntax (6, 0, [IDs]) + # values are sent to ('taxes_id' and 'supplier_taxes_id') + return ([(6, 0, categ_id.sale_tax_ids.ids)], + [(6, 0, categ_id.purchase_tax_ids.ids)]) + + @api.model + def write_or_create(self, vals): + if vals.get('categ_id'): + vals['taxes_id'], vals['supplier_taxes_id'] = ( + self.apply_tax_from_category( + self.categ_id.browse(vals['categ_id']))) + + @api.model + def create(self, vals): + self.write_or_create(vals) + return super(ProductCategTaxMixin, self).create(vals) + + @api.multi + def write(self, vals): + self.write_or_create(vals) + return super(ProductCategTaxMixin, self).create(vals) + + +class ProductTemplate(models.Model): + _inherit = ['product.template', 'product.categ.tax.mixin'] + _name = 'product.template' @api.one @api.constrains('taxes_id', 'supplier_taxes_id') @@ -61,13 +90,8 @@ class ProductTemplate(models.Model): class ProductProduct(models.Model): - _inherit = 'product.product' - - @api.onchange('categ_id') - def onchange_categ_id(self): - if self.categ_id: - self.taxes_id = [(6, 0, self.categ_id.sale_tax_ids.ids)] - self.supplier_taxes_id = [(6, 0, self.categ_id.purchase_tax_ids.ids)] + _inherit = ['product.product', 'product.categ.tax.mixin'] + _name = 'product.product' class ProductCategory(models.Model):