Skip to content

Commit

Permalink
Merge pull request #247 from skippy/fix-class-inheritance
Browse files Browse the repository at this point in the history
allow Protobuf::Message based classes to be inherited
  • Loading branch information
localshred committed Feb 18, 2015
2 parents c0257ec + 9eb9520 commit 96fe1ba
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 77 deletions.
168 changes: 91 additions & 77 deletions lib/protobuf/message/fields.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,117 +3,131 @@ class Message
module Fields

def self.extended(other)
other.extend(ClassMethods)
::Protobuf.deprecator.define_deprecated_methods(
other.singleton_class,
:get_ext_field_by_name => :get_extension_field,
:get_ext_field_by_tag => :get_extension_field,
:get_field_by_name => :get_field,
:get_field_by_tag => :get_field,
)

def inherited(subclass)
inherit_fields!(subclass)
end
end

##
# Field Definition Methods
#
module ClassMethods

# Define an optional field.
#
def optional(type_class, name, tag, options = {})
define_field(:optional, type_class, name, tag, options)
end
##
# Field Definition Methods
#

# Define a repeated field.
#
def repeated(type_class, name, tag, options = {})
define_field(:repeated, type_class, name, tag, options)
end
# Define an optional field.
#
def optional(type_class, name, tag, options = {})
define_field(:optional, type_class, name, tag, options)
end

# Define a required field.
#
def required(type_class, name, tag, options = {})
define_field(:required, type_class, name, tag, options)
end
# Define a repeated field.
#
def repeated(type_class, name, tag, options = {})
define_field(:repeated, type_class, name, tag, options)
end

# Define an extension range.
#
def extensions(range)
extension_ranges << range
end
# Define a required field.
#
def required(type_class, name, tag, options = {})
define_field(:required, type_class, name, tag, options)
end

##
# Field Access Methods
#
# Define an extension range.
#
def extensions(range)
extension_ranges << range
end

def all_fields
@all_fields ||= field_store.values.uniq.sort_by(&:tag)
end
##
# Field Access Methods
#
def all_fields
@all_fields ||= field_store.values.uniq.sort_by(&:tag)
end

def extension_fields
@extension_fields ||= all_fields.select(&:extension?)
end
def extension_fields
@extension_fields ||= all_fields.select(&:extension?)
end

def extension_ranges
@extension_ranges ||= []
end
def extension_ranges
@extension_ranges ||= []
end

def extension_tag?(tag)
tag.respond_to?(:to_i) && get_extension_field(tag).present?
end
def extension_tag?(tag)
tag.respond_to?(:to_i) && get_extension_field(tag).present?
end

def field_store
@field_store ||= {}
end
def field_store
@field_store ||= {}
end

def fields
@fields ||= all_fields.reject(&:extension?)
end
def fields
@fields ||= all_fields.reject(&:extension?)
end

def field_tag?(tag, allow_extension = false)
tag.respond_to?(:to_i) && get_field(tag, allow_extension).present?
end
def field_tag?(tag, allow_extension = false)
tag.respond_to?(:to_i) && get_field(tag, allow_extension).present?
end

def get_extension_field(name_or_tag)
name_or_tag = name_or_tag.to_sym if name_or_tag.respond_to?(:to_sym)
field = field_store[name_or_tag]
field if field.try(:extension?) { false }
end
def get_extension_field(name_or_tag)
name_or_tag = name_or_tag.to_sym if name_or_tag.respond_to?(:to_sym)
field = field_store[name_or_tag]
field if field.try(:extension?) { false }
end

def get_field(name_or_tag, allow_extension = false)
name_or_tag = name_or_tag.to_sym if name_or_tag.respond_to?(:to_sym)
field = field_store[name_or_tag]
def get_field(name_or_tag, allow_extension = false)
name_or_tag = name_or_tag.to_sym if name_or_tag.respond_to?(:to_sym)
field = field_store[name_or_tag]

if field && (allow_extension || !field.extension?)
field
else
nil
if field && (allow_extension || !field.extension?)
field
else
nil
end
end
end

def define_field(rule, type_class, field_name, tag, options)
raise_if_tag_collision(tag, field_name)
raise_if_name_collision(field_name)
def define_field(rule, type_class, field_name, tag, options)
raise_if_tag_collision(tag, field_name)
raise_if_name_collision(field_name)

field = ::Protobuf::Field.build(self, rule, type_class, field_name, tag, options)
field_store[field_name] = field
field_store[tag] = field
field = ::Protobuf::Field.build(self, rule, type_class, field_name, tag, options)
field_store[field_name] = field
field_store[tag] = field

define_method("#{field_name}!") do
@values[field_name]
define_method("#{field_name}!") do
@values[field_name]
end
end
end

def raise_if_tag_collision(tag, field_name)
if get_field(tag, true)
fail TagCollisionError, %(Field number #{tag} has already been used in "#{name}" by field "#{field_name}".)
def raise_if_tag_collision(tag, field_name)
if get_field(tag, true)
fail TagCollisionError, %(Field number #{tag} has already been used in "#{name}" by field "#{field_name}".)
end
end
end

def raise_if_name_collision(field_name)
if get_field(field_name, true)
fail DuplicateFieldNameError, %(Field name #{field_name} has already been used in "#{name}".)
def raise_if_name_collision(field_name)
if get_field(field_name, true)
fail DuplicateFieldNameError, %(Field name #{field_name} has already been used in "#{name}".)
end
end
end

def inherit_fields!(subclass)
instance_variables.each do |iv|
subclass.instance_variable_set(iv, instance_variable_get(iv))
end
end
private :inherit_fields!

end
end
end
end
52 changes: 52 additions & 0 deletions spec/functional/class_inheritance_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
require 'spec_helper'
require 'spec/support/test/resource_service'

RSpec.describe 'works through class inheritance' do
module Corp
module Protobuf
class Error < ::Protobuf::Message
required :string, :foo, 1
end
end
end
module Corp
class ErrorHandler < Corp::Protobuf::Error
end
end

let(:args) { { :foo => 'bar' } }
let(:parent_class) { Corp::Protobuf::Error }
let(:inherited_class) { Corp::ErrorHandler }

specify '#encode' do
expected_result = "\n\x03bar"
expected_result.force_encoding(Encoding::BINARY)
expect(parent_class.new(args).encode).to eq(expected_result)
expect(inherited_class.new(args).encode).to eq(expected_result)
end

specify '#to_hash' do
expect(parent_class.new(args).to_hash).to eq(args)
expect(inherited_class.new(args).to_hash).to eq(args)
end

specify '#to_json' do
expect(parent_class.new(args).to_json).to eq(args.to_json)
expect(inherited_class.new(args).to_json).to eq(args.to_json)
end

specify '.encode' do
expected_result = "\n\x03bar"
expected_result.force_encoding(Encoding::BINARY)
expect(parent_class.encode(args)).to eq(expected_result)
expect(inherited_class.encode(args)).to eq(expected_result)
end

specify '.decode' do
raw_value = "\n\x03bar"
raw_value.force_encoding(Encoding::BINARY)
expect(parent_class.decode(raw_value).to_hash).to eq(args)
expect(inherited_class.decode(raw_value).to_hash).to eq(args)
end

end

0 comments on commit 96fe1ba

Please sign in to comment.